Skip to content

Instantly share code, notes, and snippets.

@AnyISalIn
Last active September 21, 2023 05:57
Show Gist options
  • Select an option

  • Save AnyISalIn/67e97d07e97ff4f07a2298b6c8d3f146 to your computer and use it in GitHub Desktop.

Select an option

Save AnyISalIn/67e97d07e97ff4f07a2298b6c8d3f146 to your computer and use it in GitHub Desktop.
fuse/unfuse multiple lora
This file has been truncated, but you can view the full file.
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "ce75a926-c68b-4ea7-947f-1d20f7ced1c6",
"metadata": {},
"outputs": [],
"source": [
"!"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "a115ed76-78c7-4f8e-a171-cf00f0ea9847",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/anyisalin/develop/diffusiongrid/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 4.09it/s]\n"
]
}
],
"source": [
"import torch\n",
"\n",
"from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler\n",
"\n",
"pipeline = StableDiffusionPipeline.from_pretrained(\n",
" \"gsdf/Counterfeit-V2.5\", torch_dtype=torch.float16, safety_checker=None, use_safetensors=True\n",
").to(\"cuda\")\n",
"pipeline.scheduler = DPMSolverMultistepScheduler.from_config(\n",
" pipeline.scheduler.config, use_karras_sigmas=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "fa7e5b31-b471-40b2-a32e-4350bde925f5",
"metadata": {},
"outputs": [],
"source": [
"images_list = []"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "12f4a7a7-e112-4697-b3b1-a9ba3a1c614b",
"metadata": {},
"outputs": [],
"source": [
"prompt = \"masterpiece, best quality, 1girl, at dusk\"\n",
"negative_prompt = (\"(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), \"\n",
" \"bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2), large breasts\")\n",
"\n",
"def generate_image(pipeline, images_list):\n",
" images_list.extend(pipeline(prompt=prompt, \n",
" negative_prompt=negative_prompt, \n",
" width=512, \n",
" height=768, \n",
" num_inference_steps=15, \n",
" num_images_per_prompt=1,\n",
" generator=torch.manual_seed(0)\n",
").images)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "1e2262ed-37ea-475a-ae16-a36a090d088c",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:03<00:00, 4.99it/s]\n"
]
}
],
"source": [
"generate_image(pipeline, images_list)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "bc023a47-1bdb-4a6e-9780-3fa9f0b5a79b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The current API is supported for operating with a single LoRA file. You are trying to load and fuse more than one LoRA which is not well-supported.\n",
"100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:02<00:00, 5.42it/s]\n"
]
}
],
"source": [
"pipeline.load_lora_weights(\".\", weight_name=\"light_and_shadow.safetensors\")\n",
"pipeline.fuse_lora(lora_scale=0.5, lora_name=\"light_and_shadow\")\n",
"pipeline.load_lora_weights(\".\", weight_name=\"more_details.safetensors\")\n",
"pipeline.fuse_lora(lora_scale=0.5, lora_name=\"more_details\")\n",
"\n",
"generate_image(pipeline, images_list)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "b3d47636-6ad2-4ece-abde-72f5a8c684e7",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:02<00:00, 5.40it/s]\n"
]
}
],
"source": [
"pipeline.unfuse_lora(lora_name=\"light_and_shadow\")\n",
"\n",
"generate_image(pipeline, images_list)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "27a9d4b6-356f-4633-923c-9c59df5584cd",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:02<00:00, 5.40it/s]\n"
]
}
],
"source": [
"pipeline.unfuse_lora(\"more_details\")\n",
"\n",
"generate_image(pipeline, images_list)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9d758099-2509-4964-a0ab-cfbf3ac20fc3",
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import display, HTML\n",
"from PIL import Image, ImageDraw, ImageFont\n",
"import io\n",
"import base64 \n",
"\n",
"def display_image_grid_with_labels(images, labels, grid_size=(3, 3), image_size=(768, 1024)):\n",
" \"\"\"\n",
" Display a grid of images with labels in a Jupyter Notebook.\n",
"\n",
" Args:\n",
" images (list of PIL.Image.Image): List of PIL Image objects.\n",
" labels (list of str): List of labels corresponding to each image.\n",
" grid_size (tuple): Number of rows and columns in the grid (default is 3x3).\n",
" image_size (tuple): Size of each displayed image (default is 200x200).\n",
" \"\"\"\n",
" num_images = len(images)\n",
" rows, cols = grid_size\n",
" if num_images < rows * cols:\n",
" print(\"Warning: Not enough images to fill the grid.\")\n",
" \n",
" html = \"<table style='width:100%'>\"\n",
" for i in range(0, num_images, cols):\n",
" html += \"<tr>\"\n",
" for j in range(cols):\n",
" index = i + j\n",
" if index < num_images:\n",
" # Resize the image\n",
" img = images[index].copy()\n",
" img.thumbnail(image_size, Image.LANCZOS)\n",
" \n",
" # Add label to the image\n",
" draw = ImageDraw.Draw(img)\n",
" label = labels[index]\n",
" _, _, text_width, text_height = draw.textbbox((0, 0), label, font=ImageFont.load_default())\n",
" # draw.text(((img.width - text_width) / 2, img.height - text_height - 5), label, fill=\"white\")\n",
" \n",
" # Convert the image to base64\n",
" img_byte_array = io.BytesIO()\n",
" img.save(img_byte_array, format=\"PNG\")\n",
" img_data = img_byte_array.getvalue()\n",
" img_base64_ = base64.b64encode(img_data).decode(\"utf-8\")\n",
" img_base64 = f\"data:image/png;base64,{img_base64_}\"\n",
" \n",
" # Add the image and label to the HTML table\n",
" html += f\"<td style='text-align:center; vertical-align:top; padding:10px;'><img src='{img_base64}' /><br/>{label}</td>\"\n",
" html += \"</tr>\"\n",
" html += \"</table>\"\n",
" \n",
" display(HTML(html))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "956e1760-ad58-4e56-91f0-40aacd6fce02",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment