Last active
December 7, 2024 12:40
-
-
Save blepping/183fc215f9efb518d5fc11207535f277 to your computer and use it in GitHub Desktop.
Sampler thought experiments
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
import torch | |
torch.set_printoptions(profile="full") | |
dtype = torch.float64 | |
def randn_like(x, generator=None): | |
return torch.randn(*x.shape, out=x.detach().clone(), generator=generator) | |
# From ComfyUI | |
def get_ancestral_step_(sigma_from, sigma_to, eta=1.0): | |
"""Calculates the noise level (sigma_down) to step down to and the amount | |
of noise to add (sigma_up) when doing an ancestral sampling step.""" | |
if not eta or sigma_to == 0: | |
return sigma_to, sigma_to * 0 | |
sigma_to2, sigma_from2 = sigma_to**2, sigma_from**2 | |
sigma_up = min( | |
sigma_to, | |
eta * (sigma_to2 * (sigma_from2 - sigma_to2) / sigma_from2) ** 0.5, | |
) | |
sigma_down = (sigma_to2 - sigma_up**2) ** 0.5 | |
return sigma_down, sigma_up | |
# From ComfyUI | |
def get_ancestral_step(sigma_from, sigma_to, eta=1.0): | |
"""Calculates the noise level (sigma_down) to step down to and the amount | |
of noise to add (sigma_up) when doing an ancestral sampling step.""" | |
if not eta or sigma_to == 0: | |
return sigma_to, sigma_to * 0 | |
sigma_up = min( | |
sigma_to, | |
eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, | |
) | |
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 | |
return sigma_down, sigma_up | |
# I hope I never need you again. | |
def surch(a, b, target, increment=0.00001, start=-0.01, limit=1.01, f=torch.lerp): | |
best = torch.zeros_like(a) | |
bestincr = None | |
count = 0 | |
cur = start | |
while cur <= limit: | |
tried = target - f(b, a, cur) | |
if count == 0 or tried.abs().mean() < best.abs().mean(): | |
bestincr = cur | |
best = tried | |
count += 1 | |
cur += increment | |
print("BEST", bestincr, best, count) | |
class Solver: | |
name = "unknown" | |
def __init__(self, state): | |
self.state = state | |
self.model = state.model | |
def get_common(self, x, sigmas, step): | |
sigma, sigma_next = sigmas[step], sigmas[step + 1] | |
denoised = self.model(x, sigma) | |
return sigma, sigma_next, denoised | |
def step(self, x, sigmas, step): | |
raise NotImplementedError | |
class SolverHistory(Solver): | |
def __init__(self, state): | |
super().__init__(state) | |
self.denoised_history = [] | |
class SolverUncond(Solver): | |
def get_common(self, x, sigmas, step): | |
sigma, sigma_next = sigmas[step], sigmas[step + 1] | |
denoised, _cond, uncond = self.model(x, sigma, ext=True) | |
return sigma, sigma_next, denoised, uncond | |
class SolverAltCFGPP(SolverUncond): | |
def __init__(self, *args, altcfgpp_scale=1.0, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.altcfgpp_scale = altcfgpp_scale | |
def alt_cfgpp_mix(self, denoised, uncond): | |
return denoised * (1 + self.altcfgpp_scale) - uncond * self.altcfgpp_scale | |
class SolverETAMixin: | |
def __init__(self, *args, eta=1.0, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.eta = eta | |
self.generator = self.model.generator.clone_state() | |
def noise_sampler(self, x): | |
return randn_like(x, generator=self.generator) | |
class SolverAncestral(SolverETAMixin, Solver): | |
pass | |
class SolverAncestralUncond(SolverETAMixin, SolverUncond): | |
pass | |
class SolverAncestralAltCFGPP(SolverETAMixin, SolverAltCFGPP): | |
pass | |
class EulerBasic(Solver): | |
name = "normal Euler" | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised = self.get_common(x, sigmas, step) | |
dt = sigma_next - sigma | |
d = (x - denoised) / sigma | |
return x + d * dt | |
class EulerBasicAlt1(Solver): | |
name = "Euler add to denoised" | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised = self.get_common(x, sigmas, step) | |
d = (x - denoised) / sigma | |
return denoised + d * sigma_next | |
class EulerBasicAlt2(Solver): | |
name = "Euler add to denoised with ratio" | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised = self.get_common(x, sigmas, step) | |
ratio = sigma_next / sigma | |
return denoised + (x - denoised) * ratio | |
class EulerBasicAlt3(Solver): | |
name = "Euler ratios" | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised = self.get_common(x, sigmas, step) | |
ratio = sigma_next / sigma | |
return torch.lerp(denoised, x, ratio) | |
class EulerPP(SolverUncond): | |
name = "normal Euler CFG++" | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised, uncond = self.get_common(x, sigmas, step) | |
d = (x - uncond) / sigma | |
return denoised + d * sigma_next | |
class EulerPPAlt1(SolverUncond): | |
name = "Euler CFG++ add to denoised with ratio" | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised, uncond = self.get_common(x, sigmas, step) | |
ratio = sigma_next / sigma | |
return denoised + (x - uncond) * ratio | |
class EulerAltCFGPP(SolverAltCFGPP): | |
name = "normal Euler alternative CFG++" | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised, uncond = self.get_common(x, sigmas, step) | |
d = (x - self.alt_cfgpp_mix(denoised, uncond)) / sigma | |
dt = sigma_next - sigma | |
return x + d * dt | |
class EulerAltCFGPPAlt1(SolverAltCFGPP): | |
name = "Euler alternative CFG++ ratios" | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised, uncond = self.get_common(x, sigmas, step) | |
ratio = sigma_next / sigma | |
return torch.lerp(self.alt_cfgpp_mix(denoised, uncond), x, ratio) | |
class HeunBasic(Solver): | |
name = "Heun" | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised = self.get_common(x, sigmas, step) | |
d = (x - denoised) / sigma | |
dt = sigma_next - sigma | |
x_2 = x + d * dt | |
denoised_2 = self.model(x_2, sigma_next) | |
d_2 = (x_2 - denoised_2) / sigma_next | |
d_prime = (d + d_2) / 2 | |
return x + d_prime * dt | |
class HeunBasicAlt1(Solver): | |
name = "Heun ratios" | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised = self.get_common(x, sigmas, step) | |
ratio = sigma_next / sigma | |
x_2 = torch.lerp(denoised, x, ratio) | |
denoised_2 = self.model(x_2, sigma_next) | |
dratio = 1 - (sigma / sigma_next) / 2 | |
denoised_prime = torch.lerp(denoised_2, denoised, dratio) | |
return torch.lerp(denoised_prime, x, ratio) | |
class HeunBasicAlt2(Solver): | |
name = "Heun alt2" | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised = self.get_common(x, sigmas, step) | |
if sigma_next <= 0: | |
return denoised | |
d = (x - denoised) / sigma | |
dt = sigma_next - sigma | |
x_2 = x + d * dt | |
if sigma_next <= 0: | |
return x_2 | |
denoised_2 = self.model(x_2, sigma_next) | |
d_2 = (x_2 - denoised_2) / sigma_next | |
d_prime = (d + d_2) / 2 | |
denoised_prime = x - d_prime * sigma | |
# Also possible: denoised_prime * iratio + x * ratio | |
return denoised_prime + d_prime * sigma_next | |
class HeunBasicAlt3(Solver): | |
name = "Heun add to denoised with ratio" | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised = self.get_common(x, sigmas, step) | |
ratio = sigma_next / sigma | |
x_2 = torch.lerp(denoised, x, ratio) | |
denoised_2 = self.model(x_2, sigma_next) | |
dratio = 1 - (sigma / sigma_next) / 2 | |
denoised_prime = torch.lerp(denoised_2, denoised, dratio) | |
return denoised_prime + (x - denoised_prime) * ratio | |
class HeunAncestral(SolverAncestral): | |
name = "Heun ancestral" | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised = self.get_common(x, sigmas, step) | |
sigma_down, sigma_up = get_ancestral_step(sigma, sigma_next, eta=self.eta) | |
d = (x - denoised) / sigma | |
dt = sigma_down - sigma | |
x_2 = x + d * dt | |
noise = ( | |
self.noise_sampler(x) * sigma_up if sigma_next > 0 else torch.zeros_like(x) | |
) | |
denoised_2 = self.model(x_2, sigma_down) | |
d_2 = (x_2 - denoised_2) / sigma_down | |
d_prime = (d + d_2) / 2 | |
return x + d_prime * dt + noise | |
class HeunAncestralAlt1(SolverAncestral): | |
name = "Heun ancestral ratios" | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised = self.get_common(x, sigmas, step) | |
sigma_down, sigma_up = get_ancestral_step(sigma, sigma_next, eta=self.eta) | |
ratio = sigma_down / sigma | |
x_2 = torch.lerp(denoised, x, ratio) | |
noise = ( | |
self.noise_sampler(x) * sigma_up if sigma_next > 0 else torch.zeros_like(x) | |
) | |
denoised_2 = self.model(x_2, sigma_down) | |
dratio = 1 - (sigma / sigma_down) / 2 | |
denoised_prime = torch.lerp(denoised_2, denoised, dratio) | |
return torch.lerp(denoised_prime, x, ratio) + noise | |
class HeunCFGPP(SolverUncond): | |
name = "Heun CFG++" | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised, uncond = self.get_common(x, sigmas, step) | |
if sigma_next <= 0: | |
return denoised | |
d = (x - uncond) / sigma | |
x_2 = denoised + d * sigma_next | |
if sigma_next <= 0: | |
return x_2 | |
dratio = 1 - (sigma / sigma_next) / 2 | |
denoised_2, _cond_2, uncond_2 = self.model(x_2, sigma_next, ext=True) | |
d_2 = (x_2 - uncond_2) / sigma_next | |
d_prime = (d + d_2) / 2 | |
denoised_prime = torch.lerp(denoised_2, denoised, dratio) | |
return denoised_prime + d_prime * sigma_next | |
class HeunCFGPPAlt1(SolverUncond): | |
name = "Heun CFG++ ratios" | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised, uncond = self.get_common(x, sigmas, step) | |
ratio = sigma_next / sigma | |
x_2 = denoised + (x - uncond) * ratio | |
dratio = 1 - (sigma / sigma_next) / 2 | |
denoised_2, _cond_2, uncond_2 = self.model(x_2, sigma_next, ext=True) | |
denoised_prime = torch.lerp(denoised_2, denoised, dratio) | |
return ( | |
denoised_prime | |
+ denoised * 0.5 | |
+ (x - uncond - uncond_2 * (1 - dratio)) * ratio | |
) | |
class HeunAltCFGPP(SolverAltCFGPP): | |
name = "Heun alt CFG++" | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised, uncond = self.get_common(x, sigmas, step) | |
d = (x - self.alt_cfgpp_mix(denoised, uncond)) / sigma | |
dt = sigma_next - sigma | |
x_2 = x + d * dt | |
denoised_2, _cond_2, uncond_2 = self.model(x_2, sigma_next, ext=True) | |
d_2 = (x_2 - self.alt_cfgpp_mix(denoised_2, uncond_2)) / sigma_next | |
d_prime = (d + d_2) / 2 | |
return x + d_prime * dt | |
class HeunAltCFGPPAlt1(SolverAltCFGPP): | |
name = "Heun alt CFG++ ratios" | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised, uncond = self.get_common(x, sigmas, step) | |
ratio = sigma_next / sigma | |
x_2 = torch.lerp(self.alt_cfgpp_mix(denoised, uncond), x, ratio) | |
dratio = 1 - (sigma / sigma_next) / 2 | |
denoised_2, _cond_2, uncond_2 = self.model(x_2, sigma_next, ext=True) | |
denoised_prime = torch.lerp(denoised_2, denoised, dratio) | |
uncond_prime = torch.lerp(uncond_2, uncond, dratio) | |
return torch.lerp(self.alt_cfgpp_mix(denoised_prime, uncond_prime), x, ratio) | |
class BogackiBasic(Solver): | |
name = "Bogacki" | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised = self.get_common(x, sigmas, step) | |
d = (x - denoised) / sigma | |
dt = sigma_next - sigma | |
k1 = d * dt | |
sigma_2 = sigma + dt / 2 | |
x_2 = x + k1 / 2 | |
denoised_2 = self.model(x_2, sigma_2) | |
k2 = ((x_2 - denoised_2) / sigma_2) * dt | |
sigma_3 = sigma + 3 * dt / 4 | |
x_3 = x + 3 * k1 / 4 + k2 / 4 | |
denoised_3 = self.model(x_3, sigma_3) | |
k3 = ((x_3 - denoised_3) / sigma_3) * dt | |
return x + 2 * k1 / 9 + k2 / 3 + 4 * k3 / 9 | |
class BogackiAlt1(Solver): | |
name = "Bogacki2" | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised = self.get_common(x, sigmas, step) | |
d = (x - denoised) / sigma | |
dt = sigma_next - sigma | |
k1 = d * dt | |
sigma_2 = sigma + dt / 2 | |
x_2 = x + k1 / 2 | |
denoised_2 = self.model(x_2, sigma_2) | |
k2 = ((x_2 - denoised_2) / sigma_2) * dt | |
sigma_3 = sigma + 3 * dt / 4 | |
x_3 = x + 3 * k1 / 4 + k2 / 4 | |
denoised_3 = self.model(x_3, sigma_3) | |
k3 = ((x_3 - denoised_3) / sigma_3) * dt | |
x_new = x + 2 * k1 / 9 + k2 / 3 + 4 * k3 / 9 | |
ratio = sigma_next / sigma | |
denoised_prime = (x_new - x * ratio) / (1 - ratio) | |
# Or: return denoised_prime + (x - denoised_prime) * ratio | |
return torch.lerp(denoised_prime, x, ratio) | |
class HeunReversible(Solver): | |
name = "Heun reversible" | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised = self.get_common(x, sigmas, step) | |
sigma_rdown = get_ancestral_step(sigma, sigma_next, 1.0)[0] | |
dtr = sigma_rdown - sigma | |
d = (x - denoised) / sigma | |
dt = sigma_next - sigma | |
x_2 = x + d * dt | |
denoised_2 = self.model(x_2, sigma_next) | |
d_2 = (x_2 - denoised_2) / sigma_next | |
d_prime = (d + d_2) / 2 | |
correction = dtr**2 * (d_2 - d) / 4 | |
return x + d_prime * dt - correction | |
class DPMPP2MBasic(SolverHistory): | |
name = "dpmpp_2m basic" | |
@staticmethod | |
def sigma_fn(t): | |
return t.neg().exp() | |
@staticmethod | |
def t_fn(sigma): | |
return sigma.log().neg() | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised = self.get_common(x, sigmas, step) | |
denoised_last = ( | |
None if not len(self.denoised_history) else self.denoised_history[-1] | |
) | |
self.denoised_history.append(denoised) | |
t, t_next = self.t_fn(sigma), self.t_fn(sigma_next) | |
s, s_next = self.sigma_fn(t), self.sigma_fn(t_next) | |
h = t_next - t | |
if denoised_last is None: | |
return (s_next / s) * x - (-h).expm1() * denoised | |
t_last = self.t_fn(sigmas[step - 1]) | |
h_last = t - t_last | |
r = h_last / h | |
denoised_prime = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * denoised_last | |
return (s_next / s) * x - (-h).expm1() * denoised_prime | |
class DPMPP2MAlt1(DPMPP2MBasic): | |
name = "dpmpp_2m alt1" | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised = self.get_common(x, sigmas, step) | |
denoised_last = ( | |
None if not len(self.denoised_history) else self.denoised_history[-1] | |
) | |
self.denoised_history.append(denoised) | |
t, t_next = self.t_fn(sigma), self.t_fn(sigma_next) | |
s, s_next = self.sigma_fn(t), self.sigma_fn(t_next) | |
h = t_next - t | |
x_scale = s_next / s | |
denoised_scale = (-h).expm1() | |
if denoised_last is None: | |
denoised_prime = denoised | |
else: | |
t_last = self.t_fn(sigmas[step - 1]) | |
h_last = t - t_last | |
r_ = 1 / (2 * (h_last / h)) | |
denoised_prime = (1 + r_) * denoised - r_ * denoised_last | |
return x * x_scale - denoised_prime * denoised_scale | |
class DPMPP2MAlt2(DPMPP2MBasic): | |
name = "dpmpp_2m alt2" | |
def step(self, x, sigmas, step): | |
sigma, sigma_next, denoised = self.get_common(x, sigmas, step) | |
denoised_last = ( | |
None if not len(self.denoised_history) else self.denoised_history[-1] | |
) | |
self.denoised_history.append(denoised) | |
t, t_next = self.t_fn(sigma), self.t_fn(sigma_next) | |
h = t_next - t | |
if denoised_last is None: | |
denoised_prime = denoised | |
else: | |
t_last = self.t_fn(sigmas[step - 1]) | |
h_last = t - t_last | |
r_ = 1 / (2 * (h_last / h)) | |
denoised_prime = (1 + r_) * denoised - r_ * denoised_last | |
return torch.lerp(denoised_prime, x, sigma_next / sigma) | |
class Model: | |
def __init__(self, *, cfg=2, seed=0): | |
self.seed = seed | |
self.generator = torch.Generator() | |
self.generator.manual_seed(seed) | |
self.cfg = cfg | |
def randn_like(self, x): | |
return randn_like(x, generator=self.generator) | |
def __call__(self, x, sigma, *, ext=False): | |
cond = self.randn_like(x).abs() * (sigma * 0.1) | |
uncond = self.randn_like(x).abs() * (sigma * 0.1) | |
denoised = uncond + (cond - uncond) * self.cfg | |
return (denoised, cond, uncond) if ext else denoised | |
class State: | |
def __init__(self, *, max_sigma=14.56, sigmas=None, steps=5, seed=0, cfg=7): | |
torch.manual_seed(0) | |
self.model = Model(seed=seed) | |
if sigmas is None: | |
sigmas = torch.linspace(max_sigma, 0, steps + 1).to(dtype) | |
self.sigmas = sigmas | |
self.x = ( | |
torch.randn(2 * 2, dtype=dtype, generator=self.model.generator) | |
.abs() | |
.view(2, 2) | |
* self.sigmas[0] | |
) | |
self.result = self.x.new_zeros(len(self.sigmas), *self.x.shape) | |
self.last_solver = None | |
def run_steps(self, solver): | |
self.last_solver = solver | |
x = self.x.clone() | |
self.result[0] = x | |
sigma_last_i = len(self.sigmas) - 2 | |
for i in range(sigma_last_i + 1): | |
if i == sigma_last_i: | |
x = self.model(x, self.sigmas[i]) | |
else: | |
x = solver.step(x, self.sigmas, i) | |
self.result[i + 1] = x | |
return x | |
def run_test(solver_class, *, state_args, solver_args): | |
state = State(**state_args) | |
solver = solver_class(state, **solver_args) | |
state.run_steps(solver) | |
return state | |
def run_tests( | |
tests, | |
*, | |
state_args=None, | |
solver_args=None, | |
rtol=1e-06, | |
atol=1e-08, | |
verbose=False, | |
): | |
results = [] | |
if state_args is None: | |
state_args = {} | |
if solver_args is None: | |
solver_args = {} | |
for test in tests: | |
if verbose: | |
print(f"\n\n===== {test.name}") | |
state = run_test(test, state_args=state_args, solver_args=solver_args) | |
if verbose: | |
print(state.sigmas) | |
print(state.result[1:]) | |
results.append(state) | |
if verbose: | |
print("\n\n-------------------") | |
resultcount = len(results) | |
failcount = 0 | |
checked = set() | |
for i, curr_state in enumerate(results): | |
for j in range(resultcount): | |
if i == j or (i, j) in checked or (j, i) in checked: | |
continue | |
checked.update(((i, j),)) | |
vs_state = results[j] | |
if torch.allclose(curr_state.result, vs_state.result, atol=atol, rtol=rtol): | |
continue | |
if failcount == 0: | |
print("\n***\n") | |
failcount += 1 | |
print( | |
f"\n\n!!!!!!!! Fail: {curr_state.last_solver.name} vs {vs_state.last_solver.name} !!!!!!!!\n" | |
) | |
print(curr_state.result[1:]) | |
print() | |
print(vs_state.result[1:]) | |
diff = curr_state.result[1:] - vs_state.result[1:] | |
print(f"\nDifference(mean={diff.mean()}, absmax={diff.abs().max()}):\n") | |
print(diff) | |
print() | |
def main(): | |
run_tests(( | |
EulerBasic, | |
EulerBasicAlt1, | |
EulerBasicAlt2, | |
EulerBasicAlt3, | |
)) | |
run_tests((EulerPP, EulerPPAlt1)) | |
run_tests((EulerBasic, partial(EulerAltCFGPP, altcfgpp_scale=0))) | |
run_tests((EulerAltCFGPP, EulerAltCFGPPAlt1)) | |
run_tests((EulerAltCFGPP, EulerAltCFGPPAlt1), solver_args={"altcfgpp_scale": 0.5}) | |
run_tests((HeunBasic, HeunBasicAlt1, HeunBasicAlt2, HeunBasicAlt3)) | |
run_tests((HeunBasic, partial(HeunAncestral, eta=0))) | |
run_tests((HeunAncestral, HeunAncestralAlt1)) | |
run_tests((HeunCFGPP, HeunCFGPPAlt1)) | |
run_tests(( | |
HeunBasic, | |
partial(HeunAltCFGPP, altcfgpp_scale=0), | |
partial(HeunAltCFGPPAlt1, altcfgpp_scale=0), | |
)) | |
run_tests((HeunAltCFGPP, HeunAltCFGPPAlt1)) | |
run_tests((BogackiBasic, BogackiAlt1)) | |
run_tests((DPMPP2MBasic, DPMPP2MAlt1, DPMPP2MAlt2)) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment