Last active
January 3, 2023 14:01
-
-
Save arenasys/540869f9260506770f9985d19f79f479 to your computer and use it in GitHub Desktop.
prune.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
import torch | |
import safetensors | |
import safetensors.torch | |
import os | |
MODEL_KEYS = [ | |
"model.", | |
"first_stage_model", | |
"cond_stage_model", | |
"alphas_cumprod", | |
"alphas_cumprod_prev", | |
"betas", | |
"model_ema.decay", | |
"model_ema.num_updates", | |
"posterior_log_variance_clipped", | |
"posterior_mean_coef1", | |
"posterior_mean_coef2", | |
"posterior_variance", | |
"sqrt_alphas_cumprod", | |
"sqrt_one_minus_alphas_cumprod", | |
"sqrt_recip_alphas_cumprod", | |
"sqrt_recipm1_alphas_cumprod" | |
] | |
VAE_KEYS = [ | |
"encoder", | |
"decoder", | |
"quant_conv", | |
"post_quant_conv" | |
] | |
RENAME_KEYS = { | |
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', | |
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', | |
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.' | |
} | |
METADATA = {'epoch': 0, 'global_step': 0, 'pytorch-lightning_version': '1.6.0'} | |
IDENTIFICATION = { | |
"unet": { | |
0: "SD-v1", | |
1897: "NAI" | |
}, | |
"vae": { | |
0: "SD-v1", | |
2982: "NAI" | |
}, | |
"clip": { | |
0: "SD-v1" | |
}, | |
"keys": { | |
"cond_stage_model.transformer.embeddings.position_ids": "NAI", | |
"cond_stage_model.transformer.text_model.embeddings.position_ids": "SD-v1", | |
"cond_stage_model.model.transformer.resblocks.0.attn.in_proj_bias": "SD-v2" | |
} | |
} | |
def metric(model): | |
def tensor_metric(t): | |
t = t.to(torch.float16).to(torch.float32) | |
return torch.sum(torch.sigmoid(t)-0.5) | |
unet, vae, clip = 0, 0, 0 | |
for k in model: | |
if k.startswith("model."): | |
unet += tensor_metric(model[k]) | |
kk = k.replace("first_stage_model.", "") | |
if kk.startswith("encode") or kk.startswith("decode"): | |
vae += tensor_metric(model[k]) | |
if k.startswith("cond_stage_model."): | |
clip += tensor_metric(model[k]) | |
b_unet, b_vae, b_clip = -6131.5400, 17870.7051, 4509.0234 | |
s_unet, s_vae, s_clip = 10000, 10000, 10000 | |
n_unet = int(abs(unet/b_unet - 1) * s_unet) | |
n_vae = int(abs(vae/b_vae - 1) * s_vae) | |
n_clip = int(abs(clip/b_clip - 1) * s_clip) | |
unet = f"{n_unet:04}" if unet != 0 else "----" | |
vae = f"{n_vae:04}" if vae != 0 else "----" | |
clip = f"{n_clip:04}" if clip != 0 else "----" | |
return unet+"/"+vae+"/"+clip, (n_unet, n_vae, n_clip) | |
def prune(model, verbose=False): | |
in_state = model | |
out_state = {} | |
for k in in_state.keys(): | |
saved = False | |
for a in MODEL_KEYS: | |
if k.startswith(a): | |
out_state[k] = in_state[k] | |
saved = True | |
break | |
if not saved and verbose: | |
print("DELETING", k) | |
if len(out_state) == 0: | |
print("NO KEYS FOUND") | |
return out_state | |
def extract_vae(model): | |
out = {} | |
for k in model: | |
kk = k | |
if k.startswith('first_stage_model'): | |
kk = k.replace('first_stage_model.', '') | |
if kk.split('.')[0] in VAE_KEYS: | |
out[kk] = model[k] | |
return out | |
def extract_metadata(model): | |
metadata = {} | |
for k in METADATA.keys(): | |
if k in model: | |
metadata[k] = model[k] | |
return metadata | |
def replace_vae(model, vae, verbose=False): | |
for v in vae: | |
k = "first_stage_model." + v | |
if k in model: | |
if verbose: | |
print("REPLACING", k) | |
model[k] = vae[v] | |
return model | |
def rename(model, verbose=False): | |
for k in list(model.keys()): | |
for r in RENAME_KEYS.keys(): | |
if r in k and not RENAME_KEYS[r] in k: | |
kk = k.replace(r, RENAME_KEYS[r]) | |
if verbose: | |
print("RENAMING", k) | |
model[kk] = model[k] | |
del model[k] | |
return model | |
def load(file): | |
model = {} | |
metadata = {} | |
if file.endswith(".safetensors"): | |
with safetensors.safe_open(file, framework="pt", device="cpu") as f: | |
for key in f.keys(): | |
model[key] = f.get_tensor(key) | |
else: | |
raw = torch.load(file, map_location="cpu") | |
if 'state_dict' in raw: | |
metadata = extract_metadata(raw) | |
model = raw["state_dict"] | |
else: | |
model = raw | |
return model, metadata | |
def save(model, metadata, file): | |
if file.endswith(".safetensors"): | |
safetensors.torch.save_file(model, file) | |
return | |
if not metadata and not ".vae." in file: | |
metadata = dict(METADATA) | |
out = metadata | |
out['state_dict'] = model | |
torch.save(out, file) | |
def do_half(model, half_unet, half_vae, half_clip): | |
for k in model: | |
a = k.startswith("model.") and half_unet | |
b = k.startswith("first_stage_model.") and half_vae | |
c = k.startswith("cond_stage_model.") and half_clip | |
if a or b or c and type(model[k]) == torch.Tensor: | |
model[k] = model[k].half() | |
return model | |
def verbose_print(verbose, *argv): | |
if verbose: | |
print(argv) | |
def do_prune(model_file, vae_file, out_file, half_unet, half_vae, half_clip, verbose=False): | |
print("LOADING...") | |
model, metadata = load(model_file) | |
model = rename(model, verbose) | |
input_metric, _ = metric(model) | |
in_size = os.path.getsize(model_file) | |
print(f"INPUT {os.path.basename(model_file)} ({input_metric}) ({in_size*1e-9:.2f} GB)") | |
print("PRUNING...") | |
model = prune(model, verbose) | |
if vae_file: | |
vae, _ = load(vae_file) | |
vae_size = os.path.getsize(vae_file) | |
vae_metric, _ = metric(vae) | |
print(f"ADDING {os.path.basename(vae_file)} ({vae_metric}) ({vae_size*1e-9:.2f} GB)") | |
vae = extract_vae(vae) | |
model = replace_vae(model, vae, verbose) | |
if half_unet or half_vae or half_clip: | |
print("HALVING...") | |
model = do_half(model, half_unet, half_vae, half_clip) | |
print("SAVING...") | |
save(model, metadata, out_file) | |
output_metric, _ = metric(model) | |
out_size = os.path.getsize(out_file) | |
print(f"OUTPUT {os.path.basename(out_file)} ({output_metric}) ({out_size*1e-9:.2f} GB)") | |
def prune_cli(): | |
import argparse | |
import os | |
parser = argparse.ArgumentParser(description='Pruning') | |
parser.add_argument('--ckpt', type=str, default=None, required=True, help='path to model') | |
parser.add_argument('--vae', type=str, default=None, help='path to vae') | |
parser.add_argument('--out', type=str, default=None, required=True, help='path to save the pruned model') | |
parser.add_argument('-full', action='store_true', help='full precision (fp32)') | |
parser.add_argument('-v', action='store_true', help='verbose logging') | |
args = parser.parse_args() | |
model_file = args.ckpt | |
vae_file = args.vae | |
out_file = args.out | |
half = not args.full | |
verbose = args.v | |
if not os.path.exists(model_file): | |
print("model not found") | |
return | |
if vae_file and not os.path.exists(vae_file): | |
print("VAE not found") | |
return | |
do_prune(model_file, vae_file, out_file, half, half, half, verbose) | |
if __name__ == "__main__": | |
prune_cli() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment