Skip to content

Instantly share code, notes, and snippets.

@arenasys
Last active January 3, 2023 14:01
Show Gist options
  • Save arenasys/540869f9260506770f9985d19f79f479 to your computer and use it in GitHub Desktop.
Save arenasys/540869f9260506770f9985d19f79f479 to your computer and use it in GitHub Desktop.
prune.py
#!/usr/bin/python
import torch
import safetensors
import safetensors.torch
import os
MODEL_KEYS = [
"model.",
"first_stage_model",
"cond_stage_model",
"alphas_cumprod",
"alphas_cumprod_prev",
"betas",
"model_ema.decay",
"model_ema.num_updates",
"posterior_log_variance_clipped",
"posterior_mean_coef1",
"posterior_mean_coef2",
"posterior_variance",
"sqrt_alphas_cumprod",
"sqrt_one_minus_alphas_cumprod",
"sqrt_recip_alphas_cumprod",
"sqrt_recipm1_alphas_cumprod"
]
VAE_KEYS = [
"encoder",
"decoder",
"quant_conv",
"post_quant_conv"
]
RENAME_KEYS = {
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.'
}
METADATA = {'epoch': 0, 'global_step': 0, 'pytorch-lightning_version': '1.6.0'}
IDENTIFICATION = {
"unet": {
0: "SD-v1",
1897: "NAI"
},
"vae": {
0: "SD-v1",
2982: "NAI"
},
"clip": {
0: "SD-v1"
},
"keys": {
"cond_stage_model.transformer.embeddings.position_ids": "NAI",
"cond_stage_model.transformer.text_model.embeddings.position_ids": "SD-v1",
"cond_stage_model.model.transformer.resblocks.0.attn.in_proj_bias": "SD-v2"
}
}
def metric(model):
def tensor_metric(t):
t = t.to(torch.float16).to(torch.float32)
return torch.sum(torch.sigmoid(t)-0.5)
unet, vae, clip = 0, 0, 0
for k in model:
if k.startswith("model."):
unet += tensor_metric(model[k])
kk = k.replace("first_stage_model.", "")
if kk.startswith("encode") or kk.startswith("decode"):
vae += tensor_metric(model[k])
if k.startswith("cond_stage_model."):
clip += tensor_metric(model[k])
b_unet, b_vae, b_clip = -6131.5400, 17870.7051, 4509.0234
s_unet, s_vae, s_clip = 10000, 10000, 10000
n_unet = int(abs(unet/b_unet - 1) * s_unet)
n_vae = int(abs(vae/b_vae - 1) * s_vae)
n_clip = int(abs(clip/b_clip - 1) * s_clip)
unet = f"{n_unet:04}" if unet != 0 else "----"
vae = f"{n_vae:04}" if vae != 0 else "----"
clip = f"{n_clip:04}" if clip != 0 else "----"
return unet+"/"+vae+"/"+clip, (n_unet, n_vae, n_clip)
def prune(model, verbose=False):
in_state = model
out_state = {}
for k in in_state.keys():
saved = False
for a in MODEL_KEYS:
if k.startswith(a):
out_state[k] = in_state[k]
saved = True
break
if not saved and verbose:
print("DELETING", k)
if len(out_state) == 0:
print("NO KEYS FOUND")
return out_state
def extract_vae(model):
out = {}
for k in model:
kk = k
if k.startswith('first_stage_model'):
kk = k.replace('first_stage_model.', '')
if kk.split('.')[0] in VAE_KEYS:
out[kk] = model[k]
return out
def extract_metadata(model):
metadata = {}
for k in METADATA.keys():
if k in model:
metadata[k] = model[k]
return metadata
def replace_vae(model, vae, verbose=False):
for v in vae:
k = "first_stage_model." + v
if k in model:
if verbose:
print("REPLACING", k)
model[k] = vae[v]
return model
def rename(model, verbose=False):
for k in list(model.keys()):
for r in RENAME_KEYS.keys():
if r in k and not RENAME_KEYS[r] in k:
kk = k.replace(r, RENAME_KEYS[r])
if verbose:
print("RENAMING", k)
model[kk] = model[k]
del model[k]
return model
def load(file):
model = {}
metadata = {}
if file.endswith(".safetensors"):
with safetensors.safe_open(file, framework="pt", device="cpu") as f:
for key in f.keys():
model[key] = f.get_tensor(key)
else:
raw = torch.load(file, map_location="cpu")
if 'state_dict' in raw:
metadata = extract_metadata(raw)
model = raw["state_dict"]
else:
model = raw
return model, metadata
def save(model, metadata, file):
if file.endswith(".safetensors"):
safetensors.torch.save_file(model, file)
return
if not metadata and not ".vae." in file:
metadata = dict(METADATA)
out = metadata
out['state_dict'] = model
torch.save(out, file)
def do_half(model, half_unet, half_vae, half_clip):
for k in model:
a = k.startswith("model.") and half_unet
b = k.startswith("first_stage_model.") and half_vae
c = k.startswith("cond_stage_model.") and half_clip
if a or b or c and type(model[k]) == torch.Tensor:
model[k] = model[k].half()
return model
def verbose_print(verbose, *argv):
if verbose:
print(argv)
def do_prune(model_file, vae_file, out_file, half_unet, half_vae, half_clip, verbose=False):
print("LOADING...")
model, metadata = load(model_file)
model = rename(model, verbose)
input_metric, _ = metric(model)
in_size = os.path.getsize(model_file)
print(f"INPUT {os.path.basename(model_file)} ({input_metric}) ({in_size*1e-9:.2f} GB)")
print("PRUNING...")
model = prune(model, verbose)
if vae_file:
vae, _ = load(vae_file)
vae_size = os.path.getsize(vae_file)
vae_metric, _ = metric(vae)
print(f"ADDING {os.path.basename(vae_file)} ({vae_metric}) ({vae_size*1e-9:.2f} GB)")
vae = extract_vae(vae)
model = replace_vae(model, vae, verbose)
if half_unet or half_vae or half_clip:
print("HALVING...")
model = do_half(model, half_unet, half_vae, half_clip)
print("SAVING...")
save(model, metadata, out_file)
output_metric, _ = metric(model)
out_size = os.path.getsize(out_file)
print(f"OUTPUT {os.path.basename(out_file)} ({output_metric}) ({out_size*1e-9:.2f} GB)")
def prune_cli():
import argparse
import os
parser = argparse.ArgumentParser(description='Pruning')
parser.add_argument('--ckpt', type=str, default=None, required=True, help='path to model')
parser.add_argument('--vae', type=str, default=None, help='path to vae')
parser.add_argument('--out', type=str, default=None, required=True, help='path to save the pruned model')
parser.add_argument('-full', action='store_true', help='full precision (fp32)')
parser.add_argument('-v', action='store_true', help='verbose logging')
args = parser.parse_args()
model_file = args.ckpt
vae_file = args.vae
out_file = args.out
half = not args.full
verbose = args.v
if not os.path.exists(model_file):
print("model not found")
return
if vae_file and not os.path.exists(vae_file):
print("VAE not found")
return
do_prune(model_file, vae_file, out_file, half, half, half, verbose)
if __name__ == "__main__":
prune_cli()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment