Skip to content

Instantly share code, notes, and snippets.

@Clybius
Created February 9, 2024 17:16
Show Gist options
  • Save Clybius/902a295973c005ddc4fde29fe9bb0be1 to your computer and use it in GitHub Desktop.
Save Clybius/902a295973c005ddc4fde29fe9bb0be1 to your computer and use it in GitHub Desktop.
ComfyUI Model Merger Custom Nodes
import torch
from tqdm.auto import trange
import tqdm
def train_difference(a, b, c, key):
a = a[key][0]
b = b[key][0]
c = c[key][0]
merged = []
atype = a.dtype
#b = torch.stack(b.values())
#c = torch.stack(c.values())
diff_AB = a.float() - b.float()
distance_A0 = torch.abs(b.float() - c.float())
distance_A1 = torch.abs(b.float() - a.float())
sum_distances = distance_A0 + distance_A1
scale = torch.where(
sum_distances != 0, distance_A1 / sum_distances, torch.tensor(0.0).float()
)
sign_scale = torch.sign(b.float() - c.float())
scale = sign_scale * torch.abs(scale)
new_diff = scale * torch.abs(diff_AB)
merged = ((new_diff * 1.8).to(atype),)
return merged
def weighted_sum(a, b, alpha):
return (1 - alpha) * a + alpha * b
def multiply_difference(a, b, c, key, alpha, beta):
a = a[key][0]
b = b[key][0]
c = c[key][0]
merged = []
atype = a.dtype
diff_b = torch.pow(torch.abs(b.float() - a.float()), (1 - alpha))
diff_c = torch.pow(torch.abs(c.float() - a.float()), alpha)
difference = torch.copysign(diff_b * diff_c, weighted_sum(b.float(), c.float(), beta) - a.float())
merged = (difference.to(atype),)
return merged
def euclidean_difference(a, b, c, key, alpha):
a = a[key][0]
b = b[key][0]
c = c[key][0]
merged = []
atype = a.dtype
b_diff = b.float() - a.float()
c_diff = c.float() - a.float()
b_diff = torch.nan_to_num(b_diff / torch.linalg.norm(b_diff))
c_diff = torch.nan_to_num(c_diff / torch.linalg.norm(c_diff))
distance = (1 - alpha) * b_diff**2 + alpha * c_diff**2
distance = torch.sqrt(distance)
sum_diff = weighted_sum(b.float(), c.float(), alpha) - a.float()
distance = torch.copysign(distance, sum_diff)
target_norm = torch.linalg.norm(sum_diff)
return ((distance / torch.linalg.norm(distance) * target_norm).to(atype),)
import numpy as np
def get_cos_similarity(a, b, sum_mode):
sim = torch.nn.CosineSimilarity(dim=0)
sims = np.array([], dtype=np.float64)
if sum_mode == 'cos_a':
theta_A_norm = torch.nn.functional.normalize(
a.to(torch.float32), p=2, dim=0
)
theta_B_norm = torch.nn.functional.normalize(
b.to(torch.float32), p=2, dim=0
)
simab = sim(theta_A_norm, theta_B_norm)
sims = np.append(sims, simab.numpy())
elif sum_mode == 'cos_b':
simab = sim(
a.to(torch.float32),
b.to(torch.float32),
)
dot_product = torch.dot(
a.view(-1).to(torch.float32),
b.view(-1).to(torch.float32),
)
magnitude_similarity = dot_product / (
torch.norm(a.to(torch.float32))
* torch.norm(a.to(torch.float32))
)
combined_similarity = (simab + magnitude_similarity) / 2.0
sims = np.append(sims, combined_similarity.numpy())
sims = np.delete(
sims, np.where(sims < np.percentile(sims, 1, method="midpoint"))
)
sims = np.delete(
sims, np.where(sims > np.percentile(sims, 99, method="midpoint"))
)
return sim, sims
def cosine_similarity_a(a, b, alpha, sim, sims): # Returns alpha value
a_norm = torch.nn.functional.normalize(a.to(torch.float32), p=2, dim=0)
b_norm = torch.nn.functional.normalize(b.to(torch.float32), p=2, dim=0)
simab = sim(a_norm, b_norm)
dot_product = torch.dot(a_norm.view(-1), b_norm.view(-1))
magnitude_similarity = dot_product / (torch.norm(a) * torch.norm(b))
combined_similarity = (simab + magnitude_similarity) / 2.0
k = (combined_similarity - sims.min()) / (sims.max() - sims.min())
k = k - alpha
return 1 - k.clip(min=0.0, max=1.0)
def add_cosineA(a, b, key, alpha, model2_is_diff):
a = a[key][0]
b = b[key][0]
merged = []
sim, sims = get_cos_similarity(a, (a + b) if model2_is_diff else b, "cos_a")
k = cosine_similarity_a(a, (a + b) if model2_is_diff else b, alpha, sim, sims)
return (((b - a) * k) if model2_is_diff else ((a * (1 - k) + b * k) - a),)
def cosine_similarity_b(a, b, alpha, sim, sims):
simab = sim(a.to(torch.float32), b.to(torch.float32))
dot_product = torch.dot(a.view(-1).to(torch.float32), b.view(-1).to(torch.float32))
magnitude_similarity = dot_product / (
torch.norm(a.to(torch.float32)) * torch.norm(b.to(torch.float32))
)
combined_similarity = (simab + magnitude_similarity) / 2.0
k = (combined_similarity - sims.min()) / (sims.max() - sims.min())
k = k - alpha
return 1 - k.clip(min=0.0, max=1.0)
def add_cosineB(a, b, key, alpha, model2_is_diff):
a = a[key][0]
b = b[key][0]
merged = []
sim, sims = get_cos_similarity(a, (a + b) if model2_is_diff else b, "cos_b")
k = cosine_similarity_b(a, (a + b) if model2_is_diff else b, alpha, sim, sims)
return (((b - a) * k) if model2_is_diff else ((a * (1 - k) + b * k) - a),)
def normalize(v, eps: float):
norm_v = torch.linalg.norm(v)
if norm_v > eps:
v = v / norm_v
return v
def slerp(a, b, key, alpha, DOT_THRESHOLD: float = 0.9995, eps: float = 1e-8):
"""
Spherical linear interpolation
From: https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c
Args:
t (float/np.ndarray): Float value between 0.0 and 1.0
v0 (np.ndarray): Starting vector
v1 (np.ndarray): Final vector
DOT_THRESHOLD (float): Threshold for considering the two vectors as
colinear. Not recommended to alter this.
Returns:
v2 (np.ndarray): Interpolation vector between v0 and v1
"""
a = a[key][0]
b = b[key][0]
merged = []
atype = a.dtype
adevice = a.device
aback = a
is_torch = False
a = a.float()
b = b.float()
# Copy the vectors to reuse them later
a_copy = a
b_copy = b
# Normalize the vectors to get the directions and angles
a = normalize(a, eps)
b = normalize(b, eps)
# Dot product with the normalized vectors (can't use np.dot in W)
dot = torch.sum(a * b)
# If absolute value of dot product is almost 1, vectors are ~colinear, so use lerp
if torch.abs(dot) > DOT_THRESHOLD:
res = weighted_sum(a_copy, b_copy, alpha)
return (res.to(atype) - aback.to(atype),)
# Calculate initial angle between a and b
theta_0 = torch.arccos(dot)
sin_theta_0 = torch.sin(theta_0)
# Angle at timestep alpha
theta_t = theta_0 * alpha
sin_theta_t = torch.sin(theta_t)
# Finish the slerp algorithm
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
res = s0 * a_copy + s1 * b_copy
return (res.to(atype) - aback.to(atype),)
class ModelMergeScaledDifference:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
"model3": ("MODEL",),
"ratio": ("FLOAT", {"default": 1.0, "min": -5.0, "max": 5.0, "step": 0.01}),
"return_model": ("BOOLEAN", {"default": True}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge"
CATEGORY = "advanced/model_merging"
def merge(self, model1, model2, model3, ratio, return_model):
m = model1.clone()
kp_a = model1.get_key_patches("diffusion_model.")
kp_b = model2.get_key_patches("diffusion_model.")
kp_c = model3.get_key_patches("diffusion_model.")
#print(kp_a)
for k in kp_a:
m.add_patches({k: train_difference(kp_a, kp_b, kp_c, k)}, ratio, 1.0 if return_model else 0.0)
return (m, )
class ModelMergeMultiplyDifference:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
"model3": ("MODEL",),
"alpha": ("FLOAT", {"default": 0.5, "min": -5.0, "max": 5.0, "step": 0.01}),
"beta": ("FLOAT", {"default": 0.5, "min": -5.0, "max": 5.0, "step": 0.01}),
"ratio": ("FLOAT", {"default": 1.0, "min": -5.0, "max": 5.0, "step": 0.01}),
"return_model": ("BOOLEAN", {"default": True}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "multiply_merge"
CATEGORY = "advanced/model_merging"
def multiply_merge(self, model1, model2, model3, alpha, beta, ratio, return_model):
m = model1.clone()
kp_a = model1.get_key_patches("diffusion_model.")
kp_b = model2.get_key_patches("diffusion_model.")
kp_c = model3.get_key_patches("diffusion_model.")
for k in kp_a:
m.add_patches({k: multiply_difference(kp_a, kp_b, kp_c, k, alpha, beta)}, ratio, 1.0 if return_model else 0.0)
return (m, )
class ModelMergeEuclideanAddDifference:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
"model3": ("MODEL",),
"alpha": ("FLOAT", {"default": 0.5, "min": -5.0, "max": 5.0, "step": 0.01}),
"ratio": ("FLOAT", {"default": 1.0, "min": -5.0, "max": 5.0, "step": 0.01}),
"return_model": ("BOOLEAN", {"default": True}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "euclidean_merge"
CATEGORY = "advanced/model_merging"
def euclidean_merge(self, model1, model2, model3, alpha, ratio, return_model):
m = model1.clone()
kp_a = model1.get_key_patches("diffusion_model.")
kp_b = model2.get_key_patches("diffusion_model.")
kp_c = model3.get_key_patches("diffusion_model.")
for k in kp_a:
m.add_patches({k: euclidean_difference(kp_a, kp_b, kp_c, k, alpha)}, ratio, 1.0 if return_model else 0.0)
return (m, )
# Diff Merges
class ModelMergeAddCosine:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
"cos": (["model1", "model2"], ),
"alpha": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"ratio": ("FLOAT", {"default": 1.0, "min": -5.0, "max": 5.0, "step": 0.01}),
"model2_is_diff": ("BOOLEAN", {"default": False}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "cos_merge"
CATEGORY = "advanced/model_merging"
def cos_merge(self, model1, model2, cos, alpha, ratio, model2_is_diff):
m = model1.clone()
kp_a = model1.get_key_patches("diffusion_model.")
kp_b = model2.get_key_patches("diffusion_model.")
for k in kp_a:
m.add_patches({k: add_cosineA(kp_a, kp_b, k, alpha, model2_is_diff) if cos == "model1" else add_cosineB(kp_a, kp_b, k, alpha, model2_is_diff)}, ratio, 1.0)
return (m, )
class ModelMergeSlerp:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
"alpha": ("FLOAT", {"default": 0.5, "min": -5.0, "max": 5.0, "step": 0.01}),
"ratio": ("FLOAT", {"default": 1.0, "min": -5.0, "max": 5.0, "step": 0.01}),
"return_model": ("BOOLEAN", {"default": True}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "slerp_merge"
CATEGORY = "advanced/model_merging"
def slerp_merge(self, model1, model2, alpha, ratio, return_model):
m = model1.clone()
kp_a = model1.get_key_patches("diffusion_model.")
kp_b = model2.get_key_patches("diffusion_model.")
for k in kp_a:
m.add_patches({k: slerp(kp_a, kp_b, k, alpha)}, ratio, 1.0 if return_model else 0.0)
return (m, )
NODE_CLASS_MAPPINGS = {
"ModelMergeScaledDifference": ModelMergeScaledDifference,
"ModelMergeMultiplyDifference": ModelMergeMultiplyDifference,
"ModelMergeEuclideanAddDifference": ModelMergeEuclideanAddDifference,
"ModelMergeAddCosine": ModelMergeAddCosine,
"ModelMergeSlerp": ModelMergeSlerp,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment