Last active
May 17, 2025 16:01
-
-
Save blepping/3673f3425b5980bb8dfad1f0e499e35f to your computer and use it in GitHub Desktop.
APG implementation for ComfyUI
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# By https://github.com/blepping | |
# License: Apache2 | |
# Initial APG implementation referenced from https://arxiv.org/pdf/2410.02416 and https://github.com/ace-step/ACE-Step/blob/e5610345db9f450a855994169f4ca7a7b5fb4f1d/acestep/apg_guidance.py | |
# | |
# Changes: | |
# 250616: Removed alt2 mode, it was the same as positive momentum. Derp. | |
# 250616: New alt2 mode that blends history with the current diff. Added advanced YAML parameters input. | |
import math | |
import yaml | |
from enum import Enum, auto | |
from typing import NamedTuple | |
import torch | |
import torch.nn.functional as F | |
from tqdm import tqdm | |
from comfy.samplers import CFGGuider | |
import nodes | |
BLEND_MODES = None | |
def _ensure_blend_modes(): | |
global BLEND_MODES | |
if BLEND_MODES is None: | |
bleh = getattr(nodes, "_blepping_integrations", {}).get("bleh") | |
if bleh is not None: | |
BLEND_MODES = bleh.py.latent_utils.BLENDING_MODES | |
else: | |
BLEND_MODES = {"lerp": torch.lerp, "a_only": lambda a, _b, _t: a, "b_only": lambda _a, b, _t: b} | |
class UpdateMode(Enum): | |
DEFAULT = auto() | |
ALT1 = auto() | |
ALT2 = auto() | |
class APGConfig(NamedTuple): | |
start_sigma: float = math.inf | |
momentum: float = -0.5 | |
eta: float = 0.0 | |
apg_scale: float = 4.0 | |
norm_threshold: float = 2.5 | |
dims: tuple = (-2, -1) | |
update_mode: UpdateMode = UpdateMode.DEFAULT | |
update_blend_mode: str = "lerp" | |
cfg: float = 1.0 | |
apg_blend: float = 1.0 | |
apg_blend_mode: str = "lerp" | |
predict_image: bool = True | |
pre_cfg_mode: bool = False | |
@staticmethod | |
def fixup_param(k, v): | |
if k == "dims": | |
if isinstance(v, str): | |
dims = v.strip() | |
return tuple(int(d) for d in dims.split(",")) if dims else () | |
else: | |
return v | |
if k == "update_mode": | |
return getattr(UpdateMode, v.strip().upper()) | |
if k == "start_sigma": | |
return math.inf if v < 0 else float(v) | |
return v | |
@classmethod | |
def build(cls, *, mode: str = "pure_apg", **params: dict): | |
pre_mode, update_mode = mode.split("_", 1) | |
params["pre_cfg_mode"] = pre_mode == "pre" | |
params["update_mode"] = "default" if update_mode in {"apg", "cfg"} else update_mode | |
fields = frozenset(cls._fields) | |
params = {k: cls.fixup_param(k, v) for k,v in params.items() if k in fields} | |
defaults = cls() | |
params |= {k: getattr(defaults, k) for k in fields if k not in params} | |
return cls(**params) | |
def __str__(self): | |
if self.apg_blend == 0 or self.apg_scale == 0: | |
fields = ("start_sigma", "cfg") | |
else: | |
fields = self._fields | |
pretty_fields = ", ".join(f"{k}={getattr(self, k)}" for k in fields) | |
return f"APGConfig({pretty_fields})" | |
class APG: | |
def __init__(self, config: APGConfig): | |
self.config = config | |
self.running_average = 0.0 | |
def __getattr__(self, k): | |
return getattr(self.config, k) | |
def update(self, val: torch.Tensor) -> torch.Tensor: | |
if self.momentum == 0: | |
return val | |
avg = self.running_average | |
if isinstance(avg, float) or (isinstance(avg, torch.Tensor) and (avg.dtype != val.dtype or avg.device != avg.device or avg.shape != val.shape)): | |
self.running_average = val.clone() | |
return self.running_average | |
result = val + self.momentum * avg | |
if self.update_mode == UpdateMode.ALT1: | |
self.running_average = val + abs(self.momentum) * avg | |
elif self.update_mode == UpdateMode.ALT2: | |
blend = BLEND_MODES.get(self.update_blend_mode) | |
if blend is None: | |
raise ValueError("Unknown blend mode") | |
result = blend(val, avg.neg() if self.momentum < 0 else avg, abs(self.momentum)) | |
self.running_average = blend(val, avg, abs(self.momentum)) | |
else: | |
self.running_average = result | |
return result | |
def reset(self): | |
self.running_average = 0.0 | |
def project(self, v0_orig: torch.Tensor, v1_orig: torch.Tensor) -> tuple: | |
if v0_orig.device.type == "mps": | |
v0, v1 = v0_orig.cpu().double(), v1_orig.cpu().double() | |
else: | |
v0, v1 = v0_orig.double(), v1_orig.double() | |
v1 = F.normalize(v1, dim=self.dims) | |
v0_p = (v0 * v1).sum(dim=self.dims, keepdim=True) * v1 | |
v0_o = v0 - v0_p | |
return v0_p.to(dtype=v0_orig.dtype).to(v0_orig.device.type), v0_o.to(dtype=v0_orig.dtype).to(v0_orig.device.type) | |
def apg(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor: | |
pred_diff = self.update(cond - uncond) | |
if self.norm_threshold > 0: | |
diff_norm = pred_diff.norm(p=2, dim=self.dims, keepdim=True) | |
scale_factor = torch.minimum(torch.ones_like(pred_diff), self.norm_threshold / diff_norm) | |
pred_diff = pred_diff * scale_factor | |
diff_p, diff_o = self.project(pred_diff, cond) | |
update = diff_o | |
if self.eta != 0: | |
update += self.eta * diff_p | |
return update | |
def cfg_function(self, args: dict) -> torch.Tensor: | |
sigma = args["sigma"] | |
cond, uncond = (args["cond_denoised"], args["uncond_denoised"]) if self.predict_image else (args["cond"], args["uncond"]) | |
result = cond + (self.apg_scale - 1.0) * self.apg(cond, uncond) | |
return args["input"] - result if self.predict_image else result | |
def pre_cfg_function(self, args: dict) -> list: | |
conds_out = args["conds_out"] | |
if len(conds_out) < 2: | |
return conds_out | |
cond, uncond = conds_out[:2] | |
update = self.apg(cond, uncond) | |
cond_apg = uncond + update + (cond - uncond) / self.apg_scale | |
return [cond_apg, *conds_out[1:]] | |
class APGGuider(CFGGuider): | |
def __init__(self, model, *, positive, negative, rules: tuple, params: dict) -> None: | |
super().__init__(model) | |
self.set_conds(positive, negative) | |
self.set_cfg(1.0) | |
self.apg_rules = tuple(APG(rule_config) for rule_config in rules) | |
self.apg_params = params | |
self.apg_verbose = params.get("verbose", False) == True | |
if self.apg_verbose: | |
tqdm.write(f"* APG rules: {rules}") | |
def apg_reset(self, *, exclude=None): | |
for apg_rule in self.apg_rules: | |
if apg_rule is not exclude: | |
apg_rule.reset() | |
def apg_get_match(self, sigma: float) -> APG: | |
for rule in self.apg_rules: | |
if sigma <= rule.start_sigma: | |
return rule | |
raise RuntimeError("Could not get APG rule") | |
def outer_sample(self, *args: list, **kwargs: dict) -> torch.Tensor: | |
self.apg_reset() | |
result = super().outer_sample(*args, **kwargs) | |
self.apg_reset() | |
return result | |
def predict_noise(self, x: torch.Tensor, timestep, model_options=None, seed=None, **kwargs: dict) -> torch.Tensor: | |
if model_options is None: | |
model_options = {} | |
sigma = timestep.max().detach().cpu().item() if isinstance(timestep, torch.Tensor) else timestep | |
rule = self.apg_get_match(sigma) | |
self.apg_reset(exclude=rule) | |
matched = rule.apg_blend != 0 and rule.apg_scale != 0 | |
if self.apg_verbose: | |
tqdm.write(f"* APG rule matched: sigma={sigma}, rule={rule.config}") | |
if matched: | |
model_options = model_options | {"disable_cfg1_optimization": True} | |
if rule.pre_cfg_mode: | |
pre_cfg_handlers = model_options.get("sampler_pre_cfg_function", []).copy() | |
pre_cfg_handlers.append(rule.pre_cfg_function) | |
model_options["sampler_pre_cfg_function"] = pre_cfg_handlers | |
cfg = rule.apg_scale | |
else: | |
model_options["sampler_cfg_function"] = rule.cfg_function | |
cfg = rule.cfg | |
else: | |
cfg = rule.cfg | |
orig_cfg = self.cfg | |
try: | |
self.cfg = cfg | |
result = super().predict_noise(x, timestep, model_options=model_options, seed=seed, **kwargs) | |
finally: | |
self.cfg = orig_cfg | |
return result | |
class APGGuiderNode: | |
CATEGORY = "sampling/custom_sampling/guiders" | |
FUNCTION = "go" | |
RETURN_TYPES = ("GUIDER",) | |
DESCRIPTION = "Simple APG guider" | |
@classmethod | |
def INPUT_TYPES(cls) -> dict: | |
_ensure_blend_modes() | |
return { | |
"required": { | |
"model": ("MODEL",), | |
"positive": ("CONDITIONING",), | |
"negative": ("CONDITIONING",), | |
"apg_scale": ("FLOAT", {"default": 4.0, "min": -1000.0, "max": 1000.0, "tooltip": "If apg_scale is exactly 0, you will get cfg_before up to and including start_sigma, then cfg_after."}), | |
"cfg_before": ("FLOAT", {"default": 1.0, "min": 1.0, "max": 1000.0, "tooltip": "CFG value to use before APG.",}), | |
"cfg_after": ("FLOAT", {"default": 1.0, "min": 1.0, "max": 1000.0, "tooltip": "CFG value to use after APG.",}), | |
"eta": ("FLOAT", {"default": 0.0, "min": -1000.0, "max": 1000.0}), | |
"norm_threshold": ("FLOAT", {"default": 2.5, "min": -1000.0, "max": 1000.0}), | |
"momentum": ("FLOAT", {"default": -0.75, "min": -1000.0, "max": 1000.0}), | |
"start_sigma": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "tooltip": "Any negative value means no restriction"}), | |
"end_sigma": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "tooltip": "Any negative value means no restriction"}), | |
"dims": ("STRING", {"default": "-1, -2", "tooltip": "Comma-separated list of dimensions to normalize for guidance. No error checking."}), | |
"predict_image": ("BOOLEAN", {"default": True, "tooltip": "Determines whether APG guides a prediction of the image or a predict of the noise. Using image may be better. Only has an effect in pure_apg mode."}), | |
"mode": (("pure_apg", "pre_cfg", "pure_alt1", "pre_alt1", "pure_alt2", "pre_alt2"), {"default": "pure_apg", "tooltip": "pure_apg uses exactly the APG calculation when APG mode is active. pre_cfg mode works like the built-in ComfyUI node and will replace cond with APG and then run CFG with apg_scale as the CFG scale. Other modes are experimental and may or may not work."}), | |
}, | |
"optional": { | |
"yaml_parameters_opt": ( | |
"STRING", | |
{ | |
"tooltip": "Allows specifying custom parameters via YAML. Note: When specifying parameters this way, there is no error checking.", | |
"forceInput": True, | |
}, | |
), | |
}, | |
} | |
@classmethod | |
def go( | |
cls, *, | |
model, positive, negative, apg_scale, cfg_before, cfg_after, eta, | |
norm_threshold, momentum, start_sigma, end_sigma, dims, predict_image, | |
mode, yaml_parameters_opt=None, | |
) -> tuple: | |
if yaml_parameters_opt is None: | |
yaml_parameters_opt = "" | |
yaml_parameters_opt = yaml_parameters_opt.strip() | |
if yaml_parameters_opt: | |
params = yaml.safe_load(yaml_parameters_opt) | |
if not params: | |
params = {} | |
elif isinstance(params, (tuple, list)): | |
params = {"rules": tuple(params)} | |
elif not isinstance(params, dict): | |
raise TypeError("Bad format for YAML options") | |
else: | |
params = {} | |
rules = tuple(params.pop("rules", ())) | |
if not rules: | |
rules = [] | |
rules.append(APGConfig.build( | |
cfg=1.0, | |
apg_blend=1.0, | |
start_sigma=start_sigma if start_sigma >= 0 else math.inf, | |
momentum=momentum, | |
eta=eta, | |
scale=apg_scale, | |
norm_threshold=norm_threshold, | |
dims=dims, | |
predict_image=predict_image, | |
mode=mode, | |
)) | |
if end_sigma > 0: | |
end_sigma = math.nextafter(end_sigma, -math.inf) | |
if end_sigma > 0: | |
rules.append(APGConfig.build( | |
cfg=cfg_after, | |
start_sigma=end_sigma, | |
apg_blend=0.0, | |
)) | |
else: | |
rules = tuple(cfg for cfg in (APGConfig.build(**rule) for rule in rules) if cfg.start_sigma > 0) | |
rules = sorted(rules, key=lambda rule: rule.start_sigma) | |
if not rules or rules[-1].start_sigma < math.inf: | |
rules = (*rules, APGConfig.build(cfg=cfg_before, start_sigma=math.inf, apg_blend=0.0)) | |
guider = APGGuider( | |
model, | |
positive=positive, | |
negative=negative, | |
rules=rules, | |
params=params, | |
) | |
return (guider,) | |
NODE_CLASS_MAPPINGS = {} | |
if "BlehAPGGuider" not in nodes.NODE_CLASS_MAPPINGS: | |
NODE_CLASS_MAPPINGS["BlehAPGGuider"] = APGGuiderNode |
@victornpb Thanks for posting an example! Unfortunately I probably broke it with my recent changes that added a few new parameters. Setting predict_image
to false should make it work like it did before (but I think that's worse). Setting the mode to pre_cfg
should make it (mostly) work like the built-in APG node that ComfyUI just added.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
workflow example: