Created
March 12, 2025 08:36
-
-
Save KohakuBlueleaf/85cd080e9e5cd9e84582e22a8e8770c1 to your computer and use it in GitHub Desktop.
A simple implementation of IMM https://arxiv.org/pdf/2503.07565
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
An Minimal Implementation of IMM (Inductive Moment Matching) | |
""" | |
import math | |
import torch | |
import torch.nn.functional as F | |
def compute_mmd_loss_fully_vectorized( | |
ys_t: torch.Tensor, | |
ys_r: torch.Tensor, | |
M: int = 4, | |
kernel_type: str = "laplace", | |
kernel_bandwidth: float = 1.0, | |
) -> torch.Tensor: | |
"""Compute MMD loss with fully vectorized operations across all groups. | |
Args: | |
ys_t: First set of samples, tensor of shape [B, C, H, W] | |
ys_r: Second set of samples, tensor of shape [B, C, H, W] | |
M: Number of particles for MMD estimation | |
kernel_type: Type of kernel, one of "laplace", "rbf", or "energy" | |
kernel_bandwidth: Bandwidth parameter for kernel | |
Returns: | |
mmd_loss: Scalar tensor | |
""" | |
batch_size = ys_t.shape[0] | |
assert batch_size % M == 0, f"Batch size {batch_size} must be divisible by M={M}" | |
# Reshape for particle grouping | |
groups = batch_size // M | |
# Flatten spatial dimensions | |
ys_t_flat = ys_t.reshape(groups, M, -1) # [groups, M, C*H*W] | |
ys_r_flat = ys_r.reshape(groups, M, -1) # [groups, M, C*H*W] | |
# Vectorized computation of kernels for all groups at once | |
if kernel_type == "laplace": | |
# For each group, compute expanded tensors for broadcasting | |
y_t = ys_t_flat.unsqueeze(2) # [groups, M, 1, C*H*W] | |
y_t_2 = ys_t_flat.unsqueeze(1) # [groups, 1, M, C*H*W] | |
y_r = ys_r_flat.unsqueeze(2) # [groups, M, 1, C*H*W] | |
y_r_2 = ys_r_flat.unsqueeze(1) # [groups, 1, M, C*H*W] | |
# Compute pairwise L1 distances | |
t_t_dist = torch.abs(y_t - y_t_2).sum(dim=-1) # [groups, M, M] | |
r_r_dist = torch.abs(y_r - y_r_2).sum(dim=-1) # [groups, M, M] | |
t_r_dist = torch.abs(y_t - y_r_2).sum(dim=-1) # [groups, M, M] | |
# Apply Laplace kernel | |
K_t_t = torch.exp(-t_t_dist / kernel_bandwidth) | |
K_r_r = torch.exp(-r_r_dist / kernel_bandwidth) | |
K_t_r = torch.exp(-t_r_dist / kernel_bandwidth) | |
elif kernel_type == "rbf": | |
# For each group, compute expanded tensors for broadcasting | |
y_t = ys_t_flat.unsqueeze(2) # [groups, M, 1, C*H*W] | |
y_t_2 = ys_t_flat.unsqueeze(1) # [groups, 1, M, C*H*W] | |
y_r = ys_r_flat.unsqueeze(2) # [groups, M, 1, C*H*W] | |
y_r_2 = ys_r_flat.unsqueeze(1) # [groups, 1, M, C*H*W] | |
# Compute pairwise squared L2 distances | |
t_t_dist = ((y_t - y_t_2) ** 2).sum(dim=-1) # [groups, M, M] | |
r_r_dist = ((y_r - y_r_2) ** 2).sum(dim=-1) # [groups, M, M] | |
t_r_dist = ((y_t - y_r_2) ** 2).sum(dim=-1) # [groups, M, M] | |
# Apply RBF kernel | |
K_t_t = torch.exp(-t_t_dist / (2 * kernel_bandwidth**2)) | |
K_r_r = torch.exp(-r_r_dist / (2 * kernel_bandwidth**2)) | |
K_t_r = torch.exp(-t_r_dist / (2 * kernel_bandwidth**2)) | |
elif kernel_type == "energy": | |
# For each group, compute expanded tensors for broadcasting | |
y_t = ys_t_flat.unsqueeze(2) # [groups, M, 1, C*H*W] | |
y_t_2 = ys_t_flat.unsqueeze(1) # [groups, 1, M, C*H*W] | |
y_r = ys_r_flat.unsqueeze(2) # [groups, M, 1, C*H*W] | |
y_r_2 = ys_r_flat.unsqueeze(1) # [groups, 1, M, C*H*W] | |
# Compute pairwise L2 distances | |
t_t_dist = torch.sqrt(((y_t - y_t_2) ** 2).sum(dim=-1) + 1e-8) # [groups, M, M] | |
r_r_dist = torch.sqrt(((y_r - y_r_2) ** 2).sum(dim=-1) + 1e-8) # [groups, M, M] | |
t_r_dist = torch.sqrt(((y_t - y_r_2) ** 2).sum(dim=-1) + 1e-8) # [groups, M, M] | |
# Apply energy kernel (-distance) | |
K_t_t = -t_t_dist | |
K_r_r = -r_r_dist | |
K_t_r = -t_r_dist | |
else: | |
raise ValueError(f"Unknown kernel type: {kernel_type}") | |
# Compute MMD for each group and average | |
group_mmd = ( | |
K_t_t.sum(dim=[1, 2]) + K_r_r.sum(dim=[1, 2]) - 2 * K_t_r.sum(dim=[1, 2]) | |
) / (M * M) | |
mmd_loss = group_mmd.mean() | |
return mmd_loss | |
def get_alpha_sigma_otfm(t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
"""Get alpha and sigma for OT-FM schedule. | |
Args: | |
t: Time steps tensor of any shape | |
Returns: | |
Tuple of (alpha_t, sigma_t) with same shape as t | |
""" | |
return 1 - t, t | |
def get_alpha_sigma_cosine(t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
"""Get alpha and sigma for cosine schedule. | |
Args: | |
t: Time steps tensor of any shape | |
Returns: | |
Tuple of (alpha_t, sigma_t) with same shape as t | |
""" | |
return torch.cos(0.5 * math.pi * t), torch.sin(0.5 * math.pi * t) | |
def ddim_interpolant( | |
xt: torch.Tensor, | |
x0: torch.Tensor, | |
s: torch.Tensor, | |
t: torch.Tensor, | |
schedule: str = "otfm", | |
) -> torch.Tensor: | |
"""DDIM interpolant from xt to xs given x0. | |
Args: | |
xt: Noisy images at time t, tensor of shape [B, C, H, W] | |
x0: Clean images, tensor of shape [B, C, H, W] | |
s: Target time steps tensor of shape [B, 1, 1, 1] | |
t: Current time steps tensor of shape [B, 1, 1, 1] | |
schedule: Noise schedule, one of "otfm" or "cosine" | |
Returns: | |
xs: Images at time s, tensor of shape [B, C, H, W] | |
""" | |
# Get alpha and sigma for both s and t | |
if schedule == "otfm": | |
alpha_t, sigma_t = get_alpha_sigma_otfm(t) | |
alpha_s, sigma_s = get_alpha_sigma_otfm(s) | |
elif schedule == "cosine": | |
alpha_t, sigma_t = get_alpha_sigma_cosine(t) | |
alpha_s, sigma_s = get_alpha_sigma_cosine(s) | |
else: | |
raise ValueError(f"Unknown schedule: {schedule}") | |
# DDIM interpolant formula | |
xs = alpha_s * x0 + (sigma_s / sigma_t) * (xt - alpha_t * x0) | |
return xs | |
def mapping_fn_eta_decay( | |
s: torch.Tensor, | |
t: torch.Tensor, | |
eta_max: float = 160.0, | |
eta_min: float = 0.0, | |
k: int = 12, | |
min_gap: float = 1e-4, | |
schedule: str = "otfm", | |
) -> torch.Tensor: | |
"""Compute r(s,t) with constant decrement in eta space. | |
Args: | |
s: Target time steps tensor of shape [B, 1, 1, 1] | |
t: Current time steps tensor of shape [B, 1, 1, 1] | |
eta_max: Maximum eta value | |
eta_min: Minimum eta value | |
k: Power factor for step size (larger k = smaller step) | |
min_gap: Minimum gap between r and t | |
schedule: Noise schedule, one of "otfm" or "cosine" | |
Returns: | |
r: Intermediate time steps tensor of shape [B, 1, 1, 1] | |
""" | |
# Convert from time to eta space depending on schedule | |
if schedule == "otfm": | |
# For OT-FM: eta_t = sigma_t / alpha_t = t / (1-t) | |
eta_t = t / (1 - t) | |
# Desired decrement in eta space | |
decrement = (eta_max - eta_min) / (2**k) | |
# Apply decrement to get eta_r | |
eta_r = torch.maximum(s / (1 - s), eta_t - decrement) | |
# Convert back to time | |
r = eta_r / (1 + eta_r) | |
elif schedule == "cosine": | |
# For cosine: eta_t = sigma_t / alpha_t = tan(pi*t/2) | |
eta_t = torch.tan(0.5 * math.pi * t) | |
# Desired decrement in eta space | |
decrement = (eta_max - eta_min) / (2**k) | |
# Apply decrement to get eta_r | |
eta_r = torch.maximum(torch.tan(0.5 * math.pi * s), eta_t - decrement) | |
# Convert back to time | |
r = 2 * torch.atan(eta_r) / math.pi | |
else: | |
raise ValueError(f"Unknown schedule: {schedule}") | |
# Ensure minimum gap between r and t for numerical stability | |
r = torch.minimum(r, t - min_gap) | |
return r | |
def compute_mmd_loss( | |
ys_t: torch.Tensor, | |
ys_r: torch.Tensor, | |
M: int = 4, | |
kernel_type: str = "laplace", | |
kernel_bandwidth: float = 1.0, | |
) -> torch.Tensor: | |
"""Compute MMD loss between two sets of samples. | |
THIS IS LOW EFFICIENCY REFERENCE IMPLEMENTATION | |
Args: | |
ys_t: First set of samples, tensor of shape [B, C, H, W] | |
ys_r: Second set of samples, tensor of shape [B, C, H, W] | |
M: Number of particles for MMD estimation | |
kernel_type: Type of kernel, one of "laplace", "rbf", or "energy" | |
kernel_bandwidth: Bandwidth parameter for kernel | |
Returns: | |
mmd_loss: Scalar tensor | |
""" | |
batch_size = ys_t.shape[0] | |
assert batch_size % M == 0, f"Batch size {batch_size} must be divisible by M={M}" | |
# Reshape for particle grouping | |
groups = batch_size // M | |
# Flatten spatial dimensions for easier distance computation | |
ys_t_flat = ys_t.reshape(groups, M, -1) # [groups, M, C*H*W] | |
ys_r_flat = ys_r.reshape(groups, M, -1) # [groups, M, C*H*W] | |
# Initialize total loss | |
total_loss = 0.0 | |
for g in range(groups): | |
# Initialize kernel matrices | |
K_t_t = torch.zeros((M, M), device=ys_t.device) | |
K_r_r = torch.zeros((M, M), device=ys_t.device) | |
K_t_r = torch.zeros((M, M), device=ys_t.device) | |
# Compute pairwise kernel values | |
for i in range(M): | |
for j in range(M): | |
# Compute distance | |
if kernel_type == "laplace": | |
# Laplace kernel: k(x,y) = exp(-||x-y||₁/h) | |
dist_t_t = torch.norm(ys_t_flat[g, i] - ys_t_flat[g, j], p=1) | |
dist_r_r = torch.norm(ys_r_flat[g, i] - ys_r_flat[g, j], p=1) | |
dist_t_r = torch.norm(ys_t_flat[g, i] - ys_r_flat[g, j], p=1) | |
K_t_t[i, j] = torch.exp(-dist_t_t / kernel_bandwidth) | |
K_r_r[i, j] = torch.exp(-dist_r_r / kernel_bandwidth) | |
K_t_r[i, j] = torch.exp(-dist_t_r / kernel_bandwidth) | |
elif kernel_type == "rbf": | |
# RBF kernel: k(x,y) = exp(-||x-y||²/2h²) | |
dist_t_t = torch.sum((ys_t_flat[g, i] - ys_t_flat[g, j]) ** 2) | |
dist_r_r = torch.sum((ys_r_flat[g, i] - ys_r_flat[g, j]) ** 2) | |
dist_t_r = torch.sum((ys_t_flat[g, i] - ys_r_flat[g, j]) ** 2) | |
K_t_t[i, j] = torch.exp(-dist_t_t / (2 * kernel_bandwidth**2)) | |
K_r_r[i, j] = torch.exp(-dist_r_r / (2 * kernel_bandwidth**2)) | |
K_t_r[i, j] = torch.exp(-dist_t_r / (2 * kernel_bandwidth**2)) | |
elif kernel_type == "energy": | |
# Energy kernel: k(x,y) = -||x-y||₂ | |
if i != j: # Avoid self-distance which is always 0 | |
K_t_t[i, j] = -torch.norm(ys_t_flat[g, i] - ys_t_flat[g, j]) | |
K_r_r[i, j] = -torch.norm(ys_r_flat[g, i] - ys_r_flat[g, j]) | |
K_t_r[i, j] = -torch.norm(ys_t_flat[g, i] - ys_r_flat[g, j]) | |
else: | |
raise ValueError(f"Unknown kernel type: {kernel_type}") | |
# MMD estimate using kernel matrices | |
mmd_estimate = (K_t_t.sum() + K_r_r.sum() - 2 * K_t_r.sum()) / (M * M) | |
total_loss += mmd_estimate | |
return total_loss / groups | |
if __name__ == "__main__": | |
M = 4 | |
test_x = torch.randn(16, 1, 28, 28) | |
test_x2 = test_x + torch.randn_like(test_x) * 0.05 | |
loss1 = compute_mmd_loss(test_x, test_x2, M=M) | |
loss2 = compute_mmd_loss_fully_vectorized(test_x, test_x2, M=M) | |
print(F.mse_loss(loss1, loss2)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import ffmpeg | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torchvision.transforms import transforms as trns | |
from torchvision.datasets import MNIST | |
from torch.utils.data import DataLoader | |
from PIL import Image | |
from tqdm import tqdm, trange | |
from anyschedule import AnySchedule | |
from imm import ( | |
mapping_fn_eta_decay, | |
compute_mmd_loss_fully_vectorized, | |
ddim_interpolant | |
) | |
class Gen(nn.Module): | |
def __init__( | |
self, | |
input_dim=28 * 28 * 1, | |
hidden_dim=1024, | |
classes=None, | |
class_embed_dim=128, | |
): | |
super(Gen, self).__init__() | |
if classes is not None: | |
self.class_embed = nn.Sequential( | |
nn.Embedding(classes, class_embed_dim), | |
nn.LayerNorm(class_embed_dim), | |
) | |
self.model = nn.Sequential( | |
nn.Linear(input_dim + 2 + bool(classes) * class_embed_dim, hidden_dim), | |
nn.Mish(), | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.Mish(), | |
nn.LayerNorm(hidden_dim), | |
nn.Linear(hidden_dim, input_dim), | |
) | |
# nn.init.constant_(self.model[-1].weight, 0) | |
# nn.init.constant_(self.model[-1].bias, 0) | |
def forward(self, x, t, s, cond=None): | |
x = torch.cat([x, t, s], dim=-1) | |
if cond is not None: | |
x = torch.cat([x, self.class_embed(cond)], dim=-1) | |
return self.model(x) | |
def train(epoch, model, optimizer, lr_sch, dataloader, train_with_cond=False): | |
model.train() | |
ema_loss = 0 | |
with tqdm(dataloader, desc=f"Epoch {epoch}", smoothing=0.01) as pbar: | |
for i, (x, cond) in enumerate(dataloader): | |
x = x.view(-1, 28 * 28 * 1).to(DEVICE) | |
cond = cond.to(DEVICE).long() | |
## Flow Matching parameterization + IMM Training | |
M = 8 | |
max_time = 0.994 | |
min_time = 0 | |
t = torch.rand((x.shape[0]//M, 1), device=DEVICE) * (max_time - min_time) + min_time | |
s = torch.rand((x.shape[0]//M, 1), device=DEVICE) * (t - min_time) + min_time | |
t = t.repeat_interleave(M, 0) | |
s = s.repeat_interleave(M, 0) | |
r = mapping_fn_eta_decay(s, t,) | |
eps = torch.randn_like(x) | |
xt = x * (1 - t) + eps * t | |
xs = x * (1 - s) + eps * s | |
xr = ddim_interpolant(xt, x, s, t) | |
optimizer.zero_grad() | |
with torch.no_grad(): | |
pred_xs_r = xr + (s-r) * model(xr, r, s, cond if train_with_cond else None) | |
pred_xs_t = xt + (s-t) * model(xt, t, s, cond if train_with_cond else None) | |
mmd_loss = compute_mmd_loss_fully_vectorized(pred_xs_t, pred_xs_r, M=M) | |
target_loss = F.mse_loss(pred_xs_t, xs) | |
loss = mmd_loss | |
loss.backward() | |
optimizer.step() | |
lr_sch.step() | |
# logging the target loss instead of mdd loss | |
ema_decay = min(0.99, i / 100) | |
ema_loss = ema_decay * ema_loss + (1 - ema_decay) * target_loss.item() | |
pbar.update(1) | |
pbar.set_postfix({"loss": ema_loss}) | |
torch.save(model.state_dict(), "mnist-gen.pth") | |
def test(epoch, model, gen_with_cond=False): | |
rng_state = torch.get_rng_state() | |
torch.manual_seed(0) | |
model.eval() | |
IMAGE_COUNT = 16 * 16 | |
with torch.no_grad(): | |
pred_x = torch.randn(IMAGE_COUNT, 28 * 28 * 1).to(DEVICE) | |
cond = torch.arange(IMAGE_COUNT).long().to(DEVICE) % 10 | |
t = torch.ones(IMAGE_COUNT, 1).to(DEVICE) | |
STEPS = 4 | |
dt = 1 / STEPS | |
for i in range(STEPS): | |
pred = model(pred_x, t, t-dt, cond if gen_with_cond else None) | |
pred_x = pred_x - pred * dt | |
t = t - dt | |
# save image as single grid | |
pred_x = pred_x.reshape(16, 16, 28, 28).permute(0, 2, 1, 3) * 0.5 + 0.5 | |
pred_x = pred_x.reshape(16 * 28, 16 * 28).cpu().numpy() | |
pred_x = (pred_x * 255).clip(0, 255).astype(np.uint8) | |
pred_x = Image.fromarray(pred_x) | |
pred_x.save(f"./mnist-imm-result/gen-{epoch}.png") | |
torch.set_rng_state(rng_state) | |
if __name__ == "__main__": | |
os.makedirs("./mnist-imm-result", exist_ok=True) | |
DEVICE = "cuda" | |
CLASSES = 10 | |
EPOCHS = 100 | |
transform = trns.Compose([trns.ToTensor(), trns.Normalize((0.5,), (0.5,))]) | |
dataset = MNIST("./data", download=True, transform=transform) | |
dataloader = DataLoader( | |
dataset, | |
batch_size=512, | |
shuffle=True, | |
num_workers=16, | |
pin_memory=True, | |
persistent_workers=True, | |
) | |
model = Gen( | |
input_dim=784, hidden_dim=4096, classes=CLASSES, class_embed_dim=239 | |
).to(DEVICE) | |
print(sum(p.numel() for p in model.parameters()) / 1e6) | |
optimizer = optim.AdamW( | |
model.parameters(), 5e-3 | |
) | |
lr_sch = AnySchedule(optimizer, config={ | |
"lr": { | |
"mode": "cosine", | |
"min_value": 0.01, | |
"end": len(dataloader) * EPOCHS + 1, | |
} | |
}) | |
for i in trange(EPOCHS): | |
test(i, model, gen_with_cond=bool(CLASSES)) | |
train(i, model, optimizer, lr_sch, dataloader, train_with_cond=bool(CLASSES)) | |
test(i + 1, model, gen_with_cond=bool(CLASSES)) | |
stream = ffmpeg.input( | |
"./mnist-imm-result/gen-%d.png", pattern_type="sequence", framerate=24 | |
) | |
stream = ffmpeg.output(stream, "mnist-imm-gen.mp4", crf=20, pix_fmt="yuv420p") | |
ffmpeg.run(stream, overwrite_output=True, capture_stdout=True, capture_stderr=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment