Created
May 28, 2025 12:01
-
-
Save a-r-r-o-w/b34d83641a3f80e26759789d5eec3280 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Reference: https://github.com/arcee-ai/mergekit/blob/488957e8e67c82861ecf63ef761f6bc59122dc74/mergekit/scripts/extract_lora.py | |
import argparse | |
import torch | |
from safetensors.torch import load_file, save_file | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
torch.backends.cuda.preferred_linalg_library("cusolver") | |
def get_low_rank_weight(x: torch.Tensor, rank: int, distribute_scale: bool, method: str, dtype: torch.dtype = torch.float32): | |
def get_scale(s): | |
if distribute_scale: | |
sqrt_s = torch.sqrt(s) | |
scale_a = torch.diag(sqrt_s) | |
scale_b = torch.diag(sqrt_s) | |
else: | |
scale_a = torch.diag(s) | |
scale_b = torch.eye(rank, dtype=torch.float32, device=s.device) | |
return scale_a, scale_b | |
if method == "svd": | |
U, S, Vh = torch.linalg.svd(x, full_matrices=False) | |
rank = min(rank, S.shape[0]) | |
U = U[:, :rank] | |
S = S[:rank] | |
Vh = Vh[:rank, :] | |
scale_a, scale_b = get_scale(S) | |
lora_A = scale_b @ Vh | |
lora_B = U @ scale_a | |
elif method == "qr_svd": | |
Q, R = torch.linalg.qr(x) | |
U, S, Vh = torch.linalg.svd(R, full_matrices=False) | |
rank = min(rank, S.shape[0]) | |
U = Q @ U[:, :rank] | |
S = S[:rank] | |
Vh = Vh[:rank, :] | |
scale_a, scale_b = get_scale(S) | |
lora_A = scale_b @ Vh | |
lora_B = U @ scale_a | |
elif method == "randomized_svd": | |
rand_matrix = torch.randn(x.shape[1], rank, device=x.device, dtype=x.dtype) | |
Y = x @ rand_matrix | |
Q, _ = torch.linalg.qr(Y) | |
B = Q.T @ x | |
U_tilde, S, Vh = torch.linalg.svd(B, full_matrices=False) | |
rank = min(rank, S.shape[0]) | |
U = Q @ U_tilde | |
U = U[:, :rank] | |
S = S[:rank] | |
Vh = Vh[:rank, :] | |
lora_A = scale_b @ Vh | |
lora_B = U @ scale_a | |
elif method == "cur": | |
col_norms = torch.norm(x, dim=0) | |
row_norms = torch.norm(x, dim=1) | |
col_probs = col_norms / col_norms.sum() | |
row_probs = row_norms / row_norms.sum() | |
col_indices = torch.multinomial(col_probs, rank, replacement=False) | |
row_indices = torch.multinomial(row_probs, rank, replacement=False) | |
C = x[:, col_indices] | |
R = x[row_indices, :] | |
U = torch.linalg.pinv(C[row_indices, :]) @ x[row_indices, col_indices] @ torch.linalg.pinv(R[:, col_indices]) | |
lora_A = C @ U | |
lora_B = R | |
return ( | |
lora_A.contiguous().to(dtype=dtype), | |
lora_B.contiguous().to(dtype=dtype), | |
) | |
def main(args): | |
low_rank_fn = torch.compile(get_low_rank_weight, mode="max-autotune-no-cudagraphs", dynamic=True) | |
model1_state_dict = load_file(args.model1_path) | |
model1_state_dict = convert_flux_transformer_checkpoint_to_diffusers(model1_state_dict, 19, 38, 3072, 4) | |
model2_state_dict = load_file(args.model2_path) | |
model2_state_dict = convert_flux_transformer_checkpoint_to_diffusers(model2_state_dict, 19, 38, 3072, 4) | |
lora_state_dict = {} | |
for key in model2_state_dict: | |
if ( | |
(not key.endswith(".weight")) or | |
("norm" in key) or | |
(key not in model1_state_dict) | |
): | |
continue | |
diff = model2_state_dict[key].float() - model1_state_dict[key].float() | |
print(f"Processing key: {key} {diff.shape}") | |
diff = diff.cuda() | |
lora_A, lora_B = low_rank_fn(diff, args.rank, args.distribute_scale, args.method, dtype=model2_state_dict[key].dtype) | |
A_key = "transformer." + key.removesuffix(".weight") + ".lora_A.weight" | |
B_key = "transformer." + key.removesuffix(".weight") + ".lora_B.weight" | |
lora_state_dict[A_key] = lora_A.cpu() | |
lora_state_dict[B_key] = lora_B.cpu() | |
save_file(lora_state_dict, args.output_path) | |
def get_args(): | |
parser = argparse.ArgumentParser("Extract LoRA from difference between model2 and model1 weights") | |
parser.add_argument("--model1_path", type=str, required=True) | |
parser.add_argument("--model2_path", type=str, required=True) | |
parser.add_argument("--output_path", type=str, required=True) | |
parser.add_argument("--rank", type=int, default=256) | |
parser.add_argument("--distribute_scale", action="store_true") | |
parser.add_argument( | |
"--method", | |
type=str, | |
choices=["svd", "qr_svd", "randomized_svd", "cur"], | |
default="svd", | |
) | |
return parser.parse_args() | |
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; | |
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation | |
def swap_scale_shift(weight): | |
shift, scale = weight.chunk(2, dim=0) | |
new_weight = torch.cat([scale, shift], dim=0) | |
return new_weight | |
def convert_flux_transformer_checkpoint_to_diffusers( | |
original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0 | |
): | |
converted_state_dict = {} | |
## time_text_embed.timestep_embedder <- time_in | |
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop( | |
"time_in.in_layer.weight" | |
) | |
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop( | |
"time_in.in_layer.bias" | |
) | |
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop( | |
"time_in.out_layer.weight" | |
) | |
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop( | |
"time_in.out_layer.bias" | |
) | |
## time_text_embed.text_embedder <- vector_in | |
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop( | |
"vector_in.in_layer.weight" | |
) | |
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop( | |
"vector_in.in_layer.bias" | |
) | |
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop( | |
"vector_in.out_layer.weight" | |
) | |
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop( | |
"vector_in.out_layer.bias" | |
) | |
# guidance | |
has_guidance = any("guidance" in k for k in original_state_dict) | |
if has_guidance: | |
converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = original_state_dict.pop( | |
"guidance_in.in_layer.weight" | |
) | |
converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = original_state_dict.pop( | |
"guidance_in.in_layer.bias" | |
) | |
converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = original_state_dict.pop( | |
"guidance_in.out_layer.weight" | |
) | |
converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = original_state_dict.pop( | |
"guidance_in.out_layer.bias" | |
) | |
# context_embedder | |
converted_state_dict["context_embedder.weight"] = original_state_dict.pop("txt_in.weight") | |
converted_state_dict["context_embedder.bias"] = original_state_dict.pop("txt_in.bias") | |
# x_embedder | |
converted_state_dict["x_embedder.weight"] = original_state_dict.pop("img_in.weight") | |
converted_state_dict["x_embedder.bias"] = original_state_dict.pop("img_in.bias") | |
# double transformer blocks | |
for i in range(num_layers): | |
block_prefix = f"transformer_blocks.{i}." | |
# norms. | |
## norm1 | |
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop( | |
f"double_blocks.{i}.img_mod.lin.weight" | |
) | |
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop( | |
f"double_blocks.{i}.img_mod.lin.bias" | |
) | |
## norm1_context | |
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop( | |
f"double_blocks.{i}.txt_mod.lin.weight" | |
) | |
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop( | |
f"double_blocks.{i}.txt_mod.lin.bias" | |
) | |
# Q, K, V | |
sample_q, sample_k, sample_v = torch.chunk( | |
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0 | |
) | |
context_q, context_k, context_v = torch.chunk( | |
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0 | |
) | |
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( | |
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0 | |
) | |
context_q_bias, context_k_bias, context_v_bias = torch.chunk( | |
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0 | |
) | |
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q]) | |
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias]) | |
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k]) | |
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias]) | |
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v]) | |
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias]) | |
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q]) | |
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias]) | |
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k]) | |
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias]) | |
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v]) | |
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias]) | |
# qk_norm | |
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( | |
f"double_blocks.{i}.img_attn.norm.query_norm.scale" | |
) | |
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( | |
f"double_blocks.{i}.img_attn.norm.key_norm.scale" | |
) | |
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop( | |
f"double_blocks.{i}.txt_attn.norm.query_norm.scale" | |
) | |
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop( | |
f"double_blocks.{i}.txt_attn.norm.key_norm.scale" | |
) | |
# ff img_mlp | |
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop( | |
f"double_blocks.{i}.img_mlp.0.weight" | |
) | |
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop( | |
f"double_blocks.{i}.img_mlp.0.bias" | |
) | |
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop( | |
f"double_blocks.{i}.img_mlp.2.weight" | |
) | |
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop( | |
f"double_blocks.{i}.img_mlp.2.bias" | |
) | |
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop( | |
f"double_blocks.{i}.txt_mlp.0.weight" | |
) | |
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop( | |
f"double_blocks.{i}.txt_mlp.0.bias" | |
) | |
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop( | |
f"double_blocks.{i}.txt_mlp.2.weight" | |
) | |
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop( | |
f"double_blocks.{i}.txt_mlp.2.bias" | |
) | |
# output projections. | |
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop( | |
f"double_blocks.{i}.img_attn.proj.weight" | |
) | |
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop( | |
f"double_blocks.{i}.img_attn.proj.bias" | |
) | |
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop( | |
f"double_blocks.{i}.txt_attn.proj.weight" | |
) | |
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop( | |
f"double_blocks.{i}.txt_attn.proj.bias" | |
) | |
# single transformer blocks | |
for i in range(num_single_layers): | |
block_prefix = f"single_transformer_blocks.{i}." | |
# norm.linear <- single_blocks.0.modulation.lin | |
converted_state_dict[f"{block_prefix}norm.linear.weight"] = original_state_dict.pop( | |
f"single_blocks.{i}.modulation.lin.weight" | |
) | |
converted_state_dict[f"{block_prefix}norm.linear.bias"] = original_state_dict.pop( | |
f"single_blocks.{i}.modulation.lin.bias" | |
) | |
# Q, K, V, mlp | |
mlp_hidden_dim = int(inner_dim * mlp_ratio) | |
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) | |
q, k, v, mlp = torch.split(original_state_dict.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0) | |
q_bias, k_bias, v_bias, mlp_bias = torch.split( | |
original_state_dict.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0 | |
) | |
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q]) | |
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias]) | |
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k]) | |
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias]) | |
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v]) | |
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias]) | |
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp]) | |
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias]) | |
# qk norm | |
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( | |
f"single_blocks.{i}.norm.query_norm.scale" | |
) | |
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( | |
f"single_blocks.{i}.norm.key_norm.scale" | |
) | |
# output projections. | |
converted_state_dict[f"{block_prefix}proj_out.weight"] = original_state_dict.pop( | |
f"single_blocks.{i}.linear2.weight" | |
) | |
converted_state_dict[f"{block_prefix}proj_out.bias"] = original_state_dict.pop( | |
f"single_blocks.{i}.linear2.bias" | |
) | |
converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight") | |
converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias") | |
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( | |
original_state_dict.pop("final_layer.adaLN_modulation.1.weight") | |
) | |
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( | |
original_state_dict.pop("final_layer.adaLN_modulation.1.bias") | |
) | |
return converted_state_dict | |
if __name__ == "__main__": | |
args = get_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment