Skip to content

Instantly share code, notes, and snippets.

@autoregression
Last active January 18, 2025 12:31
Show Gist options
  • Save autoregression/9fdd829c9f0603a9e8799aeb8df7ca0e to your computer and use it in GitHub Desktop.
Save autoregression/9fdd829c9f0603a9e8799aeb8df7ca0e to your computer and use it in GitHub Desktop.
"""DiM (Diffusion Mixer)."""
import math
import typing
import einops
import torch
class DiMConfig(typing.NamedTuple):
input_dimension: int = 3
hidden_dimension: int = 256
layers: int = 4
sequence_length: int = 256
patch_size: int = 2
def zero(module: torch.nn.Module) -> torch.nn.Module:
torch.nn.init.zeros_(module.weight)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
return module
def modulate(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
return x * (1 + scale)
class Fourier(torch.nn.Module):
def __init__(self, hidden_dimension: int) -> None:
super().__init__()
self.register_buffer("scales", torch.randn((hidden_dimension // 2, 1)))
self.linear = torch.nn.Linear(hidden_dimension, hidden_dimension, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor: # assume x is (B,).
x = 2 * math.pi * x.unsqueeze(-1) @ self.scales.T
x = self.linear(torch.cat([x.cos(), x.sin()], dim=-1))
return x[:, None, :]
class MLP(torch.nn.Module):
def __init__(self, hidden_dimension: int, ratio: int) -> None:
super().__init__()
self.linear_1 = torch.nn.Linear(hidden_dimension, hidden_dimension * ratio, bias=False)
self.linear_2 = torch.nn.Linear(hidden_dimension, hidden_dimension * ratio, bias=False)
self.linear_3 = zero(torch.nn.Linear(hidden_dimension * ratio, hidden_dimension, bias=False))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear_1(x) * torch.nn.functional.silu(self.linear_2(x))
x = self.linear_3(x)
return x
class Block(torch.nn.Module):
def __init__(self, config: DiMConfig) -> None:
super().__init__()
self.norm_1 = torch.nn.LayerNorm(config.hidden_dimension, bias=False)
self.norm_2 = torch.nn.LayerNorm(config.hidden_dimension, bias=False)
self.mlp_1 = MLP(config.sequence_length, ratio=1)
self.mlp_2 = MLP(config.hidden_dimension, ratio=3)
self.modulation = zero(torch.nn.Linear(config.hidden_dimension, config.hidden_dimension * 2))
def forward(self, x: torch.Tensor, time: torch.Tensor) -> torch.Tensor:
s1, s2 = self.modulation(torch.nn.functional.silu(time)).chunk(2, dim=-1)
x = x + self.mlp_1(modulate(self.norm_1(x), s1).transpose(-1, -2)).transpose(-1, -2)
x = x + self.mlp_2(modulate(self.norm_2(x), s2))
return x
class DiM(torch.nn.Module):
def __init__(self, config: DiMConfig) -> None:
super().__init__()
self.norm = torch.nn.LayerNorm(config.hidden_dimension)
self.fourier = Fourier(config.hidden_dimension)
self.blocks = torch.nn.ModuleList([Block(config) for _ in range(config.layers)])
self.patch = torch.nn.Conv2d(
config.input_dimension,
config.hidden_dimension,
config.patch_size,
config.patch_size,
padding=0,
bias=False,
)
self.unpatch = zero(
torch.nn.ConvTranspose2d(
config.hidden_dimension,
config.input_dimension,
config.patch_size,
config.patch_size,
padding=0,
bias=False,
)
)
def forward(self, x: torch.Tensor, time: torch.Tensor) -> torch.Tensor:
time = self.fourier(time)
x = self.patch(x)
h = x.size(-2)
x = einops.rearrange(x, "b c h w -> b (h w) c")
for block in self.blocks:
x = block(x, time)
x = self.norm(x)
x = einops.rearrange(x, "b (h w) c -> b c h w", h=h)
x = self.unpatch(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment