Last active
December 5, 2024 05:01
-
-
Save blepping/df3d7038089899133900514c59a6fe95 to your computer and use it in GitHub Desktop.
ComfyUI sampler wrapper that saves all model predictions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ComfyUI sampler wrapper that saves the state of denoised for every model call. | |
# This ignores whatever the sampler actually returns. You will get a result of | |
# batch_size * times_model_was_called latents. | |
from __future__ import annotations | |
import torch | |
from comfy.samplers import KSAMPLER | |
from comfy.model_management import device_supports_non_blocking | |
def denoised_history_sampler( | |
model: object, | |
x: torch.Tensor, | |
sigmas: torch.Tensor, | |
*, | |
dhs_options: tuple, | |
**kwargs: dict, | |
) -> torch.Tensor: | |
wrapped_sampler, append_result = dhs_options | |
dn_list = [] | |
cuda = getattr(torch, "cuda", None) | |
non_blocking = cuda is not None and device_supports_non_blocking(x.device) | |
def model_wrapper(x: torch.Tensor, sigma: torch.Tensor, **extra_args: dict): | |
denoised = model(x, sigma, **extra_args) | |
dn_list.append(denoised.detach().clone().to("cpu", non_blocking=non_blocking)) | |
return denoised | |
for k in ( | |
"inner_model", | |
"sigmas", | |
): | |
if hasattr(model, k): | |
setattr(model_wrapper, k, getattr(model, k)) | |
result = wrapped_sampler.sampler_function( | |
model_wrapper, | |
x, | |
sigmas, | |
**kwargs, | |
**wrapped_sampler.extra_options, | |
) | |
if cuda is not None: | |
cuda.synchronize(x.device) | |
if append_result: | |
dn_list.append(result.cpu()) | |
return torch.cat(dn_list, dim=0) | |
class DenoisedHistorySamplerNode: | |
DESCRIPTION = "This sampler wrapper saves the state of denoised each time the model is called. The output from the sampler will be batch_size * model calls, with the last batch_size items being the model result from the last step. If you enable append_sampler_result then you will get an additional batch_size latents." | |
CATEGORY = "sampling/custom_sampling/samplers" | |
RETURN_TYPES = ("SAMPLER",) | |
FUNCTION = "go" | |
@classmethod | |
def INPUT_TYPES(cls) -> dict: | |
return { | |
"required": { | |
"sampler": ("SAMPLER",), | |
}, | |
"optional": { | |
"append_sampler_result": ("BOOLEAN", {"default": False}), | |
}, | |
} | |
@classmethod | |
def go( | |
cls, | |
*, | |
sampler: object, | |
append_sampler_result: bool = False, | |
) -> tuple: | |
return ( | |
KSAMPLER( | |
denoised_history_sampler, | |
extra_options={ "dhs_options": (sampler, append_sampler_result) }, | |
), | |
) | |
NODE_CLASS_MAPPINGS = { | |
"DenoisedHistorySampler": DenoisedHistorySamplerNode, | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment