Skip to content

Instantly share code, notes, and snippets.

@laksjdjf
Created November 22, 2025 04:28
Show Gist options
  • Select an option

  • Save laksjdjf/4b687345bbe8f750f603d07845650c45 to your computer and use it in GitHub Desktop.

Select an option

Save laksjdjf/4b687345bbe8f750f603d07845650c45 to your computer and use it in GitHub Desktop.
import node_helpers
import comfy.utils
import math
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy.patcher_extension import WrappersMP
import torch
class TextEncodeQwenImageEditPlusFixPixelShift(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TextEncodeQwenImageEditPlusFixPixelShift",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
io.Vae.Input("vae", optional=True),
io.Image.Input("image1", optional=True),
io.Image.Input("image2", optional=True),
io.Image.Input("image3", optional=True),
io.Boolean.Input("unscale_latent", default=True),
io.Boolean.Input("input_image_to_clip", default=True),
],
outputs=[
io.Conditioning.Output(),
],
)
@classmethod
def execute(cls, clip, prompt, vae=None, image1=None, image2=None, image3=None, unscale_latent=True, input_image_to_clip=True) -> io.NodeOutput:
ref_latents = []
images = [image1, image2, image3]
images_vl = []
llama_template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
image_prompt = ""
for i, image in enumerate(images):
if image is not None:
samples = image.movedim(-1, 1)
if input_image_to_clip:
total = int(384 * 384)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled")
images_vl.append(s.movedim(1, -1))
if vae is not None:
if unscale_latent:
ref_latents.append(vae.encode(samples.movedim(1, -1)))
else:
total = int(1024 * 1024)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by / 8.0) * 8
height = round(samples.shape[2] * scale_by / 8.0) * 8
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled")
ref_latents.append(vae.encode(s.movedim(1, -1)[:, :, :, :3]))
image_prompt += "Picture {}: <|vision_start|><|image_pad|><|vision_end|>".format(i + 1)
tokens = clip.tokenize(image_prompt + prompt, images=images_vl, llama_template=llama_template)
conditioning = clip.encode_from_tokens_scheduled(tokens)
if len(ref_latents) > 0:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True)
return io.NodeOutput(conditioning)
def wrapper(self: float):
def func(executor, x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs):
hidden_states, img_ids, orig_shape = self.process_img(x)
if ref_latents is not None:
h = 0
w = 0
index = 0
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
for ref in ref_latents:
if index_ref_method:
index += 1
h_offset = 0
w_offset = 0
else:
index = 1
h_offset = 0
w_offset = 0
if ref.shape[-2] + h > ref.shape[-1] + w:
w_offset = w
else:
h_offset = h
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
# Prepare text positional embeddings for bottom-right placement
txt_start_h = round(((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) / 2)
txt_start_w = round(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) / 2)
txt_start_id = max(txt_start_h, txt_start_w)
txt_start_list = [txt_start_id, txt_start_h, txt_start_w]
txt_ids = torch.cat([torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 1) for txt_start in txt_start_list], dim=-1)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids
def patch(args, original_func):
args["pe"].copy_(image_rotary_emb)
return original_func["original_block"](args)
transformer_options_copy = transformer_options.copy()
transformer_options_copy["patches_replace"] = transformer_options_copy.get("patches_replace", {})
transformer_options_copy["patches_replace"]["dit"] = transformer_options_copy["patches_replace"].get("dit", {})
transformer_options_copy["patches_replace"]["dit"][("double_block", 0)] = patch
model_output = executor(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options_copy, **kwargs)
return model_output
return func
class QwenImageTextBottomRight(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="QwenImageTextBottomRight",
category="_for_testing",
inputs=[
io.Model.Input("model"),
],
outputs=[
io.Model.Output(),
],
description="A node that adds a wrapper to position text prompts at the bottom right in Qwen image generation.",
)
@classmethod
def execute(cls, model) -> io.NodeOutput:
new_model = model.clone()
w = new_model.wrappers.setdefault(WrappersMP.DIFFUSION_MODEL, {}).setdefault(None, [])
w.append(wrapper(model.model.diffusion_model))
return io.NodeOutput(new_model)
class QwenExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TextEncodeQwenImageEditPlusFixPixelShift,
QwenImageTextBottomRight,
]
async def comfy_entrypoint() -> QwenExtension:
return QwenExtension()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment