Skip to content

Instantly share code, notes, and snippets.

@blepping
Last active December 5, 2024 05:01
Show Gist options
  • Save blepping/df3d7038089899133900514c59a6fe95 to your computer and use it in GitHub Desktop.
Save blepping/df3d7038089899133900514c59a6fe95 to your computer and use it in GitHub Desktop.
ComfyUI sampler wrapper that saves all model predictions
# ComfyUI sampler wrapper that saves the state of denoised for every model call.
# This ignores whatever the sampler actually returns. You will get a result of
# batch_size * times_model_was_called latents.
from __future__ import annotations
import torch
from comfy.samplers import KSAMPLER
from comfy.model_management import device_supports_non_blocking
def denoised_history_sampler(
model: object,
x: torch.Tensor,
sigmas: torch.Tensor,
*,
dhs_options: tuple,
**kwargs: dict,
) -> torch.Tensor:
wrapped_sampler, append_result = dhs_options
dn_list = []
cuda = getattr(torch, "cuda", None)
non_blocking = cuda is not None and device_supports_non_blocking(x.device)
def model_wrapper(x: torch.Tensor, sigma: torch.Tensor, **extra_args: dict):
denoised = model(x, sigma, **extra_args)
dn_list.append(denoised.detach().clone().to("cpu", non_blocking=non_blocking))
return denoised
for k in (
"inner_model",
"sigmas",
):
if hasattr(model, k):
setattr(model_wrapper, k, getattr(model, k))
result = wrapped_sampler.sampler_function(
model_wrapper,
x,
sigmas,
**kwargs,
**wrapped_sampler.extra_options,
)
if cuda is not None:
cuda.synchronize(x.device)
if append_result:
dn_list.append(result.cpu())
return torch.cat(dn_list, dim=0)
class DenoisedHistorySamplerNode:
DESCRIPTION = "This sampler wrapper saves the state of denoised each time the model is called. The output from the sampler will be batch_size * model calls, with the last batch_size items being the model result from the last step. If you enable append_sampler_result then you will get an additional batch_size latents."
CATEGORY = "sampling/custom_sampling/samplers"
RETURN_TYPES = ("SAMPLER",)
FUNCTION = "go"
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"sampler": ("SAMPLER",),
},
"optional": {
"append_sampler_result": ("BOOLEAN", {"default": False}),
},
}
@classmethod
def go(
cls,
*,
sampler: object,
append_sampler_result: bool = False,
) -> tuple:
return (
KSAMPLER(
denoised_history_sampler,
extra_options={ "dhs_options": (sampler, append_sampler_result) },
),
)
NODE_CLASS_MAPPINGS = {
"DenoisedHistorySampler": DenoisedHistorySamplerNode,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment