Created
November 22, 2025 04:28
-
-
Save laksjdjf/4b687345bbe8f750f603d07845650c45 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import node_helpers | |
| import comfy.utils | |
| import math | |
| from typing_extensions import override | |
| from comfy_api.latest import ComfyExtension, io | |
| from comfy.patcher_extension import WrappersMP | |
| import torch | |
| class TextEncodeQwenImageEditPlusFixPixelShift(io.ComfyNode): | |
| @classmethod | |
| def define_schema(cls): | |
| return io.Schema( | |
| node_id="TextEncodeQwenImageEditPlusFixPixelShift", | |
| category="advanced/conditioning", | |
| inputs=[ | |
| io.Clip.Input("clip"), | |
| io.String.Input("prompt", multiline=True, dynamic_prompts=True), | |
| io.Vae.Input("vae", optional=True), | |
| io.Image.Input("image1", optional=True), | |
| io.Image.Input("image2", optional=True), | |
| io.Image.Input("image3", optional=True), | |
| io.Boolean.Input("unscale_latent", default=True), | |
| io.Boolean.Input("input_image_to_clip", default=True), | |
| ], | |
| outputs=[ | |
| io.Conditioning.Output(), | |
| ], | |
| ) | |
| @classmethod | |
| def execute(cls, clip, prompt, vae=None, image1=None, image2=None, image3=None, unscale_latent=True, input_image_to_clip=True) -> io.NodeOutput: | |
| ref_latents = [] | |
| images = [image1, image2, image3] | |
| images_vl = [] | |
| llama_template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" | |
| image_prompt = "" | |
| for i, image in enumerate(images): | |
| if image is not None: | |
| samples = image.movedim(-1, 1) | |
| if input_image_to_clip: | |
| total = int(384 * 384) | |
| scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) | |
| width = round(samples.shape[3] * scale_by) | |
| height = round(samples.shape[2] * scale_by) | |
| s = comfy.utils.common_upscale(samples, width, height, "area", "disabled") | |
| images_vl.append(s.movedim(1, -1)) | |
| if vae is not None: | |
| if unscale_latent: | |
| ref_latents.append(vae.encode(samples.movedim(1, -1))) | |
| else: | |
| total = int(1024 * 1024) | |
| scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) | |
| width = round(samples.shape[3] * scale_by / 8.0) * 8 | |
| height = round(samples.shape[2] * scale_by / 8.0) * 8 | |
| s = comfy.utils.common_upscale(samples, width, height, "area", "disabled") | |
| ref_latents.append(vae.encode(s.movedim(1, -1)[:, :, :, :3])) | |
| image_prompt += "Picture {}: <|vision_start|><|image_pad|><|vision_end|>".format(i + 1) | |
| tokens = clip.tokenize(image_prompt + prompt, images=images_vl, llama_template=llama_template) | |
| conditioning = clip.encode_from_tokens_scheduled(tokens) | |
| if len(ref_latents) > 0: | |
| conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True) | |
| return io.NodeOutput(conditioning) | |
| def wrapper(self: float): | |
| def func(executor, x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs): | |
| hidden_states, img_ids, orig_shape = self.process_img(x) | |
| if ref_latents is not None: | |
| h = 0 | |
| w = 0 | |
| index = 0 | |
| index_ref_method = kwargs.get("ref_latents_method", "index") == "index" | |
| for ref in ref_latents: | |
| if index_ref_method: | |
| index += 1 | |
| h_offset = 0 | |
| w_offset = 0 | |
| else: | |
| index = 1 | |
| h_offset = 0 | |
| w_offset = 0 | |
| if ref.shape[-2] + h > ref.shape[-1] + w: | |
| w_offset = w | |
| else: | |
| h_offset = h | |
| h = max(h, ref.shape[-2] + h_offset) | |
| w = max(w, ref.shape[-1] + w_offset) | |
| kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) | |
| hidden_states = torch.cat([hidden_states, kontext], dim=1) | |
| img_ids = torch.cat([img_ids, kontext_ids], dim=1) | |
| # Prepare text positional embeddings for bottom-right placement | |
| txt_start_h = round(((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) / 2) | |
| txt_start_w = round(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) / 2) | |
| txt_start_id = max(txt_start_h, txt_start_w) | |
| txt_start_list = [txt_start_id, txt_start_h, txt_start_w] | |
| txt_ids = torch.cat([torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 1) for txt_start in txt_start_list], dim=-1) | |
| ids = torch.cat((txt_ids, img_ids), dim=1) | |
| image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() | |
| del ids, txt_ids, img_ids | |
| def patch(args, original_func): | |
| args["pe"].copy_(image_rotary_emb) | |
| return original_func["original_block"](args) | |
| transformer_options_copy = transformer_options.copy() | |
| transformer_options_copy["patches_replace"] = transformer_options_copy.get("patches_replace", {}) | |
| transformer_options_copy["patches_replace"]["dit"] = transformer_options_copy["patches_replace"].get("dit", {}) | |
| transformer_options_copy["patches_replace"]["dit"][("double_block", 0)] = patch | |
| model_output = executor(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options_copy, **kwargs) | |
| return model_output | |
| return func | |
| class QwenImageTextBottomRight(io.ComfyNode): | |
| @classmethod | |
| def define_schema(cls): | |
| return io.Schema( | |
| node_id="QwenImageTextBottomRight", | |
| category="_for_testing", | |
| inputs=[ | |
| io.Model.Input("model"), | |
| ], | |
| outputs=[ | |
| io.Model.Output(), | |
| ], | |
| description="A node that adds a wrapper to position text prompts at the bottom right in Qwen image generation.", | |
| ) | |
| @classmethod | |
| def execute(cls, model) -> io.NodeOutput: | |
| new_model = model.clone() | |
| w = new_model.wrappers.setdefault(WrappersMP.DIFFUSION_MODEL, {}).setdefault(None, []) | |
| w.append(wrapper(model.model.diffusion_model)) | |
| return io.NodeOutput(new_model) | |
| class QwenExtension(ComfyExtension): | |
| @override | |
| async def get_node_list(self) -> list[type[io.ComfyNode]]: | |
| return [ | |
| TextEncodeQwenImageEditPlusFixPixelShift, | |
| QwenImageTextBottomRight, | |
| ] | |
| async def comfy_entrypoint() -> QwenExtension: | |
| return QwenExtension() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment