Created
September 29, 2024 06:13
-
-
Save TadaoYamaoka/77e2239f1075006980fba44b37e129f4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
from functools import partial | |
import torch | |
from einops import rearrange | |
from einops.layers.torch import Rearrange | |
from torch import einsum, nn | |
class SinusoidalPositionEmbeddings(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
def forward(self, time): | |
device = time.device | |
half_dim = self.dim // 2 | |
embeddings = math.log(10000) / (half_dim - 1) | |
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) | |
embeddings = time[:, None] * embeddings[None, :] | |
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) | |
return embeddings | |
class Residual(nn.Module): | |
def __init__(self, fn): | |
super().__init__() | |
self.fn = fn | |
def forward(self, x, *args, **kwargs): | |
return self.fn(x, *args, **kwargs) + x | |
def Upsample(dim, dim_out): | |
return nn.Sequential( | |
nn.Upsample(scale_factor=2, mode="nearest"), | |
nn.Conv2d(dim, dim_out, 3, padding=1), | |
) | |
def Downsample(dim, dim_out): | |
# No More Strided Convolutions or Pooling | |
return nn.Sequential( | |
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), | |
nn.Conv2d(dim * 4, dim_out, 1), | |
) | |
class Block(nn.Module): | |
def __init__(self, dim, dim_out, groups=8): | |
super().__init__() | |
self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) | |
self.norm = nn.GroupNorm(groups, dim_out) | |
self.act = nn.SiLU() | |
def forward(self, x, scale_shift=None): | |
x = self.proj(x) | |
x = self.norm(x) | |
if scale_shift is not None: | |
scale, shift = scale_shift | |
x = x * (scale + 1) + shift | |
x = self.act(x) | |
return x | |
class ResnetBlock(nn.Module): | |
def __init__(self, dim, dim_out, *, time_emb_dim, groups=8): | |
super().__init__() | |
self.mlp = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear(time_emb_dim, dim_out * 2), | |
) | |
self.block1 = Block(dim, dim_out, groups=groups) | |
self.block2 = Block(dim_out, dim_out, groups=groups) | |
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() | |
def forward(self, x, time_emb=None): | |
scale_shift = None | |
time_emb = self.mlp(time_emb) | |
time_emb = rearrange(time_emb, "b c -> b c 1 1") | |
scale_shift = time_emb.chunk(2, dim=1) | |
h = self.block1(x, scale_shift=scale_shift) | |
h = self.block2(h) | |
return h + self.res_conv(x) | |
class Attention(nn.Module): | |
def __init__(self, dim, heads=4, dim_head=32): | |
super().__init__() | |
self.scale = dim_head**-0.5 | |
self.heads = heads | |
hidden_dim = dim_head * heads | |
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) | |
self.to_out = nn.Conv2d(hidden_dim, dim, 1) | |
def forward(self, x): | |
b, c, h, w = x.shape | |
qkv = self.to_qkv(x).chunk(3, dim=1) | |
q, k, v = map( | |
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv | |
) | |
q = q * self.scale | |
sim = einsum("b h d i, b h d j -> b h i j", q, k) | |
sim = sim - sim.amax(dim=-1, keepdim=True).detach() | |
attn = sim.softmax(dim=-1) | |
out = einsum("b h i j, b h d j -> b h i d", attn, v) | |
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) | |
return self.to_out(out) | |
class LinearAttention(nn.Module): | |
def __init__(self, dim, heads=4, dim_head=32): | |
super().__init__() | |
self.scale = dim_head**-0.5 | |
self.heads = heads | |
hidden_dim = dim_head * heads | |
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) | |
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim)) | |
def forward(self, x): | |
b, c, h, w = x.shape | |
qkv = self.to_qkv(x).chunk(3, dim=1) | |
q, k, v = map( | |
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv | |
) | |
q = q.softmax(dim=-2) | |
k = k.softmax(dim=-1) | |
q = q * self.scale | |
context = torch.einsum("b h d n, b h e n -> b h d e", k, v) | |
out = torch.einsum("b h d e, b h d n -> b h e n", context, q) | |
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) | |
return self.to_out(out) | |
class PreNorm(nn.Module): | |
def __init__(self, dim, fn): | |
super().__init__() | |
self.fn = fn | |
self.norm = nn.GroupNorm(1, dim) | |
def forward(self, x): | |
x = self.norm(x) | |
return self.fn(x) | |
class Unet(nn.Module): | |
def __init__( | |
self, | |
dim, | |
dim_mults=(1, 2, 4, 8), | |
channels=3, | |
self_condition=False, | |
resnet_block_groups=4, | |
): | |
super().__init__() | |
# determine dimensions | |
self.channels = channels | |
self.self_condition = self_condition | |
input_channels = channels * (2 if self_condition else 1) | |
init_dim = dim | |
self.init_conv = nn.Conv2d( | |
input_channels, init_dim, 1, padding=0 | |
) # changed to 1 and 0 from 7,3 | |
dims = [init_dim, *map(lambda m: dim * m, dim_mults)] | |
in_out = list(zip(dims[:-1], dims[1:])) | |
block_klass = partial(ResnetBlock, groups=resnet_block_groups) | |
# time embeddings | |
time_dim = dim * 4 | |
self.time_mlp = nn.Sequential( | |
SinusoidalPositionEmbeddings(dim), | |
nn.Linear(dim, time_dim), | |
nn.GELU(), | |
nn.Linear(time_dim, time_dim), | |
) | |
# layers | |
self.downs = nn.ModuleList([]) | |
self.ups = nn.ModuleList([]) | |
num_resolutions = len(in_out) | |
for ind, (dim_in, dim_out) in enumerate(in_out): | |
is_last = ind >= (num_resolutions - 1) | |
self.downs.append( | |
nn.ModuleList( | |
[ | |
block_klass(dim_in, dim_in, time_emb_dim=time_dim), | |
block_klass(dim_in, dim_in, time_emb_dim=time_dim), | |
Residual(PreNorm(dim_in, LinearAttention(dim_in))), | |
( | |
Downsample(dim_in, dim_out) | |
if not is_last | |
else nn.Conv2d(dim_in, dim_out, 3, padding=1) | |
), | |
] | |
) | |
) | |
mid_dim = dims[-1] | |
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) | |
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) | |
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) | |
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): | |
is_last = ind == (len(in_out) - 1) | |
self.ups.append( | |
nn.ModuleList( | |
[ | |
block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), | |
block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), | |
Residual(PreNorm(dim_out, LinearAttention(dim_out))), | |
( | |
Upsample(dim_out, dim_in) | |
if not is_last | |
else nn.Conv2d(dim_out, dim_in, 3, padding=1) | |
), | |
] | |
) | |
) | |
self.out_dim = channels | |
self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim) | |
self.final_conv = nn.Conv2d(dim, self.out_dim, 1) | |
def forward(self, x, time): | |
x = self.init_conv(x) | |
r = x.clone() | |
t = self.time_mlp(time) | |
h = [] | |
for block1, block2, attn, downsample in self.downs: | |
x = block1(x, t) | |
h.append(x) | |
x = block2(x, t) | |
x = attn(x) | |
h.append(x) | |
x = downsample(x) | |
x = self.mid_block1(x, t) | |
x = self.mid_attn(x) | |
x = self.mid_block2(x, t) | |
for block1, block2, attn, upsample in self.ups: | |
x = torch.cat((x, h.pop()), dim=1) | |
x = block1(x, t) | |
x = torch.cat((x, h.pop()), dim=1) | |
x = block2(x, t) | |
x = attn(x) | |
x = upsample(x) | |
x = torch.cat((x, r), dim=1) | |
x = self.final_res_block(x, t) | |
return self.final_conv(x) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment