Created
November 7, 2025 12:59
-
-
Save laksjdjf/7ed7174eaf125918e7a5848a694500be to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing_extensions import override | |
| from comfy_api.latest import ComfyExtension, io | |
| import comfy | |
| class Noise_ReferenceAlign: | |
| def __init__(self, seed, reference_latents, strength): | |
| self.seed = seed | |
| self.reference_latents = reference_latents | |
| self.strength = strength | |
| def generate_noise(self, input_latent): | |
| latent = self.reference_latents["samples"].to(input_latent["samples"].device, dtype=input_latent["samples"].dtype) | |
| batch_inds = input_latent.get("batch_index", None) | |
| noise = comfy.sample.prepare_noise(latent, self.seed, batch_inds) | |
| latent_mean = latent.mean(dim=1, keepdim=True) | |
| noise_mean = noise.mean(dim=1, keepdim=True) | |
| noise = noise + (latent_mean - noise_mean) * self.strength | |
| return noise | |
| class ReferenceAlignNoise(io.ComfyNode): | |
| @classmethod | |
| def define_schema(cls): | |
| return io.Schema( | |
| node_id="ReferenceAlignNoise", | |
| category="sampling/custom_sampling/noise", | |
| inputs=[ | |
| io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff), | |
| io.Latent.Input("reference_latents"), | |
| io.Float.Input("strength", default=0.0, min=-1.0, max=1.0, step=0.01), | |
| ], | |
| outputs=[ | |
| io.Noise.Output(), | |
| ], | |
| description="Generates noise aligned to reference latents with adjustable strength.", | |
| ) | |
| @classmethod | |
| def execute(cls, seed: int, reference_latents, strength: float) -> io.NodeOutput: | |
| return io.NodeOutput(Noise_ReferenceAlign(seed, reference_latents, strength)) | |
| class ReferenceAlignNoiseExtension(ComfyExtension): | |
| @override | |
| async def get_node_list(self) -> list[type[io.ComfyNode]]: | |
| return [ | |
| ReferenceAlignNoise, | |
| ] | |
| async def comfy_entrypoint() -> ReferenceAlignNoiseExtension: | |
| return ReferenceAlignNoiseExtension() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment