Skip to content

Instantly share code, notes, and snippets.

@TadaoYamaoka
Last active February 27, 2025 05:08
Show Gist options
  • Save TadaoYamaoka/f3dd151a6994071a774a1604815484a2 to your computer and use it in GitHub Desktop.
Save TadaoYamaoka/f3dd151a6994071a774a1604815484a2 to your computer and use it in GitHub Desktop.
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
import torchvision
from scipy import integrate
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
from unet import Unet
batch_size = 1024
learning_rate = 0.001
num_epochs = 1000
eps = 0.001
condition = True
def euler_sampler(model, shape, sample_N, device):
model.eval()
cond = torch.arange(10).repeat(shape[0] // 10).to(device) if condition else None
with torch.no_grad():
z0 = torch.randn(shape, device=device)
x = z0.detach().clone()
dt = 1.0 / sample_N
for i in range(sample_N):
num_t = i / sample_N * (1 - eps) + eps
t = torch.ones(shape[0], device=device) * num_t
pred = model(x, t * 999, cond)
x = x.detach().clone() + pred * dt
nfe = sample_N
return x.cpu(), nfe
def to_flattened_numpy(x):
return x.detach().cpu().numpy().reshape((-1,))
def from_flattened_numpy(x, shape):
return torch.from_numpy(x.reshape(shape))
def rk45_sampler(model, shape, device):
rtol = atol = 1e-05
model.eval()
cond = torch.arange(10).repeat(shape[0] // 10).to(device) if condition else None
with torch.no_grad():
z0 = torch.randn(shape, device=device)
x = z0.detach().clone()
def ode_func(t, x):
x = from_flattened_numpy(x, shape).to(device).type(torch.float32)
vec_t = torch.ones(shape[0], device=x.device) * t
drift = model(x, vec_t * 999, cond)
return to_flattened_numpy(drift)
solution = integrate.solve_ivp(
ode_func,
(eps, 1),
to_flattened_numpy(x),
rtol=rtol,
atol=atol,
method="RK45",
)
nfe = solution.nfev
x = torch.tensor(solution.y[:, -1]).reshape(shape).type(torch.float32)
return x, nfe
def imshow(img, filename):
img = img * 0.5 + 0.5
img = np.clip(img, 0, 1)
npimg = img.permute(1, 2, 0).numpy()
plt.imshow(npimg)
plt.axis("off")
plt.savefig(filename, bbox_inches="tight", pad_inches=0)
def save_img_grid(img, filename):
img_grid = torchvision.utils.make_grid(img, nrow=10)
imshow(img_grid, os.path.join("output_cifar10", filename))
def eval(model, epoch, method, device, sample_N=None, batch_size=100):
if method == "euler":
images, nfe = euler_sampler(
model, shape=(batch_size, 3, 32, 32), sample_N=sample_N, device=device
)
elif method == "rk45":
images, nfe = rk45_sampler(model, shape=(batch_size, 3, 32, 32), device=device)
save_img_grid(images, f"{method}_epoch_{epoch + 1}_nfe_{nfe}.png")
def main():
os.makedirs("output_cifar10", exist_ok=True)
writer = SummaryWriter(log_dir="runs/experiment1")
transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
train_dataset = datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform
)
dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Unet(
dim=32,
channels=3,
dim_mults=(1, 2, 4),
condition=condition,
)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
total_loss = 0
model.train()
for batch, cond in dataloader:
batch = batch.to(device)
optimizer.zero_grad()
z0 = torch.randn_like(batch)
t = torch.rand(batch.shape[0], device=device) * (1 - eps) + eps
t_expand = t.view(-1, 1, 1, 1).repeat(
1, batch.shape[1], batch.shape[2], batch.shape[3]
)
perturbed_data = t_expand * batch + (1 - t_expand) * z0
target = batch - z0
score = model(
perturbed_data, t * 999, cond.to(device) if condition else None
)
losses = torch.square(score - target)
losses = torch.mean(losses.reshape(losses.shape[0], -1), dim=-1)
loss = torch.mean(losses)
loss.backward()
optimizer.step()
total_loss += loss.item()
writer.add_scalar("Loss/train", total_loss / len(dataloader), epoch)
if epoch < 10 or (epoch + 1) % 10 == 0:
eval(model, epoch, "euler", device, sample_N=1)
eval(model, epoch, "euler", device, sample_N=2)
eval(model, epoch, "euler", device, sample_N=10)
eval(model, epoch, "rk45", device)
if (epoch + 1) % 100 == 0:
torch.save(
model.state_dict(),
os.path.join("output_cifar10", f"model_epoch_{epoch + 1}.pt"),
)
if __name__ == "__main__":
main()
@drscotthawley
Copy link

drscotthawley commented Jan 15, 2025

Dear Yamaokasan,
Thank you for sharing your code. I am adapting it a bit to work on Oxford Flowers dataset at higher resolution as an academic learning project: https://github.com/drscotthawley/flow-matching-flowers
Since the code there is mostly yours, is there a License that you assign to your code? If so I will use the same license in my repository.
Best wishes,
Scott Hawley

@TadaoYamaoka
Copy link
Author

Dear Scott,

Thank you for reaching out. I’m glad to hear you’re finding the code useful for your academic project!

I have not assigned a specific license to the code, so please feel free to use it as you see fit.

Best wishes,
Yamaoka

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment