Skip to content

Instantly share code, notes, and snippets.

@a-r-r-o-w
Created May 28, 2025 12:01
Show Gist options
  • Save a-r-r-o-w/b34d83641a3f80e26759789d5eec3280 to your computer and use it in GitHub Desktop.
Save a-r-r-o-w/b34d83641a3f80e26759789d5eec3280 to your computer and use it in GitHub Desktop.
# Reference: https://github.com/arcee-ai/mergekit/blob/488957e8e67c82861ecf63ef761f6bc59122dc74/mergekit/scripts/extract_lora.py
import argparse
import torch
from safetensors.torch import load_file, save_file
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.preferred_linalg_library("cusolver")
def get_low_rank_weight(x: torch.Tensor, rank: int, distribute_scale: bool, method: str, dtype: torch.dtype = torch.float32):
def get_scale(s):
if distribute_scale:
sqrt_s = torch.sqrt(s)
scale_a = torch.diag(sqrt_s)
scale_b = torch.diag(sqrt_s)
else:
scale_a = torch.diag(s)
scale_b = torch.eye(rank, dtype=torch.float32, device=s.device)
return scale_a, scale_b
if method == "svd":
U, S, Vh = torch.linalg.svd(x, full_matrices=False)
rank = min(rank, S.shape[0])
U = U[:, :rank]
S = S[:rank]
Vh = Vh[:rank, :]
scale_a, scale_b = get_scale(S)
lora_A = scale_b @ Vh
lora_B = U @ scale_a
elif method == "qr_svd":
Q, R = torch.linalg.qr(x)
U, S, Vh = torch.linalg.svd(R, full_matrices=False)
rank = min(rank, S.shape[0])
U = Q @ U[:, :rank]
S = S[:rank]
Vh = Vh[:rank, :]
scale_a, scale_b = get_scale(S)
lora_A = scale_b @ Vh
lora_B = U @ scale_a
elif method == "randomized_svd":
rand_matrix = torch.randn(x.shape[1], rank, device=x.device, dtype=x.dtype)
Y = x @ rand_matrix
Q, _ = torch.linalg.qr(Y)
B = Q.T @ x
U_tilde, S, Vh = torch.linalg.svd(B, full_matrices=False)
rank = min(rank, S.shape[0])
U = Q @ U_tilde
U = U[:, :rank]
S = S[:rank]
Vh = Vh[:rank, :]
lora_A = scale_b @ Vh
lora_B = U @ scale_a
elif method == "cur":
col_norms = torch.norm(x, dim=0)
row_norms = torch.norm(x, dim=1)
col_probs = col_norms / col_norms.sum()
row_probs = row_norms / row_norms.sum()
col_indices = torch.multinomial(col_probs, rank, replacement=False)
row_indices = torch.multinomial(row_probs, rank, replacement=False)
C = x[:, col_indices]
R = x[row_indices, :]
U = torch.linalg.pinv(C[row_indices, :]) @ x[row_indices, col_indices] @ torch.linalg.pinv(R[:, col_indices])
lora_A = C @ U
lora_B = R
return (
lora_A.contiguous().to(dtype=dtype),
lora_B.contiguous().to(dtype=dtype),
)
def main(args):
low_rank_fn = torch.compile(get_low_rank_weight, mode="max-autotune-no-cudagraphs", dynamic=True)
model1_state_dict = load_file(args.model1_path)
model1_state_dict = convert_flux_transformer_checkpoint_to_diffusers(model1_state_dict, 19, 38, 3072, 4)
model2_state_dict = load_file(args.model2_path)
model2_state_dict = convert_flux_transformer_checkpoint_to_diffusers(model2_state_dict, 19, 38, 3072, 4)
lora_state_dict = {}
for key in model2_state_dict:
if (
(not key.endswith(".weight")) or
("norm" in key) or
(key not in model1_state_dict)
):
continue
diff = model2_state_dict[key].float() - model1_state_dict[key].float()
print(f"Processing key: {key} {diff.shape}")
diff = diff.cuda()
lora_A, lora_B = low_rank_fn(diff, args.rank, args.distribute_scale, args.method, dtype=model2_state_dict[key].dtype)
A_key = "transformer." + key.removesuffix(".weight") + ".lora_A.weight"
B_key = "transformer." + key.removesuffix(".weight") + ".lora_B.weight"
lora_state_dict[A_key] = lora_A.cpu()
lora_state_dict[B_key] = lora_B.cpu()
save_file(lora_state_dict, args.output_path)
def get_args():
parser = argparse.ArgumentParser("Extract LoRA from difference between model2 and model1 weights")
parser.add_argument("--model1_path", type=str, required=True)
parser.add_argument("--model2_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
parser.add_argument("--rank", type=int, default=256)
parser.add_argument("--distribute_scale", action="store_true")
parser.add_argument(
"--method",
type=str,
choices=["svd", "qr_svd", "randomized_svd", "cur"],
default="svd",
)
return parser.parse_args()
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def convert_flux_transformer_checkpoint_to_diffusers(
original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0
):
converted_state_dict = {}
## time_text_embed.timestep_embedder <- time_in
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
"time_in.in_layer.weight"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
"time_in.in_layer.bias"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
"time_in.out_layer.weight"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
"time_in.out_layer.bias"
)
## time_text_embed.text_embedder <- vector_in
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop(
"vector_in.in_layer.weight"
)
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop(
"vector_in.in_layer.bias"
)
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop(
"vector_in.out_layer.weight"
)
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop(
"vector_in.out_layer.bias"
)
# guidance
has_guidance = any("guidance" in k for k in original_state_dict)
if has_guidance:
converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = original_state_dict.pop(
"guidance_in.in_layer.weight"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = original_state_dict.pop(
"guidance_in.in_layer.bias"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = original_state_dict.pop(
"guidance_in.out_layer.weight"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = original_state_dict.pop(
"guidance_in.out_layer.bias"
)
# context_embedder
converted_state_dict["context_embedder.weight"] = original_state_dict.pop("txt_in.weight")
converted_state_dict["context_embedder.bias"] = original_state_dict.pop("txt_in.bias")
# x_embedder
converted_state_dict["x_embedder.weight"] = original_state_dict.pop("img_in.weight")
converted_state_dict["x_embedder.bias"] = original_state_dict.pop("img_in.bias")
# double transformer blocks
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
# norms.
## norm1
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_mod.lin.weight"
)
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop(
f"double_blocks.{i}.img_mod.lin.bias"
)
## norm1_context
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mod.lin.weight"
)
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mod.lin.bias"
)
# Q, K, V
sample_q, sample_k, sample_v = torch.chunk(
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0
)
context_q, context_k, context_v = torch.chunk(
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
)
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
)
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
# qk_norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
)
# ff img_mlp
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_mlp.0.weight"
)
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
f"double_blocks.{i}.img_mlp.0.bias"
)
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_mlp.2.weight"
)
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(
f"double_blocks.{i}.img_mlp.2.bias"
)
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mlp.0.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mlp.0.bias"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mlp.2.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop(
f"double_blocks.{i}.txt_mlp.2.bias"
)
# output projections.
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
f"double_blocks.{i}.img_attn.proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
f"double_blocks.{i}.img_attn.proj.bias"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop(
f"double_blocks.{i}.txt_attn.proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop(
f"double_blocks.{i}.txt_attn.proj.bias"
)
# single transformer blocks
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."
# norm.linear <- single_blocks.0.modulation.lin
converted_state_dict[f"{block_prefix}norm.linear.weight"] = original_state_dict.pop(
f"single_blocks.{i}.modulation.lin.weight"
)
converted_state_dict[f"{block_prefix}norm.linear.bias"] = original_state_dict.pop(
f"single_blocks.{i}.modulation.lin.bias"
)
# Q, K, V, mlp
mlp_hidden_dim = int(inner_dim * mlp_ratio)
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
q, k, v, mlp = torch.split(original_state_dict.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
q_bias, k_bias, v_bias, mlp_bias = torch.split(
original_state_dict.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
# qk norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
f"single_blocks.{i}.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
f"single_blocks.{i}.norm.key_norm.scale"
)
# output projections.
converted_state_dict[f"{block_prefix}proj_out.weight"] = original_state_dict.pop(
f"single_blocks.{i}.linear2.weight"
)
converted_state_dict[f"{block_prefix}proj_out.bias"] = original_state_dict.pop(
f"single_blocks.{i}.linear2.bias"
)
converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
original_state_dict.pop("final_layer.adaLN_modulation.1.weight")
)
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
original_state_dict.pop("final_layer.adaLN_modulation.1.bias")
)
return converted_state_dict
if __name__ == "__main__":
args = get_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment