Last active
November 28, 2022 11:42
-
-
Save inu-ai/7e8a8ecda5f6649d81bd5202ce8e6a21 to your computer and use it in GitHub Desktop.
stable_diffusion_1_dreambooth_Kohya_S.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/thx-pw/7e8a8ecda5f6649d81bd5202ce8e6a21/dreambooth_stable_diffusion_fixed.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"移動します。\n", | |
"\n", | |
"https://github.com/thx-pw/stable-diffusion-2.0-dreambooth" | |
], | |
"metadata": { | |
"id": "51U6Ge8sCn1z" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"このColabのライセンスはApache License 2.0\n", | |
"\n", | |
"引用元\n", | |
"\n", | |
"https://github.com/ShivamShrirao/diffusers/tree/main/examples/dreambooth\n", | |
"\n", | |
"https://note.com/kohya_ss/n/nee3ed1649fb6\n", | |
"\n", | |
"https://note.com/kohya_ss/n/nad3bce9a3622" | |
], | |
"metadata": { | |
"id": "8DK981h_7p0X" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"cellView": "form", | |
"id": "XU7NuMAA2drw" | |
}, | |
"outputs": [], | |
"source": [ | |
"#@title GPUチェック\n", | |
"!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "BzM7j0ZSc_9c" | |
}, | |
"source": [ | |
"https://github.com/ShivamShrirao/diffusers/tree/main/examples/dreambooth" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "aLWXPZqjsZVV", | |
"cellView": "form" | |
}, | |
"outputs": [], | |
"source": [ | |
"#@title 必要なパッケージのインストール\n", | |
"!pip install -q diffusers[torch]==0.9.0 accelerate transformers==4.21.3 ftfy albumentations opencv-python einops bitsandbytes fairscale==0.4.6 numpy==1.21.6\n", | |
"!pip install -q https://github.com/camenduru/stable-diffusion-webui-colab/releases/download/0.0.14/xformers-0.0.14.dev0-cp37-cp37m-linux_x86_64.whl\n", | |
"\n", | |
"!pip install -q pytorch_lightning\n", | |
"\n", | |
"!git clone https://github.com/salesforce/BLIP --quiet\n", | |
"\n", | |
"!wget -q https://github.com/thx-pw/stable-diffusion-2.0-dreambooth/raw/main/gen_img_diffusers.py\n", | |
"!wget -q https://github.com/thx-pw/stable-diffusion-2.0-dreambooth/raw/main/train_db_fixed.py\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"cellView": "form", | |
"id": "Rxg0y5MBudmd" | |
}, | |
"outputs": [], | |
"source": [ | |
"#@title モデルの選択\n", | |
"\n", | |
"#@markdown https://huggingface.co/settings/tokens\n", | |
"HUGGINGFACE_TOKEN = \"\" #@param {type:\"string\"}\n", | |
"\n", | |
"MODEL_NAME = \"stable-v14\" #@param ['trinart-characters-19m', 'waifu-v13-float32', 'waifu-v13-float16', 'stable-v14', 'pokemon', 'robo-v1']\n", | |
"\n", | |
"# 教師データ(学習データ)の保存場所\n", | |
"TRAIN_DIR = \"/content/input/train\"\n", | |
"!mkdir -p $TRAIN_DIR\n", | |
"\n", | |
"# 正則化画像(クラスの画像)の保存場所\n", | |
"REG_DIR = \"/content/input/reg\"\n", | |
"!mkdir -p $REG_DIR\n", | |
"\n", | |
"# 学習済みモデルの保存場所\n", | |
"OUTPUT_DIR = \"/content/output\" \n", | |
"!mkdir -p $OUTPUT_DIR\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#@title モデルのダウンロード\n", | |
"\n", | |
"models_dict = {\n", | |
" \"trinart-characters-19m\" : \"https://huggingface.co/naclbit/trinart_characters_19.2m_stable_diffusion_v1/blob/main/trinart_characters_it4_v1.ckpt\",\n", | |
" \"waifu-v13-float32\" : \"https://huggingface.co/hakurei/waifu-diffusion-v1-3/blob/main/wd-v1-3-float32.ckpt\",\n", | |
" \"waifu-v13-float16\" : \"https://huggingface.co/hakurei/waifu-diffusion-v1-3/blob/main/wd-v1-3-float16.ckpt\",\n", | |
" \"stable-v14\" : \"https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/blob/main/sd-v1-4.ckpt\",\n", | |
" \"pokemon\" : \"https://huggingface.co/justinpinkney/pokemon-stable-diffusion/blob/main/ema-only-epoch%3D000142.ckpt\",\n", | |
" \"robo-v1\" : \"https://huggingface.co/nousr/robo-diffusion/blob/main/models/robo-diffusion-v1.ckpt\",\n", | |
"}\n", | |
"\n", | |
"model_url = models_dict[MODEL_NAME].replace(\"/blob/\", \"/resolve/\")\n", | |
"user_header = f\"\\\"Authorization: Bearer {HUGGINGFACE_TOKEN}\\\"\"\n", | |
"!wget --header={user_header} {model_url} -O /content/{MODEL_NAME}.ckpt\n" | |
], | |
"metadata": { | |
"cellView": "form", | |
"id": "ww7DMIzFxm_h" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#@title BLIP機能\n", | |
"import torch\n", | |
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", | |
"\n", | |
"def load_blip():\n", | |
" import sys\n", | |
" sys.path.append('BLIP')\n", | |
" \n", | |
" from models.blip import blip_decoder\n", | |
"\n", | |
" %cd /content/BLIP\n", | |
"\n", | |
" blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'\n", | |
" \n", | |
" blip_model = blip_decoder(pretrained=blip_model_url, image_size=384, vit='base')\n", | |
" blip_model.eval()\n", | |
" blip_model = blip_model.to(device)\n", | |
"\n", | |
" %cd /content\n", | |
"\n", | |
" return blip_model" | |
], | |
"metadata": { | |
"cellView": "form", | |
"id": "sL8UGjj8tmUH" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "fe-GgtnUVO_e", | |
"cellView": "form" | |
}, | |
"outputs": [], | |
"source": [ | |
"#@title 教師データ(学習データ)のアップロードと正則化画像(クラスの画像)の自動生成\n", | |
"#@markdown このセルでSKSとCLASSを変更し、複数回実行すると複数キャラの学習が可能\n", | |
"\n", | |
"#@markdown Colabへのアップロードが遅いのでリサイズするツール:https://www.birme.net/?target_width=1024&target_height=1024\n", | |
"SKS = \"zundamon\" #@param {type:\"string\"}\n", | |
"CLASS = \"boy\" #@param {type:\"string\"}\n", | |
"TRAIN_N_REPEATS = 20\n", | |
"REG_N_REPEATS = 1 \n", | |
"#@markdown (実験的機能)教師データをBLIPでpromptを逆算して、「CLASS, prompt」で正則化画像を自動生成します\n", | |
"\n", | |
"#@markdown BLIPを使わない場合は、promptはCLASSだけになります\n", | |
"use_blip = False #@param {type:\"boolean\"}\n", | |
"NEGATIVE_PROMPT = \"\" #@param {type:\"string\"}\n", | |
"\n", | |
"PROMPTS_PATH = \"/content/prompts.txt\"\n", | |
"\n", | |
"import os\n", | |
"from google.colab import files\n", | |
"import shutil\n", | |
"import glob\n", | |
"from PIL import Image\n", | |
"\n", | |
"train_path = os.path.join(TRAIN_DIR, f\"{TRAIN_N_REPEATS}_{SKS} {CLASS}\")\n", | |
"os.makedirs(train_path, exist_ok=True)\n", | |
"reg_path = os.path.join(REG_DIR, f\"{REG_N_REPEATS}_{CLASS}\")\n", | |
"os.makedirs(reg_path, exist_ok=True)\n", | |
"\n", | |
"uploaded = files.upload()\n", | |
"for filename in uploaded.keys():\n", | |
" dst_path = os.path.join(train_path, filename)\n", | |
" shutil.move(filename, dst_path)\n", | |
"\n", | |
"def get_prompt(blip_model, image):\n", | |
" from torchvision import transforms\n", | |
" from torchvision.transforms.functional import InterpolationMode\n", | |
" image_size = 384\n", | |
" transform = transforms.Compose([\n", | |
" transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n", | |
" ])\n", | |
" \n", | |
" image = transform(image).unsqueeze(0).to(device)\n", | |
" with torch.no_grad():\n", | |
" prompt = blip_model.generate(image, sample=False, num_beams=3, max_length=200, min_length=30)\n", | |
" return prompt[0]\n", | |
"\n", | |
"def get_prompts(blip_model):\n", | |
" prompts = []\n", | |
" for image_path in glob.glob(f'{train_path}/*.*'):\n", | |
" image = Image.open(image_path).convert('RGB')\n", | |
" prompt = get_prompt(blip_model, image)\n", | |
" prompts.append(prompt)\n", | |
" return prompts\n", | |
"\n", | |
"def generate_prompts():\n", | |
" if use_blip:\n", | |
" blip_model = load_blip()\n", | |
" prompts = [f\"{CLASS}, {x} --n {NEGATIVE_PROMPT}\" for x in get_prompts(blip_model)]\n", | |
"\n", | |
" del blip_model\n", | |
" if torch.cuda.is_available():\n", | |
" torch.cuda.empty_cache()\n", | |
" else:\n", | |
" prompts = [f\"{CLASS}, --n {NEGATIVE_PROMPT}\"]\n", | |
"\n", | |
" with open(PROMPTS_PATH, \"w\") as f:\n", | |
" f.write('\\n'.join(prompts))\n", | |
" \n", | |
"def generate_reg_images():\n", | |
" reg_num_images = sum(os.path.isfile(os.path.join(reg_path, name)) for name in os.listdir(reg_path))\n", | |
" reg_num_images = (TRAIN_N_REPEATS * train_num_images) // REG_N_REPEATS - reg_num_images\n", | |
" \n", | |
" !python gen_img_diffusers.py \\\n", | |
" --ckpt {MODEL_NAME}.ckpt \\\n", | |
" --outdir {reg_path} \\\n", | |
" --xformers \\\n", | |
" --fp16 \\\n", | |
" --W 512 \\\n", | |
" --H 512 \\\n", | |
" --scale 12.5 \\\n", | |
" --sampler ddim \\\n", | |
" --steps 20 \\\n", | |
" --batch_size 4 \\\n", | |
" --images_per_prompt {reg_num_images} \\\n", | |
" --from_file {PROMPTS_PATH}\n", | |
"\n", | |
"train_num_images = sum(os.path.isfile(os.path.join(train_path, name)) for name in os.listdir(train_path))\n", | |
"if train_num_images > 0:\n", | |
" generate_prompts()\n", | |
" generate_reg_images()\n", | |
"else:\n", | |
" print(\"cancel upload.\")\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "jjcSXTp-u-Eg" | |
}, | |
"outputs": [], | |
"source": [ | |
"!accelerate launch --num_cpu_threads_per_process 2 train_db_fixed.py \\\n", | |
" --pretrained_model_name_or_path={MODEL_NAME}.ckpt \\\n", | |
" --train_data_dir=$TRAIN_DIR \\\n", | |
" --reg_data_dir=$REG_DIR \\\n", | |
" --output_dir=$OUTPUT_DIR \\\n", | |
" --prior_loss_weight=1.0 \\\n", | |
" --resolution=512 \\\n", | |
" --train_batch_size=4 \\\n", | |
" --learning_rate=2e-6 \\\n", | |
" --max_train_steps=400 \\\n", | |
" --use_8bit_adam \\\n", | |
" --mixed_precision='fp16' \\\n", | |
" --xformers \\\n", | |
" --cache_latents \\\n", | |
" --gradient_checkpointing \\\n", | |
" --save_precision='fp16' \\\n", | |
" --save_every_n_epochs 2 \\\n", | |
" --logging_dir=logs" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"cellView": "form", | |
"id": "89Az5NUxOWdy" | |
}, | |
"outputs": [], | |
"source": [ | |
"#@title ログの確認\n", | |
"%load_ext tensorboard\n", | |
"%tensorboard --logdir=logs" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#@title Google Driveにckptを保存\n", | |
"ckpt_name = \"epoch-000010\" #@param {type:\"string\"}\n", | |
"from google.colab import drive\n", | |
"drive.mount('/content/drive')\n", | |
"\n", | |
"import os\n", | |
"model_checkpoints = \"/content/drive/MyDrive/sd/stable-diffusion-webui/models/Stable-diffusion\"\n", | |
"os.makedirs(model_checkpoints, exist_ok=True)\n", | |
"!cp \"{OUTPUT_DIR}/{ckpt_name}.ckpt\" {model_checkpoints}\n", | |
"\n", | |
"print(f\"save to {model_checkpoints}\")" | |
], | |
"metadata": { | |
"cellView": "form", | |
"id": "MYDjfXf8MB2R" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#@title 学習済みモデルで画像生成\n", | |
"ckpt_name = \"epoch-000010\" #@param {type:\"string\"}\n", | |
"!python gen_img_diffusers.py \\\n", | |
" --ckpt \"{OUTPUT_DIR}/{ckpt_name}.ckpt\" \\\n", | |
" --outdir 'tmp' \\\n", | |
" --xformers \\\n", | |
" --fp16 \\\n", | |
" --W 768 \\\n", | |
" --H 768 \\\n", | |
" --scale 12.5 \\\n", | |
" --sampler ddim \\\n", | |
" --steps 20 \\\n", | |
" --batch_size 4 \\\n", | |
" --images_per_prompt 4 \\\n", | |
" --prompt \"{SKS} {CLASS} eating a lunch in MacDonald's -n\"\n", | |
"\n", | |
"print(\"create to /content/tmp\")" | |
], | |
"metadata": { | |
"cellView": "form", | |
"id": "OuQGBG737QJc" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"accelerator": "GPU", | |
"colab": { | |
"provenance": [], | |
"private_outputs": true, | |
"name": "stable_diffusion_1_dreambooth_Kohya_S.ipynb", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"display_name": "Python 3.8.12 ('pytorch')", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"name": "python", | |
"version": "3.8.12" | |
}, | |
"vscode": { | |
"interpreter": { | |
"hash": "2d58e898dde0263bc564c6968b04150abacfd33eed9b19aaa8e45c040360e146" | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment