Last active
September 21, 2023 05:57
-
-
Save AnyISalIn/67e97d07e97ff4f07a2298b6c8d3f146 to your computer and use it in GitHub Desktop.
fuse/unfuse multiple lora
This file has been truncated, but you can view the full file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "ce75a926-c68b-4ea7-947f-1d20f7ced1c6", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "!" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "a115ed76-78c7-4f8e-a171-cf00f0ea9847", | |
| "metadata": { | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/home/anyisalin/develop/diffusiongrid/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | |
| " from .autonotebook import tqdm as notebook_tqdm\n", | |
| "Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 4.09it/s]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import torch\n", | |
| "\n", | |
| "from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler\n", | |
| "\n", | |
| "pipeline = StableDiffusionPipeline.from_pretrained(\n", | |
| " \"gsdf/Counterfeit-V2.5\", torch_dtype=torch.float16, safety_checker=None, use_safetensors=True\n", | |
| ").to(\"cuda\")\n", | |
| "pipeline.scheduler = DPMSolverMultistepScheduler.from_config(\n", | |
| " pipeline.scheduler.config, use_karras_sigmas=True\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "fa7e5b31-b471-40b2-a32e-4350bde925f5", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "images_list = []" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "12f4a7a7-e112-4697-b3b1-a9ba3a1c614b", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "prompt = \"masterpiece, best quality, 1girl, at dusk\"\n", | |
| "negative_prompt = (\"(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), \"\n", | |
| " \"bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2), large breasts\")\n", | |
| "\n", | |
| "def generate_image(pipeline, images_list):\n", | |
| " images_list.extend(pipeline(prompt=prompt, \n", | |
| " negative_prompt=negative_prompt, \n", | |
| " width=512, \n", | |
| " height=768, \n", | |
| " num_inference_steps=15, \n", | |
| " num_images_per_prompt=1,\n", | |
| " generator=torch.manual_seed(0)\n", | |
| ").images)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "1e2262ed-37ea-475a-ae16-a36a090d088c", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:03<00:00, 4.99it/s]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "generate_image(pipeline, images_list)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "bc023a47-1bdb-4a6e-9780-3fa9f0b5a79b", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "The current API is supported for operating with a single LoRA file. You are trying to load and fuse more than one LoRA which is not well-supported.\n", | |
| "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:02<00:00, 5.42it/s]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "pipeline.load_lora_weights(\".\", weight_name=\"light_and_shadow.safetensors\")\n", | |
| "pipeline.fuse_lora(lora_scale=0.5, lora_name=\"light_and_shadow\")\n", | |
| "pipeline.load_lora_weights(\".\", weight_name=\"more_details.safetensors\")\n", | |
| "pipeline.fuse_lora(lora_scale=0.5, lora_name=\"more_details\")\n", | |
| "\n", | |
| "generate_image(pipeline, images_list)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "b3d47636-6ad2-4ece-abde-72f5a8c684e7", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:02<00:00, 5.40it/s]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "pipeline.unfuse_lora(lora_name=\"light_and_shadow\")\n", | |
| "\n", | |
| "generate_image(pipeline, images_list)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "27a9d4b6-356f-4633-923c-9c59df5584cd", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:02<00:00, 5.40it/s]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "pipeline.unfuse_lora(\"more_details\")\n", | |
| "\n", | |
| "generate_image(pipeline, images_list)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "9d758099-2509-4964-a0ab-cfbf3ac20fc3", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from IPython.display import display, HTML\n", | |
| "from PIL import Image, ImageDraw, ImageFont\n", | |
| "import io\n", | |
| "import base64 \n", | |
| "\n", | |
| "def display_image_grid_with_labels(images, labels, grid_size=(3, 3), image_size=(768, 1024)):\n", | |
| " \"\"\"\n", | |
| " Display a grid of images with labels in a Jupyter Notebook.\n", | |
| "\n", | |
| " Args:\n", | |
| " images (list of PIL.Image.Image): List of PIL Image objects.\n", | |
| " labels (list of str): List of labels corresponding to each image.\n", | |
| " grid_size (tuple): Number of rows and columns in the grid (default is 3x3).\n", | |
| " image_size (tuple): Size of each displayed image (default is 200x200).\n", | |
| " \"\"\"\n", | |
| " num_images = len(images)\n", | |
| " rows, cols = grid_size\n", | |
| " if num_images < rows * cols:\n", | |
| " print(\"Warning: Not enough images to fill the grid.\")\n", | |
| " \n", | |
| " html = \"<table style='width:100%'>\"\n", | |
| " for i in range(0, num_images, cols):\n", | |
| " html += \"<tr>\"\n", | |
| " for j in range(cols):\n", | |
| " index = i + j\n", | |
| " if index < num_images:\n", | |
| " # Resize the image\n", | |
| " img = images[index].copy()\n", | |
| " img.thumbnail(image_size, Image.LANCZOS)\n", | |
| " \n", | |
| " # Add label to the image\n", | |
| " draw = ImageDraw.Draw(img)\n", | |
| " label = labels[index]\n", | |
| " _, _, text_width, text_height = draw.textbbox((0, 0), label, font=ImageFont.load_default())\n", | |
| " # draw.text(((img.width - text_width) / 2, img.height - text_height - 5), label, fill=\"white\")\n", | |
| " \n", | |
| " # Convert the image to base64\n", | |
| " img_byte_array = io.BytesIO()\n", | |
| " img.save(img_byte_array, format=\"PNG\")\n", | |
| " img_data = img_byte_array.getvalue()\n", | |
| " img_base64_ = base64.b64encode(img_data).decode(\"utf-8\")\n", | |
| " img_base64 = f\"data:image/png;base64,{img_base64_}\"\n", | |
| " \n", | |
| " # Add the image and label to the HTML table\n", | |
| " html += f\"<td style='text-align:center; vertical-align:top; padding:10px;'><img src='{img_base64}' /><br/>{label}</td>\"\n", | |
| " html += \"</tr>\"\n", | |
| " html += \"</table>\"\n", | |
| " \n", | |
| " display(HTML(html))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "956e1760-ad58-4e56-91f0-40aacd6fce02", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment