Last active
February 17, 2025 06:12
-
-
Save laksjdjf/435c512bc19636e9c9af4ee7bea9eb86 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
https://arxiv.org/abs/2312.00858 | |
1. put this file in ComfyUI/custom_nodes | |
2. load node from <loaders> | |
start_step, end_step: apply this method when the timestep is between start_step and end_step | |
cache_interval: interval of caching (1 means no caching) | |
cache_depth: depth of caching | |
''' | |
import torch | |
from comfy.ldm.modules.diffusionmodules.openaimodel import forward_timestep_embed, timestep_embedding, th, apply_control | |
class DeepCache: | |
@classmethod | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"model": ("MODEL", ), | |
"cache_interval": ("INT", { | |
"default": 5, | |
"min": 1, | |
"max": 1000, | |
"step": 1, | |
"display": "number" | |
}), | |
"cache_depth": ("INT", { | |
"default": 3, | |
"min": 0, | |
"max": 12, | |
"step": 1, | |
"display": "number" | |
}), | |
"start_step": ("INT", { | |
"default": 0, | |
"min": 0, | |
"max": 1000, | |
"step": 1, | |
"display": "number" | |
}), | |
"end_step": ("INT", { | |
"default": 1000, | |
"min": 0, | |
"max": 1000, | |
"step": 0.1, | |
}), | |
}, | |
} | |
RETURN_TYPES = ("MODEL", ) | |
FUNCTION = "apply" | |
CATEGORY = "loaders" | |
def apply(self, model, cache_interval, cache_depth, start_step, end_step): | |
new_model = model.clone() | |
current_t = -1 | |
current_step = -1 | |
cache_h = None | |
def apply_model(model_function, kwargs): | |
nonlocal current_t, current_step, cache_h | |
xa = kwargs["input"] | |
t = kwargs["timestep"] | |
c_concat = kwargs["c"].get("c_concat", None) | |
c_crossattn = kwargs["c"].get("c_crossattn", None) | |
y = kwargs["c"].get("y", None) | |
control = kwargs["c"].get("control", None) | |
transformer_options = kwargs["c"].get("transformer_options", None) | |
# https://github.com/comfyanonymous/ComfyUI/blob/629e4c552cc30a75d2756cbff8095640af3af163/comfy/model_base.py#L51-L69 | |
sigma = t | |
xc = new_model.model.model_sampling.calculate_input(sigma, xa) | |
if c_concat is not None: | |
xc = torch.cat([xc] + [c_concat], dim=1) | |
context = c_crossattn | |
dtype = new_model.model.get_dtype() | |
xc = xc.to(dtype) | |
t = new_model.model.model_sampling.timestep(t).float() | |
context = context.to(dtype) | |
extra_conds = {} | |
for o in kwargs: | |
extra = kwargs[o] | |
if hasattr(extra, "to"): | |
extra = extra.to(dtype) | |
extra_conds[o] = extra | |
x = xc | |
timesteps = t | |
y = None if y is None else y.to(dtype) | |
transformer_options["original_shape"] = list(x.shape) | |
transformer_options["current_index"] = 0 | |
transformer_patches = transformer_options.get("patches", {}) | |
""" | |
Apply the model to an input batch. | |
:param x: an [N x C x ...] Tensor of inputs. | |
:param timesteps: a 1-D batch of timesteps. | |
:param context: conditioning plugged in via crossattn | |
:param y: an [N] Tensor of labels, if class-conditional. | |
:return: an [N x C x ...] Tensor of outputs. | |
""" | |
unet = new_model.model.diffusion_model | |
# unet次回実行はtimestepが上がってると仮定・・Refiner等でエラーが起きるかも | |
if t[0].item() > current_t: | |
current_step = -1 | |
current_t = t[0].item() | |
apply = 1000 - end_step <= current_t <= 1000 - start_step # tは999->0 | |
if apply: | |
current_step += 1 | |
else: | |
current_step = -1 | |
current_t = t[0].item() | |
# https://github.com/comfyanonymous/ComfyUI/blob/629e4c552cc30a75d2756cbff8095640af3af163/comfy/ldm/modules/diffusionmodules/openaimodel.py#L598-L659 | |
assert (y is not None) == ( | |
unet.num_classes is not None | |
), "must specify y if and only if the model is class-conditional" | |
hs = [] | |
t_emb = timestep_embedding(timesteps, unet.model_channels, repeat_only=False).to(unet.dtype) | |
emb = unet.time_embed(t_emb) | |
if unet.num_classes is not None: | |
assert y.shape[0] == x.shape[0] | |
emb = emb + unet.label_emb(y) | |
h = x.type(unet.dtype) | |
for id, module in enumerate(unet.input_blocks): | |
transformer_options["block"] = ("input", id) | |
h = forward_timestep_embed(module, h, emb, context, transformer_options) | |
h = apply_control(h, control, 'input') | |
if "input_block_patch" in transformer_patches: | |
patch = transformer_patches["input_block_patch"] | |
for p in patch: | |
h = p(h, transformer_options) | |
hs.append(h) | |
if "input_block_patch_after_skip" in transformer_patches: | |
patch = transformer_patches["input_block_patch_after_skip"] | |
for p in patch: | |
h = p(h, transformer_options) | |
if id == cache_depth and apply: | |
if not current_step % cache_interval == 0: | |
break # cache位置以降はスキップ | |
if current_step % cache_interval == 0 or not apply: | |
transformer_options["block"] = ("middle", 0) | |
h = forward_timestep_embed(unet.middle_block, h, emb, context, transformer_options) | |
h = apply_control(h, control, 'middle') | |
for id, module in enumerate(unet.output_blocks): | |
if id < len(unet.output_blocks) - cache_depth - 1 and apply: | |
if not current_step % cache_interval == 0: | |
continue # cache位置以前はスキップ | |
if id == len(unet.output_blocks) - cache_depth -1 and apply: | |
if current_step % cache_interval == 0: | |
cache_h = h # cache | |
else: | |
h = cache_h # load cache | |
transformer_options["block"] = ("output", id) | |
hsp = hs.pop() | |
hsp = apply_control(hsp, control, 'output') | |
if "output_block_patch" in transformer_patches: | |
patch = transformer_patches["output_block_patch"] | |
for p in patch: | |
h, hsp = p(h, hsp, transformer_options) | |
h = th.cat([h, hsp], dim=1) | |
del hsp | |
if len(hs) > 0: | |
output_shape = hs[-1].shape | |
else: | |
output_shape = None | |
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape) | |
h = h.type(x.dtype) | |
if unet.predict_codebook_ids: | |
model_output = unet.id_predictor(h) | |
else: | |
model_output = unet.out(h) | |
return new_model.model.model_sampling.calculate_denoised(sigma, model_output, xa) | |
new_model.set_model_unet_function_wrapper(apply_model) | |
return (new_model, ) | |
NODE_CLASS_MAPPINGS = { | |
"DeepCache": DeepCache, | |
} | |
NODE_DISPLAY_NAME_MAPPINGS = { | |
"DeepCache": "Deep Cache", | |
} | |
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] |
I love this script, and would like to know if it is possible to integrate this. I'm a rookie so not able to do it myself.
import multiprocessing
Define a function to perform the computation for a single iteration
def process_iteration(args):
# Perform the computation based on the given arguments
# ...
Modify the relevant part of your script where the loop or computation occurs
def parallel_computation():
# Determine the number of iterations or tasks
num_iterations = 100 # Replace with the actual number
# Define the number of processes to use
num_processes = multiprocessing.cpu_count() # Use available CPU cores or specify a number
# Create a multiprocessing Pool
with multiprocessing.Pool(processes=num_processes) as pool:
# Generate arguments for each iteration/task (modify based on your code structure)
iterations_args = [(arg1, arg2, ...) for i in range(num_iterations)]
# ^ Replace (arg1, arg2, ...) with the actual arguments needed for each iteration
# Map the function to the arguments for parallel execution
pool.map(process_iteration, iterations_args)
Call the function to execute the parallel computation
if name == "main":
parallel_computation()
thanks for this man, any change the xl or video versions could be implemented? appreciate the share!
It should be compatible with sdxl.
Is this theoretically compatible with Perturbed Attention Guidance? The current implementations in ComfyUI do not work together, but I wonder if it is possible?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, many thanks for this. Is there any possibility of having also the implementations for SVD and SDXL, thanks again and happy holidays.