Skip to content

Instantly share code, notes, and snippets.

@blepping
Last active December 7, 2024 12:40
Show Gist options
  • Save blepping/183fc215f9efb518d5fc11207535f277 to your computer and use it in GitHub Desktop.
Save blepping/183fc215f9efb518d5fc11207535f277 to your computer and use it in GitHub Desktop.
Sampler thought experiments
from functools import partial
import torch
torch.set_printoptions(profile="full")
dtype = torch.float64
def randn_like(x, generator=None):
return torch.randn(*x.shape, out=x.detach().clone(), generator=generator)
# From ComfyUI
def get_ancestral_step_(sigma_from, sigma_to, eta=1.0):
"""Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step."""
if not eta or sigma_to == 0:
return sigma_to, sigma_to * 0
sigma_to2, sigma_from2 = sigma_to**2, sigma_from**2
sigma_up = min(
sigma_to,
eta * (sigma_to2 * (sigma_from2 - sigma_to2) / sigma_from2) ** 0.5,
)
sigma_down = (sigma_to2 - sigma_up**2) ** 0.5
return sigma_down, sigma_up
# From ComfyUI
def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
"""Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step."""
if not eta or sigma_to == 0:
return sigma_to, sigma_to * 0
sigma_up = min(
sigma_to,
eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
)
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
return sigma_down, sigma_up
# I hope I never need you again.
def surch(a, b, target, increment=0.00001, start=-0.01, limit=1.01, f=torch.lerp):
best = torch.zeros_like(a)
bestincr = None
count = 0
cur = start
while cur <= limit:
tried = target - f(b, a, cur)
if count == 0 or tried.abs().mean() < best.abs().mean():
bestincr = cur
best = tried
count += 1
cur += increment
print("BEST", bestincr, best, count)
class Solver:
name = "unknown"
def __init__(self, state):
self.state = state
self.model = state.model
def get_common(self, x, sigmas, step):
sigma, sigma_next = sigmas[step], sigmas[step + 1]
denoised = self.model(x, sigma)
return sigma, sigma_next, denoised
def step(self, x, sigmas, step):
raise NotImplementedError
class SolverHistory(Solver):
def __init__(self, state):
super().__init__(state)
self.denoised_history = []
class SolverUncond(Solver):
def get_common(self, x, sigmas, step):
sigma, sigma_next = sigmas[step], sigmas[step + 1]
denoised, _cond, uncond = self.model(x, sigma, ext=True)
return sigma, sigma_next, denoised, uncond
class SolverAltCFGPP(SolverUncond):
def __init__(self, *args, altcfgpp_scale=1.0, **kwargs):
super().__init__(*args, **kwargs)
self.altcfgpp_scale = altcfgpp_scale
def alt_cfgpp_mix(self, denoised, uncond):
return denoised * (1 + self.altcfgpp_scale) - uncond * self.altcfgpp_scale
class SolverETAMixin:
def __init__(self, *args, eta=1.0, **kwargs):
super().__init__(*args, **kwargs)
self.eta = eta
self.generator = self.model.generator.clone_state()
def noise_sampler(self, x):
return randn_like(x, generator=self.generator)
class SolverAncestral(SolverETAMixin, Solver):
pass
class SolverAncestralUncond(SolverETAMixin, SolverUncond):
pass
class SolverAncestralAltCFGPP(SolverETAMixin, SolverAltCFGPP):
pass
class EulerBasic(Solver):
name = "normal Euler"
def step(self, x, sigmas, step):
sigma, sigma_next, denoised = self.get_common(x, sigmas, step)
dt = sigma_next - sigma
d = (x - denoised) / sigma
return x + d * dt
class EulerBasicAlt1(Solver):
name = "Euler add to denoised"
def step(self, x, sigmas, step):
sigma, sigma_next, denoised = self.get_common(x, sigmas, step)
d = (x - denoised) / sigma
return denoised + d * sigma_next
class EulerBasicAlt2(Solver):
name = "Euler add to denoised with ratio"
def step(self, x, sigmas, step):
sigma, sigma_next, denoised = self.get_common(x, sigmas, step)
ratio = sigma_next / sigma
return denoised + (x - denoised) * ratio
class EulerBasicAlt3(Solver):
name = "Euler ratios"
def step(self, x, sigmas, step):
sigma, sigma_next, denoised = self.get_common(x, sigmas, step)
ratio = sigma_next / sigma
return torch.lerp(denoised, x, ratio)
class EulerPP(SolverUncond):
name = "normal Euler CFG++"
def step(self, x, sigmas, step):
sigma, sigma_next, denoised, uncond = self.get_common(x, sigmas, step)
d = (x - uncond) / sigma
return denoised + d * sigma_next
class EulerPPAlt1(SolverUncond):
name = "Euler CFG++ add to denoised with ratio"
def step(self, x, sigmas, step):
sigma, sigma_next, denoised, uncond = self.get_common(x, sigmas, step)
ratio = sigma_next / sigma
return denoised + (x - uncond) * ratio
class EulerAltCFGPP(SolverAltCFGPP):
name = "normal Euler alternative CFG++"
def step(self, x, sigmas, step):
sigma, sigma_next, denoised, uncond = self.get_common(x, sigmas, step)
d = (x - self.alt_cfgpp_mix(denoised, uncond)) / sigma
dt = sigma_next - sigma
return x + d * dt
class EulerAltCFGPPAlt1(SolverAltCFGPP):
name = "Euler alternative CFG++ ratios"
def step(self, x, sigmas, step):
sigma, sigma_next, denoised, uncond = self.get_common(x, sigmas, step)
ratio = sigma_next / sigma
return torch.lerp(self.alt_cfgpp_mix(denoised, uncond), x, ratio)
class HeunBasic(Solver):
name = "Heun"
def step(self, x, sigmas, step):
sigma, sigma_next, denoised = self.get_common(x, sigmas, step)
d = (x - denoised) / sigma
dt = sigma_next - sigma
x_2 = x + d * dt
denoised_2 = self.model(x_2, sigma_next)
d_2 = (x_2 - denoised_2) / sigma_next
d_prime = (d + d_2) / 2
return x + d_prime * dt
class HeunBasicAlt1(Solver):
name = "Heun ratios"
def step(self, x, sigmas, step):
sigma, sigma_next, denoised = self.get_common(x, sigmas, step)
ratio = sigma_next / sigma
x_2 = torch.lerp(denoised, x, ratio)
denoised_2 = self.model(x_2, sigma_next)
dratio = 1 - (sigma / sigma_next) / 2
denoised_prime = torch.lerp(denoised_2, denoised, dratio)
return torch.lerp(denoised_prime, x, ratio)
class HeunBasicAlt2(Solver):
name = "Heun alt2"
def step(self, x, sigmas, step):
sigma, sigma_next, denoised = self.get_common(x, sigmas, step)
if sigma_next <= 0:
return denoised
d = (x - denoised) / sigma
dt = sigma_next - sigma
x_2 = x + d * dt
if sigma_next <= 0:
return x_2
denoised_2 = self.model(x_2, sigma_next)
d_2 = (x_2 - denoised_2) / sigma_next
d_prime = (d + d_2) / 2
denoised_prime = x - d_prime * sigma
# Also possible: denoised_prime * iratio + x * ratio
return denoised_prime + d_prime * sigma_next
class HeunBasicAlt3(Solver):
name = "Heun add to denoised with ratio"
def step(self, x, sigmas, step):
sigma, sigma_next, denoised = self.get_common(x, sigmas, step)
ratio = sigma_next / sigma
x_2 = torch.lerp(denoised, x, ratio)
denoised_2 = self.model(x_2, sigma_next)
dratio = 1 - (sigma / sigma_next) / 2
denoised_prime = torch.lerp(denoised_2, denoised, dratio)
return denoised_prime + (x - denoised_prime) * ratio
class HeunAncestral(SolverAncestral):
name = "Heun ancestral"
def step(self, x, sigmas, step):
sigma, sigma_next, denoised = self.get_common(x, sigmas, step)
sigma_down, sigma_up = get_ancestral_step(sigma, sigma_next, eta=self.eta)
d = (x - denoised) / sigma
dt = sigma_down - sigma
x_2 = x + d * dt
noise = (
self.noise_sampler(x) * sigma_up if sigma_next > 0 else torch.zeros_like(x)
)
denoised_2 = self.model(x_2, sigma_down)
d_2 = (x_2 - denoised_2) / sigma_down
d_prime = (d + d_2) / 2
return x + d_prime * dt + noise
class HeunAncestralAlt1(SolverAncestral):
name = "Heun ancestral ratios"
def step(self, x, sigmas, step):
sigma, sigma_next, denoised = self.get_common(x, sigmas, step)
sigma_down, sigma_up = get_ancestral_step(sigma, sigma_next, eta=self.eta)
ratio = sigma_down / sigma
x_2 = torch.lerp(denoised, x, ratio)
noise = (
self.noise_sampler(x) * sigma_up if sigma_next > 0 else torch.zeros_like(x)
)
denoised_2 = self.model(x_2, sigma_down)
dratio = 1 - (sigma / sigma_down) / 2
denoised_prime = torch.lerp(denoised_2, denoised, dratio)
return torch.lerp(denoised_prime, x, ratio) + noise
class HeunCFGPP(SolverUncond):
name = "Heun CFG++"
def step(self, x, sigmas, step):
sigma, sigma_next, denoised, uncond = self.get_common(x, sigmas, step)
if sigma_next <= 0:
return denoised
d = (x - uncond) / sigma
x_2 = denoised + d * sigma_next
if sigma_next <= 0:
return x_2
dratio = 1 - (sigma / sigma_next) / 2
denoised_2, _cond_2, uncond_2 = self.model(x_2, sigma_next, ext=True)
d_2 = (x_2 - uncond_2) / sigma_next
d_prime = (d + d_2) / 2
denoised_prime = torch.lerp(denoised_2, denoised, dratio)
return denoised_prime + d_prime * sigma_next
class HeunCFGPPAlt1(SolverUncond):
name = "Heun CFG++ ratios"
def step(self, x, sigmas, step):
sigma, sigma_next, denoised, uncond = self.get_common(x, sigmas, step)
ratio = sigma_next / sigma
x_2 = denoised + (x - uncond) * ratio
dratio = 1 - (sigma / sigma_next) / 2
denoised_2, _cond_2, uncond_2 = self.model(x_2, sigma_next, ext=True)
denoised_prime = torch.lerp(denoised_2, denoised, dratio)
return (
denoised_prime
+ denoised * 0.5
+ (x - uncond - uncond_2 * (1 - dratio)) * ratio
)
class HeunAltCFGPP(SolverAltCFGPP):
name = "Heun alt CFG++"
def step(self, x, sigmas, step):
sigma, sigma_next, denoised, uncond = self.get_common(x, sigmas, step)
d = (x - self.alt_cfgpp_mix(denoised, uncond)) / sigma
dt = sigma_next - sigma
x_2 = x + d * dt
denoised_2, _cond_2, uncond_2 = self.model(x_2, sigma_next, ext=True)
d_2 = (x_2 - self.alt_cfgpp_mix(denoised_2, uncond_2)) / sigma_next
d_prime = (d + d_2) / 2
return x + d_prime * dt
class HeunAltCFGPPAlt1(SolverAltCFGPP):
name = "Heun alt CFG++ ratios"
def step(self, x, sigmas, step):
sigma, sigma_next, denoised, uncond = self.get_common(x, sigmas, step)
ratio = sigma_next / sigma
x_2 = torch.lerp(self.alt_cfgpp_mix(denoised, uncond), x, ratio)
dratio = 1 - (sigma / sigma_next) / 2
denoised_2, _cond_2, uncond_2 = self.model(x_2, sigma_next, ext=True)
denoised_prime = torch.lerp(denoised_2, denoised, dratio)
uncond_prime = torch.lerp(uncond_2, uncond, dratio)
return torch.lerp(self.alt_cfgpp_mix(denoised_prime, uncond_prime), x, ratio)
class BogackiBasic(Solver):
name = "Bogacki"
def step(self, x, sigmas, step):
sigma, sigma_next, denoised = self.get_common(x, sigmas, step)
d = (x - denoised) / sigma
dt = sigma_next - sigma
k1 = d * dt
sigma_2 = sigma + dt / 2
x_2 = x + k1 / 2
denoised_2 = self.model(x_2, sigma_2)
k2 = ((x_2 - denoised_2) / sigma_2) * dt
sigma_3 = sigma + 3 * dt / 4
x_3 = x + 3 * k1 / 4 + k2 / 4
denoised_3 = self.model(x_3, sigma_3)
k3 = ((x_3 - denoised_3) / sigma_3) * dt
return x + 2 * k1 / 9 + k2 / 3 + 4 * k3 / 9
class BogackiAlt1(Solver):
name = "Bogacki2"
def step(self, x, sigmas, step):
sigma, sigma_next, denoised = self.get_common(x, sigmas, step)
d = (x - denoised) / sigma
dt = sigma_next - sigma
k1 = d * dt
sigma_2 = sigma + dt / 2
x_2 = x + k1 / 2
denoised_2 = self.model(x_2, sigma_2)
k2 = ((x_2 - denoised_2) / sigma_2) * dt
sigma_3 = sigma + 3 * dt / 4
x_3 = x + 3 * k1 / 4 + k2 / 4
denoised_3 = self.model(x_3, sigma_3)
k3 = ((x_3 - denoised_3) / sigma_3) * dt
x_new = x + 2 * k1 / 9 + k2 / 3 + 4 * k3 / 9
ratio = sigma_next / sigma
denoised_prime = (x_new - x * ratio) / (1 - ratio)
# Or: return denoised_prime + (x - denoised_prime) * ratio
return torch.lerp(denoised_prime, x, ratio)
class HeunReversible(Solver):
name = "Heun reversible"
def step(self, x, sigmas, step):
sigma, sigma_next, denoised = self.get_common(x, sigmas, step)
sigma_rdown = get_ancestral_step(sigma, sigma_next, 1.0)[0]
dtr = sigma_rdown - sigma
d = (x - denoised) / sigma
dt = sigma_next - sigma
x_2 = x + d * dt
denoised_2 = self.model(x_2, sigma_next)
d_2 = (x_2 - denoised_2) / sigma_next
d_prime = (d + d_2) / 2
correction = dtr**2 * (d_2 - d) / 4
return x + d_prime * dt - correction
class DPMPP2MBasic(SolverHistory):
name = "dpmpp_2m basic"
@staticmethod
def sigma_fn(t):
return t.neg().exp()
@staticmethod
def t_fn(sigma):
return sigma.log().neg()
def step(self, x, sigmas, step):
sigma, sigma_next, denoised = self.get_common(x, sigmas, step)
denoised_last = (
None if not len(self.denoised_history) else self.denoised_history[-1]
)
self.denoised_history.append(denoised)
t, t_next = self.t_fn(sigma), self.t_fn(sigma_next)
s, s_next = self.sigma_fn(t), self.sigma_fn(t_next)
h = t_next - t
if denoised_last is None:
return (s_next / s) * x - (-h).expm1() * denoised
t_last = self.t_fn(sigmas[step - 1])
h_last = t - t_last
r = h_last / h
denoised_prime = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * denoised_last
return (s_next / s) * x - (-h).expm1() * denoised_prime
class DPMPP2MAlt1(DPMPP2MBasic):
name = "dpmpp_2m alt1"
def step(self, x, sigmas, step):
sigma, sigma_next, denoised = self.get_common(x, sigmas, step)
denoised_last = (
None if not len(self.denoised_history) else self.denoised_history[-1]
)
self.denoised_history.append(denoised)
t, t_next = self.t_fn(sigma), self.t_fn(sigma_next)
s, s_next = self.sigma_fn(t), self.sigma_fn(t_next)
h = t_next - t
x_scale = s_next / s
denoised_scale = (-h).expm1()
if denoised_last is None:
denoised_prime = denoised
else:
t_last = self.t_fn(sigmas[step - 1])
h_last = t - t_last
r_ = 1 / (2 * (h_last / h))
denoised_prime = (1 + r_) * denoised - r_ * denoised_last
return x * x_scale - denoised_prime * denoised_scale
class DPMPP2MAlt2(DPMPP2MBasic):
name = "dpmpp_2m alt2"
def step(self, x, sigmas, step):
sigma, sigma_next, denoised = self.get_common(x, sigmas, step)
denoised_last = (
None if not len(self.denoised_history) else self.denoised_history[-1]
)
self.denoised_history.append(denoised)
t, t_next = self.t_fn(sigma), self.t_fn(sigma_next)
h = t_next - t
if denoised_last is None:
denoised_prime = denoised
else:
t_last = self.t_fn(sigmas[step - 1])
h_last = t - t_last
r_ = 1 / (2 * (h_last / h))
denoised_prime = (1 + r_) * denoised - r_ * denoised_last
return torch.lerp(denoised_prime, x, sigma_next / sigma)
class Model:
def __init__(self, *, cfg=2, seed=0):
self.seed = seed
self.generator = torch.Generator()
self.generator.manual_seed(seed)
self.cfg = cfg
def randn_like(self, x):
return randn_like(x, generator=self.generator)
def __call__(self, x, sigma, *, ext=False):
cond = self.randn_like(x).abs() * (sigma * 0.1)
uncond = self.randn_like(x).abs() * (sigma * 0.1)
denoised = uncond + (cond - uncond) * self.cfg
return (denoised, cond, uncond) if ext else denoised
class State:
def __init__(self, *, max_sigma=14.56, sigmas=None, steps=5, seed=0, cfg=7):
torch.manual_seed(0)
self.model = Model(seed=seed)
if sigmas is None:
sigmas = torch.linspace(max_sigma, 0, steps + 1).to(dtype)
self.sigmas = sigmas
self.x = (
torch.randn(2 * 2, dtype=dtype, generator=self.model.generator)
.abs()
.view(2, 2)
* self.sigmas[0]
)
self.result = self.x.new_zeros(len(self.sigmas), *self.x.shape)
self.last_solver = None
def run_steps(self, solver):
self.last_solver = solver
x = self.x.clone()
self.result[0] = x
sigma_last_i = len(self.sigmas) - 2
for i in range(sigma_last_i + 1):
if i == sigma_last_i:
x = self.model(x, self.sigmas[i])
else:
x = solver.step(x, self.sigmas, i)
self.result[i + 1] = x
return x
def run_test(solver_class, *, state_args, solver_args):
state = State(**state_args)
solver = solver_class(state, **solver_args)
state.run_steps(solver)
return state
def run_tests(
tests,
*,
state_args=None,
solver_args=None,
rtol=1e-06,
atol=1e-08,
verbose=False,
):
results = []
if state_args is None:
state_args = {}
if solver_args is None:
solver_args = {}
for test in tests:
if verbose:
print(f"\n\n===== {test.name}")
state = run_test(test, state_args=state_args, solver_args=solver_args)
if verbose:
print(state.sigmas)
print(state.result[1:])
results.append(state)
if verbose:
print("\n\n-------------------")
resultcount = len(results)
failcount = 0
checked = set()
for i, curr_state in enumerate(results):
for j in range(resultcount):
if i == j or (i, j) in checked or (j, i) in checked:
continue
checked.update(((i, j),))
vs_state = results[j]
if torch.allclose(curr_state.result, vs_state.result, atol=atol, rtol=rtol):
continue
if failcount == 0:
print("\n***\n")
failcount += 1
print(
f"\n\n!!!!!!!! Fail: {curr_state.last_solver.name} vs {vs_state.last_solver.name} !!!!!!!!\n"
)
print(curr_state.result[1:])
print()
print(vs_state.result[1:])
diff = curr_state.result[1:] - vs_state.result[1:]
print(f"\nDifference(mean={diff.mean()}, absmax={diff.abs().max()}):\n")
print(diff)
print()
def main():
run_tests((
EulerBasic,
EulerBasicAlt1,
EulerBasicAlt2,
EulerBasicAlt3,
))
run_tests((EulerPP, EulerPPAlt1))
run_tests((EulerBasic, partial(EulerAltCFGPP, altcfgpp_scale=0)))
run_tests((EulerAltCFGPP, EulerAltCFGPPAlt1))
run_tests((EulerAltCFGPP, EulerAltCFGPPAlt1), solver_args={"altcfgpp_scale": 0.5})
run_tests((HeunBasic, HeunBasicAlt1, HeunBasicAlt2, HeunBasicAlt3))
run_tests((HeunBasic, partial(HeunAncestral, eta=0)))
run_tests((HeunAncestral, HeunAncestralAlt1))
run_tests((HeunCFGPP, HeunCFGPPAlt1))
run_tests((
HeunBasic,
partial(HeunAltCFGPP, altcfgpp_scale=0),
partial(HeunAltCFGPPAlt1, altcfgpp_scale=0),
))
run_tests((HeunAltCFGPP, HeunAltCFGPPAlt1))
run_tests((BogackiBasic, BogackiAlt1))
run_tests((DPMPP2MBasic, DPMPP2MAlt1, DPMPP2MAlt2))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment