Skip to content

Instantly share code, notes, and snippets.

@laksjdjf
Created April 27, 2025 07:08
Show Gist options
  • Save laksjdjf/5bdb5b8f20bd1ba289704b15f9868680 to your computer and use it in GitHub Desktop.
Save laksjdjf/5bdb5b8f20bd1ba289704b15f9868680 to your computer and use it in GitHub Desktop.
import torch
def new_vec(mode, chunks, x):
xs = x.clone().chunk(chunks, dim=0)
ref_xs = torch.cat([xi[0].unsqueeze(0).expand(xi.shape[0], -1, -1).clone() for xi in xs], dim=0).clone()
if mode == "concat":
new_x = x.clone()
return torch.cat([new_x, ref_xs], dim=1)
else:
return ref_xs
class ReferenceFirstBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ("MODEL",),
"mode": (["concat", "replace"], {"default": "concat"}),
"depth": ("INT", {"default": 0, "min": -1, "max": 12}),
"start_step": ("FLOAT", {"default": 0,"min": 0, "max": 1, "step": 0.01}),
"end_step": ("FLOAT", {"default": 1, "min": 0, "max": 1, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL", )
FUNCTION = "reference_only"
CATEGORY = "loaders"
def reference_only(self, model, mode, depth, start_step, end_step):
model_reference = model.clone()
start_sigma = model_reference.model.model_sampling.percent_to_sigma(start_step)
end_sigma = model_reference.model.model_sampling.percent_to_sigma(end_step)
self.depth = depth
self.sdxl = hasattr(model_reference.model.diffusion_model, "label_emb")
self.num_blocks = 8 if self.sdxl else 11
def reference_apply(q, k, v, extra_options):
block_name, block_id = extra_options["block"]
chunks = len(extra_options["cond_or_uncond"])
if block_name == "output":
block_number = self.num_blocks - block_id
else:
block_number = 100
q_out = q.clone()
k_out = k.clone()
v_out = v.clone()
sigma = extra_options["sigmas"][0].item()
if end_sigma <= sigma <= start_sigma and block_number <= self.depth:
k_out = new_vec(mode, chunks, k_out)
v_out = new_vec(mode, chunks, v_out)
return q_out, k_out, v_out
model_reference.set_model_attn1_patch(reference_apply)
return (model_reference, )
NODE_CLASS_MAPPINGS = {
"ReferenceFirstBatch": ReferenceFirstBatch,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment