Skip to content

Instantly share code, notes, and snippets.

@blepping
Last active May 17, 2025 16:01
Show Gist options
  • Save blepping/3673f3425b5980bb8dfad1f0e499e35f to your computer and use it in GitHub Desktop.
Save blepping/3673f3425b5980bb8dfad1f0e499e35f to your computer and use it in GitHub Desktop.
APG implementation for ComfyUI
# By https://github.com/blepping
# License: Apache2
# Initial APG implementation referenced from https://arxiv.org/pdf/2410.02416 and https://github.com/ace-step/ACE-Step/blob/e5610345db9f450a855994169f4ca7a7b5fb4f1d/acestep/apg_guidance.py
#
# Changes:
# 250616: Removed alt2 mode, it was the same as positive momentum. Derp.
# 250616: New alt2 mode that blends history with the current diff. Added advanced YAML parameters input.
import math
import yaml
from enum import Enum, auto
from typing import NamedTuple
import torch
import torch.nn.functional as F
from tqdm import tqdm
from comfy.samplers import CFGGuider
import nodes
BLEND_MODES = None
def _ensure_blend_modes():
global BLEND_MODES
if BLEND_MODES is None:
bleh = getattr(nodes, "_blepping_integrations", {}).get("bleh")
if bleh is not None:
BLEND_MODES = bleh.py.latent_utils.BLENDING_MODES
else:
BLEND_MODES = {"lerp": torch.lerp, "a_only": lambda a, _b, _t: a, "b_only": lambda _a, b, _t: b}
class UpdateMode(Enum):
DEFAULT = auto()
ALT1 = auto()
ALT2 = auto()
class APGConfig(NamedTuple):
start_sigma: float = math.inf
momentum: float = -0.5
eta: float = 0.0
apg_scale: float = 4.0
norm_threshold: float = 2.5
dims: tuple = (-2, -1)
update_mode: UpdateMode = UpdateMode.DEFAULT
update_blend_mode: str = "lerp"
cfg: float = 1.0
apg_blend: float = 1.0
apg_blend_mode: str = "lerp"
predict_image: bool = True
pre_cfg_mode: bool = False
@staticmethod
def fixup_param(k, v):
if k == "dims":
if isinstance(v, str):
dims = v.strip()
return tuple(int(d) for d in dims.split(",")) if dims else ()
else:
return v
if k == "update_mode":
return getattr(UpdateMode, v.strip().upper())
if k == "start_sigma":
return math.inf if v < 0 else float(v)
return v
@classmethod
def build(cls, *, mode: str = "pure_apg", **params: dict):
pre_mode, update_mode = mode.split("_", 1)
params["pre_cfg_mode"] = pre_mode == "pre"
params["update_mode"] = "default" if update_mode in {"apg", "cfg"} else update_mode
fields = frozenset(cls._fields)
params = {k: cls.fixup_param(k, v) for k,v in params.items() if k in fields}
defaults = cls()
params |= {k: getattr(defaults, k) for k in fields if k not in params}
return cls(**params)
def __str__(self):
if self.apg_blend == 0 or self.apg_scale == 0:
fields = ("start_sigma", "cfg")
else:
fields = self._fields
pretty_fields = ", ".join(f"{k}={getattr(self, k)}" for k in fields)
return f"APGConfig({pretty_fields})"
class APG:
def __init__(self, config: APGConfig):
self.config = config
self.running_average = 0.0
def __getattr__(self, k):
return getattr(self.config, k)
def update(self, val: torch.Tensor) -> torch.Tensor:
if self.momentum == 0:
return val
avg = self.running_average
if isinstance(avg, float) or (isinstance(avg, torch.Tensor) and (avg.dtype != val.dtype or avg.device != avg.device or avg.shape != val.shape)):
self.running_average = val.clone()
return self.running_average
result = val + self.momentum * avg
if self.update_mode == UpdateMode.ALT1:
self.running_average = val + abs(self.momentum) * avg
elif self.update_mode == UpdateMode.ALT2:
blend = BLEND_MODES.get(self.update_blend_mode)
if blend is None:
raise ValueError("Unknown blend mode")
result = blend(val, avg.neg() if self.momentum < 0 else avg, abs(self.momentum))
self.running_average = blend(val, avg, abs(self.momentum))
else:
self.running_average = result
return result
def reset(self):
self.running_average = 0.0
def project(self, v0_orig: torch.Tensor, v1_orig: torch.Tensor) -> tuple:
if v0_orig.device.type == "mps":
v0, v1 = v0_orig.cpu().double(), v1_orig.cpu().double()
else:
v0, v1 = v0_orig.double(), v1_orig.double()
v1 = F.normalize(v1, dim=self.dims)
v0_p = (v0 * v1).sum(dim=self.dims, keepdim=True) * v1
v0_o = v0 - v0_p
return v0_p.to(dtype=v0_orig.dtype).to(v0_orig.device.type), v0_o.to(dtype=v0_orig.dtype).to(v0_orig.device.type)
def apg(self, cond: torch.Tensor, uncond: torch.Tensor) -> torch.Tensor:
pred_diff = self.update(cond - uncond)
if self.norm_threshold > 0:
diff_norm = pred_diff.norm(p=2, dim=self.dims, keepdim=True)
scale_factor = torch.minimum(torch.ones_like(pred_diff), self.norm_threshold / diff_norm)
pred_diff = pred_diff * scale_factor
diff_p, diff_o = self.project(pred_diff, cond)
update = diff_o
if self.eta != 0:
update += self.eta * diff_p
return update
def cfg_function(self, args: dict) -> torch.Tensor:
sigma = args["sigma"]
cond, uncond = (args["cond_denoised"], args["uncond_denoised"]) if self.predict_image else (args["cond"], args["uncond"])
result = cond + (self.apg_scale - 1.0) * self.apg(cond, uncond)
return args["input"] - result if self.predict_image else result
def pre_cfg_function(self, args: dict) -> list:
conds_out = args["conds_out"]
if len(conds_out) < 2:
return conds_out
cond, uncond = conds_out[:2]
update = self.apg(cond, uncond)
cond_apg = uncond + update + (cond - uncond) / self.apg_scale
return [cond_apg, *conds_out[1:]]
class APGGuider(CFGGuider):
def __init__(self, model, *, positive, negative, rules: tuple, params: dict) -> None:
super().__init__(model)
self.set_conds(positive, negative)
self.set_cfg(1.0)
self.apg_rules = tuple(APG(rule_config) for rule_config in rules)
self.apg_params = params
self.apg_verbose = params.get("verbose", False) == True
if self.apg_verbose:
tqdm.write(f"* APG rules: {rules}")
def apg_reset(self, *, exclude=None):
for apg_rule in self.apg_rules:
if apg_rule is not exclude:
apg_rule.reset()
def apg_get_match(self, sigma: float) -> APG:
for rule in self.apg_rules:
if sigma <= rule.start_sigma:
return rule
raise RuntimeError("Could not get APG rule")
def outer_sample(self, *args: list, **kwargs: dict) -> torch.Tensor:
self.apg_reset()
result = super().outer_sample(*args, **kwargs)
self.apg_reset()
return result
def predict_noise(self, x: torch.Tensor, timestep, model_options=None, seed=None, **kwargs: dict) -> torch.Tensor:
if model_options is None:
model_options = {}
sigma = timestep.max().detach().cpu().item() if isinstance(timestep, torch.Tensor) else timestep
rule = self.apg_get_match(sigma)
self.apg_reset(exclude=rule)
matched = rule.apg_blend != 0 and rule.apg_scale != 0
if self.apg_verbose:
tqdm.write(f"* APG rule matched: sigma={sigma}, rule={rule.config}")
if matched:
model_options = model_options | {"disable_cfg1_optimization": True}
if rule.pre_cfg_mode:
pre_cfg_handlers = model_options.get("sampler_pre_cfg_function", []).copy()
pre_cfg_handlers.append(rule.pre_cfg_function)
model_options["sampler_pre_cfg_function"] = pre_cfg_handlers
cfg = rule.apg_scale
else:
model_options["sampler_cfg_function"] = rule.cfg_function
cfg = rule.cfg
else:
cfg = rule.cfg
orig_cfg = self.cfg
try:
self.cfg = cfg
result = super().predict_noise(x, timestep, model_options=model_options, seed=seed, **kwargs)
finally:
self.cfg = orig_cfg
return result
class APGGuiderNode:
CATEGORY = "sampling/custom_sampling/guiders"
FUNCTION = "go"
RETURN_TYPES = ("GUIDER",)
DESCRIPTION = "Simple APG guider"
@classmethod
def INPUT_TYPES(cls) -> dict:
_ensure_blend_modes()
return {
"required": {
"model": ("MODEL",),
"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"apg_scale": ("FLOAT", {"default": 4.0, "min": -1000.0, "max": 1000.0, "tooltip": "If apg_scale is exactly 0, you will get cfg_before up to and including start_sigma, then cfg_after."}),
"cfg_before": ("FLOAT", {"default": 1.0, "min": 1.0, "max": 1000.0, "tooltip": "CFG value to use before APG.",}),
"cfg_after": ("FLOAT", {"default": 1.0, "min": 1.0, "max": 1000.0, "tooltip": "CFG value to use after APG.",}),
"eta": ("FLOAT", {"default": 0.0, "min": -1000.0, "max": 1000.0}),
"norm_threshold": ("FLOAT", {"default": 2.5, "min": -1000.0, "max": 1000.0}),
"momentum": ("FLOAT", {"default": -0.75, "min": -1000.0, "max": 1000.0}),
"start_sigma": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "tooltip": "Any negative value means no restriction"}),
"end_sigma": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "tooltip": "Any negative value means no restriction"}),
"dims": ("STRING", {"default": "-1, -2", "tooltip": "Comma-separated list of dimensions to normalize for guidance. No error checking."}),
"predict_image": ("BOOLEAN", {"default": True, "tooltip": "Determines whether APG guides a prediction of the image or a predict of the noise. Using image may be better. Only has an effect in pure_apg mode."}),
"mode": (("pure_apg", "pre_cfg", "pure_alt1", "pre_alt1", "pure_alt2", "pre_alt2"), {"default": "pure_apg", "tooltip": "pure_apg uses exactly the APG calculation when APG mode is active. pre_cfg mode works like the built-in ComfyUI node and will replace cond with APG and then run CFG with apg_scale as the CFG scale. Other modes are experimental and may or may not work."}),
},
"optional": {
"yaml_parameters_opt": (
"STRING",
{
"tooltip": "Allows specifying custom parameters via YAML. Note: When specifying parameters this way, there is no error checking.",
"forceInput": True,
},
),
},
}
@classmethod
def go(
cls, *,
model, positive, negative, apg_scale, cfg_before, cfg_after, eta,
norm_threshold, momentum, start_sigma, end_sigma, dims, predict_image,
mode, yaml_parameters_opt=None,
) -> tuple:
if yaml_parameters_opt is None:
yaml_parameters_opt = ""
yaml_parameters_opt = yaml_parameters_opt.strip()
if yaml_parameters_opt:
params = yaml.safe_load(yaml_parameters_opt)
if not params:
params = {}
elif isinstance(params, (tuple, list)):
params = {"rules": tuple(params)}
elif not isinstance(params, dict):
raise TypeError("Bad format for YAML options")
else:
params = {}
rules = tuple(params.pop("rules", ()))
if not rules:
rules = []
rules.append(APGConfig.build(
cfg=1.0,
apg_blend=1.0,
start_sigma=start_sigma if start_sigma >= 0 else math.inf,
momentum=momentum,
eta=eta,
scale=apg_scale,
norm_threshold=norm_threshold,
dims=dims,
predict_image=predict_image,
mode=mode,
))
if end_sigma > 0:
end_sigma = math.nextafter(end_sigma, -math.inf)
if end_sigma > 0:
rules.append(APGConfig.build(
cfg=cfg_after,
start_sigma=end_sigma,
apg_blend=0.0,
))
else:
rules = tuple(cfg for cfg in (APGConfig.build(**rule) for rule in rules) if cfg.start_sigma > 0)
rules = sorted(rules, key=lambda rule: rule.start_sigma)
if not rules or rules[-1].start_sigma < math.inf:
rules = (*rules, APGConfig.build(cfg=cfg_before, start_sigma=math.inf, apg_blend=0.0))
guider = APGGuider(
model,
positive=positive,
negative=negative,
rules=rules,
params=params,
)
return (guider,)
NODE_CLASS_MAPPINGS = {}
if "BlehAPGGuider" not in nodes.NODE_CLASS_MAPPINGS:
NODE_CLASS_MAPPINGS["BlehAPGGuider"] = APGGuiderNode
@victornpb
Copy link

workflow example:

{
  "id":"88ac5dad-efd7-40bb-84fe-fbaefdee1fa9","revision":0,"last_node_id":72,"last_link_id":172,"nodes":[
  {"id":62,"type":"BlehAPGGuider","pos":[1194.76220703125,251.79566955566406],"size":[280.77679443359375,299.2467041015625],"flags":{},"order":8,"mode":0,"inputs":[{"name":"model","type":"MODEL","link":157},{"name":"positive","type":"CONDITIONING","link":154},{"name":"negative","type":"CONDITIONING","link":155}],"outputs":[{"name":"GUIDER","type":"GUIDER","links":[150]}],"properties":{"Node name for S&R":"BlehAPGGuider"},"widgets_values":[10,1,1,0,5,-0.7000000000000001,-1,0.5,"-1, -2"]},
  {"id":47,"type":"ConditioningZeroOut","pos":[958.5279541015625,600.8081665039062],"size":[197.712890625,26],"flags":{},"order":7,"mode":0,"inputs":[{"name":"conditioning","type":"CONDITIONING","link":130}],"outputs":[{"name":"CONDITIONING","type":"CONDITIONING","links":[155]}],"properties":{"cnr_id":"comfy-core","ver":"0.3.34","Node name for S&R":"ConditioningZeroOut"},"widgets_values":[]},
  {"id":63,"type":"SamplerCustomAdvanced","pos":[1543.7059326171875,575.8972778320312],"size":[202.53378295898438,106],"flags":{},"order":9,"mode":0,"inputs":[{"name":"noise","type":"NOISE","link":153},{"name":"guider","type":"GUIDER","link":150},{"name":"sampler","type":"SAMPLER","link":152},{"name":"sigmas","type":"SIGMAS","link":151},{"name":"latent_image","type":"LATENT","link":170}],"outputs":[{"name":"output","type":"LATENT","links":[]},{"name":"denoised_output","type":"LATENT","links":[168]}],"properties":{"cnr_id":"comfy-core","ver":"0.3.34","Node name for S&R":"SamplerCustomAdvanced"},"widgets_values":[]},
  {"id":45,"type":"ModelSamplingSD3","pos":[460.5631408691406,-125.1843032836914],"size":[432.8270263671875,59.71588897705078],"flags":{},"order":4,"mode":0,"inputs":[{"name":"model","type":"MODEL","link":111}],"outputs":[{"name":"MODEL","type":"MODEL","links":[157,167]}],"properties":{"cnr_id":"comfy-core","ver":"0.3.34","Node name for S&R":"ModelSamplingSD3"},"widgets_values":[4.000000000000001]},
  {"id":40,"type":"CheckpointLoaderSimple","pos":[14.46983528137207,53.56587600708008],"size":[375,98],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[{"name":"MODEL","type":"MODEL","links":[111]},{"name":"CLIP","type":"CLIP","links":[80]},{"name":"VAE","type":"VAE","links":[161]}],"properties":{"cnr_id":"comfy-core","ver":"0.3.34","Node name for S&R":"CheckpointLoaderSimple"},"widgets_values":["ace_step_v1_3.5b.safetensors"]},
  {"id":65,"type":"KSamplerSelect","pos":[1196.9697265625,152.96026611328125],"size":[270,58],"flags":{},"order":1,"mode":0,"inputs":[],"outputs":[{"name":"SAMPLER","type":"SAMPLER","links":[152]}],"properties":{"cnr_id":"comfy-core","ver":"0.3.34","Node name for S&R":"KSamplerSelect"},"widgets_values":["euler"]},
  {"id":66,"type":"RandomNoise","pos":[1197.8623046875,-135.85256958007812],"size":[270,82],"flags":{},"order":2,"mode":0,"inputs":[],"outputs":[{"name":"NOISE","type":"NOISE","links":[153]}],"properties":{"cnr_id":"comfy-core","ver":"0.3.34","Node name for S&R":"RandomNoise"},"widgets_values":[1234,"fixed"]},
  {"id":17,"type":"EmptyAceStepLatentAudio","pos":[1194.0584716796875,605.5031127929688],"size":[278.2037048339844,82.15135955810547],"flags":{},"order":3,"mode":0,"inputs":[],"outputs":[{"name":"LATENT","type":"LATENT","links":[170]}],"properties":{"cnr_id":"comfy-core","ver":"0.3.34","Node name for S&R":"EmptyAceStepLatentAudio"},"widgets_values":[150,1]},
  {"id":14,"type":"TextEncodeAceStepAudio","pos":[458.3848571777344,-17.144275665283203],"size":[433.3960876464844,699.25732421875],"flags":{},"order":5,"mode":0,"inputs":[{"name":"clip","type":"CLIP","link":80}],"outputs":[{"name":"CONDITIONING","type":"CONDITIONING","links":[130,154]}],"properties":{"cnr_id":"comfy-core","ver":"0.3.34","Node name for S&R":"TextEncodeAceStepAudio"},"widgets_values":["funk, pop, soul, rock, melodic, guitar, drums, bass, keyboard, percussion, 105 BPM, energetic, upbeat, groovy, vibrant, dynamic","[verse]\nNeon lights they flicker bright\nCity hums in dead of night\nRhythms pulse through concrete veins\nLost in echoes of refrains\n\n[verse]\nBassline groovin' in my chest\nHeartbeats match the city's zest\nElectric whispers fill the air\nSynthesized dreams everywhere\n\n[chorus]\nTurn it up and let it flow\nFeel the fire let it grow\nIn this rhythm we belong\nHear the night sing out our song\n\n[verse]\nGuitar strings they start to weep\nWake the soul from silent sleep\nEvery note a story told\nIn this night we’re bold and gold\n\n[bridge]\nVoices blend in harmony\nLost in pure cacophony\nTimeless echoes timeless cries\nSoulful shouts beneath the skies\n\n[verse]\nKeyboard dances on the keys\nMelodies on evening breeze\nCatch the tune and hold it tight\nIn this moment we take flight\n",0.9500000000000002]},
  {"id":64,"type":"BasicScheduler","pos":[1196.2774658203125,2.6407196521759033],"size":[270,106],"flags":{},"order":6,"mode":0,"inputs":[{"name":"model","type":"MODEL","link":167}],"outputs":[{"name":"SIGMAS","type":"SIGMAS","links":[151]}],"properties":{"cnr_id":"comfy-core","ver":"0.3.34","Node name for S&R":"BasicScheduler"},"widgets_values":["simple",27,0.9500000000000002]},
  {"id":68,"type":"VAEDecodeAudio","pos":[1574.37060546875,77.70063018798828],"size":[200.38494873046875,46.08256149291992],"flags":{},"order":10,"mode":0,"inputs":[{"name":"samples","type":"LATENT","link":168},{"name":"vae","type":"VAE","link":161}],"outputs":[{"name":"AUDIO","type":"AUDIO","links":[171,172]}],"properties":{"cnr_id":"comfy-core","ver":"0.3.34","Node name for S&R":"VAEDecodeAudio"},"widgets_values":[]},
  {"id":71,"type":"PreviewAudio","pos":[1809.2738037109375,77.15756225585938],"size":[393.26751708984375,182.66552734375],"flags":{},"order":11,"mode":0,"inputs":[{"name":"audio","type":"AUDIO","link":171}],"outputs":[],"properties":{"cnr_id":"comfy-core","ver":"0.3.34","Node name for S&R":"PreviewAudio"},"widgets_values":[]},
  {"id":72,"type":"SaveAudioMP3","pos":[1810.5047607421875,306.7505798339844],"size":[392.2533874511719,136],"flags":{},"order":12,"mode":0,"inputs":[{"name":"audio","type":"AUDIO","link":172}],"outputs":[],"properties":{"cnr_id":"comfy-core","ver":"0.3.34","Node name for S&R":"SaveAudioMP3"},"widgets_values":["audio/ComfyUI","V0"]}
  ],
  "links":[[80,40,1,14,0,"CLIP"],[111,40,0,45,0,"MODEL"],[130,14,0,47,0,"CONDITIONING"],[150,62,0,63,1,"GUIDER"],[151,64,0,63,3,"SIGMAS"],[152,65,0,63,2,"SAMPLER"],[153,66,0,63,0,"NOISE"],[154,14,0,62,1,"CONDITIONING"],[155,47,0,62,2,"CONDITIONING"],[157,45,0,62,0,"MODEL"],[161,40,2,68,1,"VAE"],[167,45,0,64,0,"MODEL"],[168,63,1,68,0,"LATENT"],[170,17,0,63,4,"LATENT"],[171,68,0,71,0,"AUDIO"],[172,68,0,72,0,"AUDIO"]],
  "groups":[],"config":{},"extra":{"ds":{"scale":0.7742308225955729,"offset":[50.58097747787872,426.2705199504457]},"frontendVersion":"1.19.9","ue_links":[]},
  "version":0.4
}

@blepping
Copy link
Author

@victornpb Thanks for posting an example! Unfortunately I probably broke it with my recent changes that added a few new parameters. Setting predict_image to false should make it work like it did before (but I think that's worse). Setting the mode to pre_cfg should make it (mostly) work like the built-in APG node that ComfyUI just added.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment