Created
August 19, 2023 21:35
-
-
Save miabrahams/e9aa899669e6d757f7c368da6a1fde56 to your computer and use it in GitHub Desktop.
Jupyter notebook to convert .safetensors to diffusers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "ddb4c3fd", | |
"metadata": {}, | |
"source": [ | |
"## Convert .safetensors to Diffusers" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "08597be1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import argparse\n", | |
"import importlib\n", | |
"\n", | |
"import torch\n", | |
"from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt\n", | |
"\n", | |
"\n", | |
"class Args:\n", | |
" pass" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "599eb851", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"args = Args()\n", | |
"\n", | |
"args.checkpoint_path = r\"D:\\AI\\automatic\\models\\Stable-diffusion\\realisticVisionV51.safetensors\"\n", | |
"args.original_config_file = r\"D:\\AI\\automatic\\configs\\v1-inference.yaml\"\n", | |
"args.dump_path= r\"D:\\AI\\invokeai\\models\\realisticVisionv51\"\n", | |
"args.image_size = None\n", | |
"args.prediction_type=None\n", | |
"args.pipeline_type=None\n", | |
"args.extract_ema=True\n", | |
"args.scheduler_type=\"ddim\"\n", | |
"args.num_in_channels=None\n", | |
"args.upcast_attention=True\n", | |
"args.from_safetensors=True\n", | |
"args.to_safetensors=True\n", | |
"args.device=\"cuda:0\"\n", | |
"args.stable_unclip=None\n", | |
"args.stable_unclip_prior=None\n", | |
"args.clip_stats_path=None\n", | |
"args.controlnet=False\n", | |
"args.vae_path=None\n", | |
"args.pipeline_class_name = None\n", | |
"args.half = True\n", | |
"\n", | |
"\n", | |
"if args.pipeline_class_name:\n", | |
" library = importlib.import_module(\"diffusers\")\n", | |
" # Show all pipelines\n", | |
" for p in dir(library):\n", | |
" print(p) if \"Pipeline\" in p else None\n", | |
" class_obj = getattr(library, args.pipeline_class_name)\n", | |
" pipeline_class = class_obj\n", | |
"else:\n", | |
" pipeline_class = None\n", | |
"\n", | |
"# pipeline_class=library.StableDiffusionPipeline\n", | |
"\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "087c4b1e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"pipe = download_from_original_stable_diffusion_ckpt(\n", | |
" checkpoint_path=args.checkpoint_path,\n", | |
" original_config_file=args.original_config_file,\n", | |
" image_size=args.image_size,\n", | |
" prediction_type=args.prediction_type,\n", | |
" model_type=args.pipeline_type,\n", | |
" extract_ema=args.extract_ema,\n", | |
" scheduler_type=args.scheduler_type,\n", | |
" num_in_channels=args.num_in_channels,\n", | |
" upcast_attention=args.upcast_attention,\n", | |
" from_safetensors=args.from_safetensors,\n", | |
" device=args.device,\n", | |
" stable_unclip=args.stable_unclip,\n", | |
" stable_unclip_prior=args.stable_unclip_prior,\n", | |
" clip_stats_path=args.clip_stats_path,\n", | |
" controlnet=args.controlnet,\n", | |
" vae_path=args.vae_path,\n", | |
" pipeline_class=pipeline_class,\n", | |
")\n", | |
"\n", | |
"if args.half:\n", | |
" pipe.to(torch_dtype=torch.float16)\n", | |
"\n", | |
"if args.controlnet:\n", | |
" # only save the controlnet model\n", | |
" pipe.controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)\n", | |
"else:\n", | |
" pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "7308aa50", | |
"metadata": {}, | |
"source": [ | |
"### Check available pipelines" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "dc2e94c5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"\n", | |
"# Show all pipelines\n", | |
"library = importlib.import_module(\"diffusers\")\n", | |
"\n", | |
"for p in dir(library):\n", | |
" print(p) if \"Pipeline\" in p else None\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a0065561", | |
"metadata": {}, | |
"source": [ | |
"## Convert VAE" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "d30b82d9", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import argparse\n", | |
"import io\n", | |
"\n", | |
"import requests\n", | |
"import torch\n", | |
"from omegaconf import OmegaConf\n", | |
"\n", | |
"from diffusers import AutoencoderKL\n", | |
"from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (\n", | |
" assign_to_checkpoint,\n", | |
" conv_attn_to_linear,\n", | |
" create_vae_diffusers_config,\n", | |
" renew_vae_attention_paths,\n", | |
" renew_vae_resnet_paths,\n", | |
")\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "27fcefd6", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Function definitions\n", | |
"\n", | |
"def custom_convert_ldm_vae_checkpoint(checkpoint, config):\n", | |
" vae_state_dict = checkpoint\n", | |
"\n", | |
" new_checkpoint = {}\n", | |
"\n", | |
" new_checkpoint[\"encoder.conv_in.weight\"] = vae_state_dict[\"encoder.conv_in.weight\"]\n", | |
" new_checkpoint[\"encoder.conv_in.bias\"] = vae_state_dict[\"encoder.conv_in.bias\"]\n", | |
" new_checkpoint[\"encoder.conv_out.weight\"] = vae_state_dict[\"encoder.conv_out.weight\"]\n", | |
" new_checkpoint[\"encoder.conv_out.bias\"] = vae_state_dict[\"encoder.conv_out.bias\"]\n", | |
" new_checkpoint[\"encoder.conv_norm_out.weight\"] = vae_state_dict[\"encoder.norm_out.weight\"]\n", | |
" new_checkpoint[\"encoder.conv_norm_out.bias\"] = vae_state_dict[\"encoder.norm_out.bias\"]\n", | |
"\n", | |
" new_checkpoint[\"decoder.conv_in.weight\"] = vae_state_dict[\"decoder.conv_in.weight\"]\n", | |
" new_checkpoint[\"decoder.conv_in.bias\"] = vae_state_dict[\"decoder.conv_in.bias\"]\n", | |
" new_checkpoint[\"decoder.conv_out.weight\"] = vae_state_dict[\"decoder.conv_out.weight\"]\n", | |
" new_checkpoint[\"decoder.conv_out.bias\"] = vae_state_dict[\"decoder.conv_out.bias\"]\n", | |
" new_checkpoint[\"decoder.conv_norm_out.weight\"] = vae_state_dict[\"decoder.norm_out.weight\"]\n", | |
" new_checkpoint[\"decoder.conv_norm_out.bias\"] = vae_state_dict[\"decoder.norm_out.bias\"]\n", | |
"\n", | |
" new_checkpoint[\"quant_conv.weight\"] = vae_state_dict[\"quant_conv.weight\"]\n", | |
" new_checkpoint[\"quant_conv.bias\"] = vae_state_dict[\"quant_conv.bias\"]\n", | |
" new_checkpoint[\"post_quant_conv.weight\"] = vae_state_dict[\"post_quant_conv.weight\"]\n", | |
" new_checkpoint[\"post_quant_conv.bias\"] = vae_state_dict[\"post_quant_conv.bias\"]\n", | |
"\n", | |
" # Retrieves the keys for the encoder down blocks only\n", | |
" num_down_blocks = len({\".\".join(layer.split(\".\")[:3]) for layer in vae_state_dict if \"encoder.down\" in layer})\n", | |
" down_blocks = {\n", | |
" layer_id: [key for key in vae_state_dict if f\"down.{layer_id}\" in key] for layer_id in range(num_down_blocks)\n", | |
" }\n", | |
"\n", | |
" # Retrieves the keys for the decoder up blocks only\n", | |
" num_up_blocks = len({\".\".join(layer.split(\".\")[:3]) for layer in vae_state_dict if \"decoder.up\" in layer})\n", | |
" up_blocks = {\n", | |
" layer_id: [key for key in vae_state_dict if f\"up.{layer_id}\" in key] for layer_id in range(num_up_blocks)\n", | |
" }\n", | |
"\n", | |
" for i in range(num_down_blocks):\n", | |
" resnets = [key for key in down_blocks[i] if f\"down.{i}\" in key and f\"down.{i}.downsample\" not in key]\n", | |
"\n", | |
" if f\"encoder.down.{i}.downsample.conv.weight\" in vae_state_dict:\n", | |
" new_checkpoint[f\"encoder.down_blocks.{i}.downsamplers.0.conv.weight\"] = vae_state_dict.pop(\n", | |
" f\"encoder.down.{i}.downsample.conv.weight\"\n", | |
" )\n", | |
" new_checkpoint[f\"encoder.down_blocks.{i}.downsamplers.0.conv.bias\"] = vae_state_dict.pop(\n", | |
" f\"encoder.down.{i}.downsample.conv.bias\"\n", | |
" )\n", | |
"\n", | |
" paths = renew_vae_resnet_paths(resnets)\n", | |
" meta_path = {\"old\": f\"down.{i}.block\", \"new\": f\"down_blocks.{i}.resnets\"}\n", | |
" assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n", | |
"\n", | |
" mid_resnets = [key for key in vae_state_dict if \"encoder.mid.block\" in key]\n", | |
" num_mid_res_blocks = 2\n", | |
" for i in range(1, num_mid_res_blocks + 1):\n", | |
" resnets = [key for key in mid_resnets if f\"encoder.mid.block_{i}\" in key]\n", | |
"\n", | |
" paths = renew_vae_resnet_paths(resnets)\n", | |
" meta_path = {\"old\": f\"mid.block_{i}\", \"new\": f\"mid_block.resnets.{i - 1}\"}\n", | |
" assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n", | |
"\n", | |
" mid_attentions = [key for key in vae_state_dict if \"encoder.mid.attn\" in key]\n", | |
" paths = renew_vae_attention_paths(mid_attentions)\n", | |
" meta_path = {\"old\": \"mid.attn_1\", \"new\": \"mid_block.attentions.0\"}\n", | |
" assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n", | |
" conv_attn_to_linear(new_checkpoint)\n", | |
"\n", | |
" for i in range(num_up_blocks):\n", | |
" block_id = num_up_blocks - 1 - i\n", | |
" resnets = [\n", | |
" key for key in up_blocks[block_id] if f\"up.{block_id}\" in key and f\"up.{block_id}.upsample\" not in key\n", | |
" ]\n", | |
"\n", | |
" if f\"decoder.up.{block_id}.upsample.conv.weight\" in vae_state_dict:\n", | |
" new_checkpoint[f\"decoder.up_blocks.{i}.upsamplers.0.conv.weight\"] = vae_state_dict[\n", | |
" f\"decoder.up.{block_id}.upsample.conv.weight\"\n", | |
" ]\n", | |
" new_checkpoint[f\"decoder.up_blocks.{i}.upsamplers.0.conv.bias\"] = vae_state_dict[\n", | |
" f\"decoder.up.{block_id}.upsample.conv.bias\"\n", | |
" ]\n", | |
"\n", | |
" paths = renew_vae_resnet_paths(resnets)\n", | |
" meta_path = {\"old\": f\"up.{block_id}.block\", \"new\": f\"up_blocks.{i}.resnets\"}\n", | |
" assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n", | |
"\n", | |
" mid_resnets = [key for key in vae_state_dict if \"decoder.mid.block\" in key]\n", | |
" num_mid_res_blocks = 2\n", | |
" for i in range(1, num_mid_res_blocks + 1):\n", | |
" resnets = [key for key in mid_resnets if f\"decoder.mid.block_{i}\" in key]\n", | |
"\n", | |
" paths = renew_vae_resnet_paths(resnets)\n", | |
" meta_path = {\"old\": f\"mid.block_{i}\", \"new\": f\"mid_block.resnets.{i - 1}\"}\n", | |
" assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n", | |
"\n", | |
" mid_attentions = [key for key in vae_state_dict if \"decoder.mid.attn\" in key]\n", | |
" paths = renew_vae_attention_paths(mid_attentions)\n", | |
" meta_path = {\"old\": \"mid.attn_1\", \"new\": \"mid_block.attentions.0\"}\n", | |
" assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)\n", | |
" conv_attn_to_linear(new_checkpoint)\n", | |
" return new_checkpoint\n", | |
"\n", | |
"\n", | |
"def vae_pt_to_vae_diffuser(\n", | |
" checkpoint_path: str,\n", | |
" output_path: str,\n", | |
"):\n", | |
" # Only support V1\n", | |
" r = requests.get(\n", | |
" \" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml\"\n", | |
" )\n", | |
" io_obj = io.BytesIO(r.content)\n", | |
"\n", | |
" original_config = OmegaConf.load(io_obj)\n", | |
" image_size = 512\n", | |
" device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", | |
" if checkpoint_path.endswith(\"safetensors\"):\n", | |
" from safetensors import safe_open\n", | |
"\n", | |
" checkpoint = {}\n", | |
" with safe_open(checkpoint_path, framework=\"pt\", device=\"cpu\") as f:\n", | |
" for key in f.keys():\n", | |
" checkpoint[key] = f.get_tensor(key)\n", | |
" else:\n", | |
" checkpoint = torch.load(checkpoint_path, map_location=device)[\"state_dict\"]\n", | |
"\n", | |
" # Convert the VAE model.\n", | |
" vae_config = create_vae_diffusers_config(original_config, image_size=image_size)\n", | |
" converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint, vae_config)\n", | |
"\n", | |
" vae = AutoencoderKL(**vae_config)\n", | |
" vae.load_state_dict(converted_vae_checkpoint)\n", | |
" vae.save_pretrained(output_path)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "bb30dc66", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"\n", | |
"vae_pt_path = r\"D:\\AI\\automatic\\models\\VAE\\vae-ft-ema-560000-ema-pruned.safetensors\"\n", | |
"dump_path = r\"D:\\AI\\invokeai\\models\\vae-ft-ema-560000-ema-pruned\"\n", | |
"\n", | |
"vae_pt_to_vae_diffuser(vae_pt_path, dump_path)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment