Skip to content

Instantly share code, notes, and snippets.

@KohakuBlueleaf
Created March 12, 2025 08:36
Show Gist options
  • Save KohakuBlueleaf/85cd080e9e5cd9e84582e22a8e8770c1 to your computer and use it in GitHub Desktop.
Save KohakuBlueleaf/85cd080e9e5cd9e84582e22a8e8770c1 to your computer and use it in GitHub Desktop.
A simple implementation of IMM https://arxiv.org/pdf/2503.07565
"""
An Minimal Implementation of IMM (Inductive Moment Matching)
"""
import math
import torch
import torch.nn.functional as F
def compute_mmd_loss_fully_vectorized(
ys_t: torch.Tensor,
ys_r: torch.Tensor,
M: int = 4,
kernel_type: str = "laplace",
kernel_bandwidth: float = 1.0,
) -> torch.Tensor:
"""Compute MMD loss with fully vectorized operations across all groups.
Args:
ys_t: First set of samples, tensor of shape [B, C, H, W]
ys_r: Second set of samples, tensor of shape [B, C, H, W]
M: Number of particles for MMD estimation
kernel_type: Type of kernel, one of "laplace", "rbf", or "energy"
kernel_bandwidth: Bandwidth parameter for kernel
Returns:
mmd_loss: Scalar tensor
"""
batch_size = ys_t.shape[0]
assert batch_size % M == 0, f"Batch size {batch_size} must be divisible by M={M}"
# Reshape for particle grouping
groups = batch_size // M
# Flatten spatial dimensions
ys_t_flat = ys_t.reshape(groups, M, -1) # [groups, M, C*H*W]
ys_r_flat = ys_r.reshape(groups, M, -1) # [groups, M, C*H*W]
# Vectorized computation of kernels for all groups at once
if kernel_type == "laplace":
# For each group, compute expanded tensors for broadcasting
y_t = ys_t_flat.unsqueeze(2) # [groups, M, 1, C*H*W]
y_t_2 = ys_t_flat.unsqueeze(1) # [groups, 1, M, C*H*W]
y_r = ys_r_flat.unsqueeze(2) # [groups, M, 1, C*H*W]
y_r_2 = ys_r_flat.unsqueeze(1) # [groups, 1, M, C*H*W]
# Compute pairwise L1 distances
t_t_dist = torch.abs(y_t - y_t_2).sum(dim=-1) # [groups, M, M]
r_r_dist = torch.abs(y_r - y_r_2).sum(dim=-1) # [groups, M, M]
t_r_dist = torch.abs(y_t - y_r_2).sum(dim=-1) # [groups, M, M]
# Apply Laplace kernel
K_t_t = torch.exp(-t_t_dist / kernel_bandwidth)
K_r_r = torch.exp(-r_r_dist / kernel_bandwidth)
K_t_r = torch.exp(-t_r_dist / kernel_bandwidth)
elif kernel_type == "rbf":
# For each group, compute expanded tensors for broadcasting
y_t = ys_t_flat.unsqueeze(2) # [groups, M, 1, C*H*W]
y_t_2 = ys_t_flat.unsqueeze(1) # [groups, 1, M, C*H*W]
y_r = ys_r_flat.unsqueeze(2) # [groups, M, 1, C*H*W]
y_r_2 = ys_r_flat.unsqueeze(1) # [groups, 1, M, C*H*W]
# Compute pairwise squared L2 distances
t_t_dist = ((y_t - y_t_2) ** 2).sum(dim=-1) # [groups, M, M]
r_r_dist = ((y_r - y_r_2) ** 2).sum(dim=-1) # [groups, M, M]
t_r_dist = ((y_t - y_r_2) ** 2).sum(dim=-1) # [groups, M, M]
# Apply RBF kernel
K_t_t = torch.exp(-t_t_dist / (2 * kernel_bandwidth**2))
K_r_r = torch.exp(-r_r_dist / (2 * kernel_bandwidth**2))
K_t_r = torch.exp(-t_r_dist / (2 * kernel_bandwidth**2))
elif kernel_type == "energy":
# For each group, compute expanded tensors for broadcasting
y_t = ys_t_flat.unsqueeze(2) # [groups, M, 1, C*H*W]
y_t_2 = ys_t_flat.unsqueeze(1) # [groups, 1, M, C*H*W]
y_r = ys_r_flat.unsqueeze(2) # [groups, M, 1, C*H*W]
y_r_2 = ys_r_flat.unsqueeze(1) # [groups, 1, M, C*H*W]
# Compute pairwise L2 distances
t_t_dist = torch.sqrt(((y_t - y_t_2) ** 2).sum(dim=-1) + 1e-8) # [groups, M, M]
r_r_dist = torch.sqrt(((y_r - y_r_2) ** 2).sum(dim=-1) + 1e-8) # [groups, M, M]
t_r_dist = torch.sqrt(((y_t - y_r_2) ** 2).sum(dim=-1) + 1e-8) # [groups, M, M]
# Apply energy kernel (-distance)
K_t_t = -t_t_dist
K_r_r = -r_r_dist
K_t_r = -t_r_dist
else:
raise ValueError(f"Unknown kernel type: {kernel_type}")
# Compute MMD for each group and average
group_mmd = (
K_t_t.sum(dim=[1, 2]) + K_r_r.sum(dim=[1, 2]) - 2 * K_t_r.sum(dim=[1, 2])
) / (M * M)
mmd_loss = group_mmd.mean()
return mmd_loss
def get_alpha_sigma_otfm(t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Get alpha and sigma for OT-FM schedule.
Args:
t: Time steps tensor of any shape
Returns:
Tuple of (alpha_t, sigma_t) with same shape as t
"""
return 1 - t, t
def get_alpha_sigma_cosine(t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Get alpha and sigma for cosine schedule.
Args:
t: Time steps tensor of any shape
Returns:
Tuple of (alpha_t, sigma_t) with same shape as t
"""
return torch.cos(0.5 * math.pi * t), torch.sin(0.5 * math.pi * t)
def ddim_interpolant(
xt: torch.Tensor,
x0: torch.Tensor,
s: torch.Tensor,
t: torch.Tensor,
schedule: str = "otfm",
) -> torch.Tensor:
"""DDIM interpolant from xt to xs given x0.
Args:
xt: Noisy images at time t, tensor of shape [B, C, H, W]
x0: Clean images, tensor of shape [B, C, H, W]
s: Target time steps tensor of shape [B, 1, 1, 1]
t: Current time steps tensor of shape [B, 1, 1, 1]
schedule: Noise schedule, one of "otfm" or "cosine"
Returns:
xs: Images at time s, tensor of shape [B, C, H, W]
"""
# Get alpha and sigma for both s and t
if schedule == "otfm":
alpha_t, sigma_t = get_alpha_sigma_otfm(t)
alpha_s, sigma_s = get_alpha_sigma_otfm(s)
elif schedule == "cosine":
alpha_t, sigma_t = get_alpha_sigma_cosine(t)
alpha_s, sigma_s = get_alpha_sigma_cosine(s)
else:
raise ValueError(f"Unknown schedule: {schedule}")
# DDIM interpolant formula
xs = alpha_s * x0 + (sigma_s / sigma_t) * (xt - alpha_t * x0)
return xs
def mapping_fn_eta_decay(
s: torch.Tensor,
t: torch.Tensor,
eta_max: float = 160.0,
eta_min: float = 0.0,
k: int = 12,
min_gap: float = 1e-4,
schedule: str = "otfm",
) -> torch.Tensor:
"""Compute r(s,t) with constant decrement in eta space.
Args:
s: Target time steps tensor of shape [B, 1, 1, 1]
t: Current time steps tensor of shape [B, 1, 1, 1]
eta_max: Maximum eta value
eta_min: Minimum eta value
k: Power factor for step size (larger k = smaller step)
min_gap: Minimum gap between r and t
schedule: Noise schedule, one of "otfm" or "cosine"
Returns:
r: Intermediate time steps tensor of shape [B, 1, 1, 1]
"""
# Convert from time to eta space depending on schedule
if schedule == "otfm":
# For OT-FM: eta_t = sigma_t / alpha_t = t / (1-t)
eta_t = t / (1 - t)
# Desired decrement in eta space
decrement = (eta_max - eta_min) / (2**k)
# Apply decrement to get eta_r
eta_r = torch.maximum(s / (1 - s), eta_t - decrement)
# Convert back to time
r = eta_r / (1 + eta_r)
elif schedule == "cosine":
# For cosine: eta_t = sigma_t / alpha_t = tan(pi*t/2)
eta_t = torch.tan(0.5 * math.pi * t)
# Desired decrement in eta space
decrement = (eta_max - eta_min) / (2**k)
# Apply decrement to get eta_r
eta_r = torch.maximum(torch.tan(0.5 * math.pi * s), eta_t - decrement)
# Convert back to time
r = 2 * torch.atan(eta_r) / math.pi
else:
raise ValueError(f"Unknown schedule: {schedule}")
# Ensure minimum gap between r and t for numerical stability
r = torch.minimum(r, t - min_gap)
return r
def compute_mmd_loss(
ys_t: torch.Tensor,
ys_r: torch.Tensor,
M: int = 4,
kernel_type: str = "laplace",
kernel_bandwidth: float = 1.0,
) -> torch.Tensor:
"""Compute MMD loss between two sets of samples.
THIS IS LOW EFFICIENCY REFERENCE IMPLEMENTATION
Args:
ys_t: First set of samples, tensor of shape [B, C, H, W]
ys_r: Second set of samples, tensor of shape [B, C, H, W]
M: Number of particles for MMD estimation
kernel_type: Type of kernel, one of "laplace", "rbf", or "energy"
kernel_bandwidth: Bandwidth parameter for kernel
Returns:
mmd_loss: Scalar tensor
"""
batch_size = ys_t.shape[0]
assert batch_size % M == 0, f"Batch size {batch_size} must be divisible by M={M}"
# Reshape for particle grouping
groups = batch_size // M
# Flatten spatial dimensions for easier distance computation
ys_t_flat = ys_t.reshape(groups, M, -1) # [groups, M, C*H*W]
ys_r_flat = ys_r.reshape(groups, M, -1) # [groups, M, C*H*W]
# Initialize total loss
total_loss = 0.0
for g in range(groups):
# Initialize kernel matrices
K_t_t = torch.zeros((M, M), device=ys_t.device)
K_r_r = torch.zeros((M, M), device=ys_t.device)
K_t_r = torch.zeros((M, M), device=ys_t.device)
# Compute pairwise kernel values
for i in range(M):
for j in range(M):
# Compute distance
if kernel_type == "laplace":
# Laplace kernel: k(x,y) = exp(-||x-y||₁/h)
dist_t_t = torch.norm(ys_t_flat[g, i] - ys_t_flat[g, j], p=1)
dist_r_r = torch.norm(ys_r_flat[g, i] - ys_r_flat[g, j], p=1)
dist_t_r = torch.norm(ys_t_flat[g, i] - ys_r_flat[g, j], p=1)
K_t_t[i, j] = torch.exp(-dist_t_t / kernel_bandwidth)
K_r_r[i, j] = torch.exp(-dist_r_r / kernel_bandwidth)
K_t_r[i, j] = torch.exp(-dist_t_r / kernel_bandwidth)
elif kernel_type == "rbf":
# RBF kernel: k(x,y) = exp(-||x-y||²/2h²)
dist_t_t = torch.sum((ys_t_flat[g, i] - ys_t_flat[g, j]) ** 2)
dist_r_r = torch.sum((ys_r_flat[g, i] - ys_r_flat[g, j]) ** 2)
dist_t_r = torch.sum((ys_t_flat[g, i] - ys_r_flat[g, j]) ** 2)
K_t_t[i, j] = torch.exp(-dist_t_t / (2 * kernel_bandwidth**2))
K_r_r[i, j] = torch.exp(-dist_r_r / (2 * kernel_bandwidth**2))
K_t_r[i, j] = torch.exp(-dist_t_r / (2 * kernel_bandwidth**2))
elif kernel_type == "energy":
# Energy kernel: k(x,y) = -||x-y||₂
if i != j: # Avoid self-distance which is always 0
K_t_t[i, j] = -torch.norm(ys_t_flat[g, i] - ys_t_flat[g, j])
K_r_r[i, j] = -torch.norm(ys_r_flat[g, i] - ys_r_flat[g, j])
K_t_r[i, j] = -torch.norm(ys_t_flat[g, i] - ys_r_flat[g, j])
else:
raise ValueError(f"Unknown kernel type: {kernel_type}")
# MMD estimate using kernel matrices
mmd_estimate = (K_t_t.sum() + K_r_r.sum() - 2 * K_t_r.sum()) / (M * M)
total_loss += mmd_estimate
return total_loss / groups
if __name__ == "__main__":
M = 4
test_x = torch.randn(16, 1, 28, 28)
test_x2 = test_x + torch.randn_like(test_x) * 0.05
loss1 = compute_mmd_loss(test_x, test_x2, M=M)
loss2 = compute_mmd_loss_fully_vectorized(test_x, test_x2, M=M)
print(F.mse_loss(loss1, loss2))
import os
import ffmpeg
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.transforms import transforms as trns
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from PIL import Image
from tqdm import tqdm, trange
from anyschedule import AnySchedule
from imm import (
mapping_fn_eta_decay,
compute_mmd_loss_fully_vectorized,
ddim_interpolant
)
class Gen(nn.Module):
def __init__(
self,
input_dim=28 * 28 * 1,
hidden_dim=1024,
classes=None,
class_embed_dim=128,
):
super(Gen, self).__init__()
if classes is not None:
self.class_embed = nn.Sequential(
nn.Embedding(classes, class_embed_dim),
nn.LayerNorm(class_embed_dim),
)
self.model = nn.Sequential(
nn.Linear(input_dim + 2 + bool(classes) * class_embed_dim, hidden_dim),
nn.Mish(),
nn.Linear(hidden_dim, hidden_dim),
nn.Mish(),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, input_dim),
)
# nn.init.constant_(self.model[-1].weight, 0)
# nn.init.constant_(self.model[-1].bias, 0)
def forward(self, x, t, s, cond=None):
x = torch.cat([x, t, s], dim=-1)
if cond is not None:
x = torch.cat([x, self.class_embed(cond)], dim=-1)
return self.model(x)
def train(epoch, model, optimizer, lr_sch, dataloader, train_with_cond=False):
model.train()
ema_loss = 0
with tqdm(dataloader, desc=f"Epoch {epoch}", smoothing=0.01) as pbar:
for i, (x, cond) in enumerate(dataloader):
x = x.view(-1, 28 * 28 * 1).to(DEVICE)
cond = cond.to(DEVICE).long()
## Flow Matching parameterization + IMM Training
M = 8
max_time = 0.994
min_time = 0
t = torch.rand((x.shape[0]//M, 1), device=DEVICE) * (max_time - min_time) + min_time
s = torch.rand((x.shape[0]//M, 1), device=DEVICE) * (t - min_time) + min_time
t = t.repeat_interleave(M, 0)
s = s.repeat_interleave(M, 0)
r = mapping_fn_eta_decay(s, t,)
eps = torch.randn_like(x)
xt = x * (1 - t) + eps * t
xs = x * (1 - s) + eps * s
xr = ddim_interpolant(xt, x, s, t)
optimizer.zero_grad()
with torch.no_grad():
pred_xs_r = xr + (s-r) * model(xr, r, s, cond if train_with_cond else None)
pred_xs_t = xt + (s-t) * model(xt, t, s, cond if train_with_cond else None)
mmd_loss = compute_mmd_loss_fully_vectorized(pred_xs_t, pred_xs_r, M=M)
target_loss = F.mse_loss(pred_xs_t, xs)
loss = mmd_loss
loss.backward()
optimizer.step()
lr_sch.step()
# logging the target loss instead of mdd loss
ema_decay = min(0.99, i / 100)
ema_loss = ema_decay * ema_loss + (1 - ema_decay) * target_loss.item()
pbar.update(1)
pbar.set_postfix({"loss": ema_loss})
torch.save(model.state_dict(), "mnist-gen.pth")
def test(epoch, model, gen_with_cond=False):
rng_state = torch.get_rng_state()
torch.manual_seed(0)
model.eval()
IMAGE_COUNT = 16 * 16
with torch.no_grad():
pred_x = torch.randn(IMAGE_COUNT, 28 * 28 * 1).to(DEVICE)
cond = torch.arange(IMAGE_COUNT).long().to(DEVICE) % 10
t = torch.ones(IMAGE_COUNT, 1).to(DEVICE)
STEPS = 4
dt = 1 / STEPS
for i in range(STEPS):
pred = model(pred_x, t, t-dt, cond if gen_with_cond else None)
pred_x = pred_x - pred * dt
t = t - dt
# save image as single grid
pred_x = pred_x.reshape(16, 16, 28, 28).permute(0, 2, 1, 3) * 0.5 + 0.5
pred_x = pred_x.reshape(16 * 28, 16 * 28).cpu().numpy()
pred_x = (pred_x * 255).clip(0, 255).astype(np.uint8)
pred_x = Image.fromarray(pred_x)
pred_x.save(f"./mnist-imm-result/gen-{epoch}.png")
torch.set_rng_state(rng_state)
if __name__ == "__main__":
os.makedirs("./mnist-imm-result", exist_ok=True)
DEVICE = "cuda"
CLASSES = 10
EPOCHS = 100
transform = trns.Compose([trns.ToTensor(), trns.Normalize((0.5,), (0.5,))])
dataset = MNIST("./data", download=True, transform=transform)
dataloader = DataLoader(
dataset,
batch_size=512,
shuffle=True,
num_workers=16,
pin_memory=True,
persistent_workers=True,
)
model = Gen(
input_dim=784, hidden_dim=4096, classes=CLASSES, class_embed_dim=239
).to(DEVICE)
print(sum(p.numel() for p in model.parameters()) / 1e6)
optimizer = optim.AdamW(
model.parameters(), 5e-3
)
lr_sch = AnySchedule(optimizer, config={
"lr": {
"mode": "cosine",
"min_value": 0.01,
"end": len(dataloader) * EPOCHS + 1,
}
})
for i in trange(EPOCHS):
test(i, model, gen_with_cond=bool(CLASSES))
train(i, model, optimizer, lr_sch, dataloader, train_with_cond=bool(CLASSES))
test(i + 1, model, gen_with_cond=bool(CLASSES))
stream = ffmpeg.input(
"./mnist-imm-result/gen-%d.png", pattern_type="sequence", framerate=24
)
stream = ffmpeg.output(stream, "mnist-imm-gen.mp4", crf=20, pix_fmt="yuv420p")
ffmpeg.run(stream, overwrite_output=True, capture_stdout=True, capture_stderr=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment