Skip to content

Instantly share code, notes, and snippets.

@KohakuBlueleaf
Created November 7, 2023 12:08
Show Gist options
  • Save KohakuBlueleaf/e8bc8c562d4d0ecd383ddf4a7cfebabd to your computer and use it in GitHub Desktop.
Save KohakuBlueleaf/e8bc8c562d4d0ecd383ddf4a7cfebabd to your computer and use it in GitHub Desktop.
A transcript of Consistency Decoder
import torch
import torch.nn as nn
import torch.nn.functional as F
class TimestepEmbedding(nn.Module):
def __init__(self, n_time=1024, n_emb=320, n_out=1280) -> None:
super().__init__()
self.emb = nn.Embedding(n_time, n_emb)
self.f_1 = nn.Linear(n_emb, n_out)
# self.act = nn.SiLU()
self.f_2 = nn.Linear(n_out, n_out)
def forward(self, x) -> torch.Tensor:
x = self.emb(x)
x = self.f_1(x)
x = F.silu(x)
return self.f_2(x)
class ImageEmbedding(nn.Module):
def __init__(self, in_channels=7, out_channels=320) -> None:
super().__init__()
self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x) -> torch.Tensor:
return self.f(x)
class ImageUnembedding(nn.Module):
def __init__(self, in_channels=320, out_channels=6) -> None:
super().__init__()
self.gn = nn.GroupNorm(32, in_channels)
self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x) -> torch.Tensor:
return self.f(F.silu(self.gn(x)))
class ConvResblock(nn.Module):
def __init__(self, in_features=320, out_features=320, skip_conv=False) -> None:
super().__init__()
self.f_t = nn.Linear(1280, out_features * 2)
self.gn_1 = nn.GroupNorm(32, in_features)
self.f_1 = nn.Conv2d(in_features, out_features, kernel_size=3, padding=1)
self.gn_2 = nn.GroupNorm(32, out_features)
self.f_2 = nn.Conv2d(out_features, out_features, kernel_size=3, padding=1)
self.f_s = nn.Identity() if not skip_conv else nn.Conv2d(in_features, out_features, kernel_size=1)
def forward(self, x, t):
x_skip = x
t: torch.Tensor = self.f_t(F.silu(t))
t = t.chunk(2, dim=1)
# ???
# maybe need to swap them out idk, idxs are like that, first one is +1, other is as is
# probably that stupid while loop with `None`s
t_1 = t[0].unsqueeze(dim=2).unsqueeze(dim=3) + 1
t_2 = t[1].unsqueeze(dim=2).unsqueeze(dim=3)
gn_1 = F.silu(self.gn_1(x))
f_1 = self.f_1(gn_1)
gn_2 = self.gn_2(f_1)
# I don't know how addcmul is routed, probably += a*b? self is t_2, tensor1 is gn_2, tensor2 is t_1
addcmul = F.silu(gn_2 * t_1 + t_2)
return self.f_s(x_skip) + self.f_2(addcmul)
# Also ConvResblock
class Downsample(nn.Module):
def __init__(self, in_channels=320) -> None:
super().__init__()
self.f_t = nn.Linear(1280, in_channels*2)
self.gn_1 = nn.GroupNorm(32, in_channels)
self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.gn_2 = nn.GroupNorm(32, in_channels)
self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
def forward(self, x, t) -> torch.Tensor:
x_skip = x
t = F.silu(self.f_t(t))
t_1, t_2 = t.chunk(2, dim=1)
t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
t_2 = t_2.unsqueeze(2).unsqueeze(3)
gn_1 = F.silu(self.gn_1(x))
avg_pool2d = F.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None)
f_1 = self.f_1(avg_pool2d)
gn_2 = self.gn_2(f_1)
addcmul = F.silu(t_2 + (t_1 * gn_2))
f_2 = self.f_2(addcmul)
return f_2 + F.avg_pool2d(x_skip, kernel_size=(2, 2), stride=None)
# Also ConvResblock
class Upsample(nn.Module):
def __init__(self, in_channels=1024) -> None:
super().__init__()
self.f_t = nn.Linear(1280, in_channels*2)
self.gn_1 = nn.GroupNorm(32, in_channels)
self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.gn_2 = nn.GroupNorm(32, in_channels)
self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
def forward(self, x, t) -> torch.Tensor:
x_skip = x
t = F.silu(self.f_t(t))
t_1, t_2 = t.chunk(2, dim=1)
t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
t_2 = t_2.unsqueeze(2).unsqueeze(3)
gn_1 = F.silu(self.gn_1(x))
# avg_pool2d = F.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None)
upsample = F.interpolate(gn_1, scale_factor=2, mode='nearest-exact')
f_1 = self.f_1(upsample)
gn_2 = self.gn_2(f_1)
addcmul = F.silu(t_2 + (t_1 * gn_2))
f_2 = self.f_2(addcmul)
return f_2 + F.interpolate(x_skip, scale_factor=2, mode='nearest-exact')
# ConsistencyDecoder aka super resolution from 4 to 3 channels!
class Decoder(nn.Module):
def __init__(self) -> None:
super().__init__()
self.embed_image = ImageEmbedding()
self.embed_time = TimestepEmbedding()
# No attention is needed here!
# We only "upscale" (48x that is or 64x if you don't count chan diff lulw)
# I was close to doing that,
# but I had CrossAttn over VAE reshaped
# to be Bx(HW div by whatever or -1 if you prefer)x1024 alongside DiffNeXt's skip
# 3 ResBlocks before downsample
# repeat 4 times
# downs are [320, 640, 1024, 1024]
# in reality it has distinctions between conv and downsamp
# Chess Battle Advanced
down_0 = nn.ModuleList([
ConvResblock(320, 320),
ConvResblock(320, 320),
ConvResblock(320, 320),
Downsample(320),
])
down_1 = nn.ModuleList([
ConvResblock(320, 640, skip_conv=True),
ConvResblock(640, 640),
ConvResblock(640, 640),
Downsample(640),
])
down_2 = nn.ModuleList([
ConvResblock(640, 1024, skip_conv=True),
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
Downsample(1024),
])
down_3 = nn.ModuleList([
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
])
self.down = nn.ModuleList([
down_0,
down_1,
down_2,
down_3,
])
# mid has 2
self.mid = nn.ModuleList([
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
])
# Again,
# Chess Battle Advanced
up_3 = nn.ModuleList([
ConvResblock(1024*2, 1024, skip_conv=True),
ConvResblock(1024*2, 1024, skip_conv=True),
ConvResblock(1024*2, 1024, skip_conv=True),
ConvResblock(1024*2, 1024, skip_conv=True),
Upsample(1024),
])
up_2 = nn.ModuleList([
ConvResblock(1024*2, 1024, skip_conv=True),
ConvResblock(1024*2, 1024, skip_conv=True),
ConvResblock(1024*2, 1024, skip_conv=True),
ConvResblock(1024+640, 1024, skip_conv=True),
Upsample(1024),
])
up_1 = nn.ModuleList([
ConvResblock(1024+640, 640, skip_conv=True),
ConvResblock(640*2, 640, skip_conv=True),
ConvResblock(640*2, 640, skip_conv=True),
ConvResblock(320+640, 640, skip_conv=True),
Upsample(640),
])
up_0 = nn.ModuleList([
ConvResblock(320+640, 320, skip_conv=True),
ConvResblock(320*2, 320, skip_conv=True),
ConvResblock(320*2, 320, skip_conv=True),
ConvResblock(320*2, 320, skip_conv=True),
])
self.up = nn.ModuleList([
up_0,
up_1,
up_2,
up_3,
])
self.output = ImageUnembedding()
@torch.no_grad()
def forward(self, x, t, features) -> torch.Tensor:
t = self.embed_time(t)
# LITERAL SUPER RESOLUTION
x = torch.cat(
[x, F.interpolate(features, scale_factor=8, mode='nearest-exact')],
dim=1
)
x = self.embed_image(x)
# DOWN
block_outs = [x]
for down in self.down:
for block in down:
x = block(x, t)
block_outs.append(x)
# mid
for i in range(2):
x = self.mid[i](x, t)
# UP
for up in reversed(self.up):
for block in up:
if not isinstance(block, Upsample) and block_outs:
x = torch.cat([x, block_outs.pop()], dim=1)
x = block(x, t)
# OUT
# GN -> silu -> f
x = self.output(x)
return x
def load_model():
model = Decoder()
import safetensors.torch
cd_orig = safetensors.torch.load_file("consistency_decoder.safetensors")
# print(cd_orig.keys())
# prefix
cd_orig = {k.replace("blocks.", ""): v for k,v in cd_orig.items()}
# layer names
cd_orig = {k.replace("down_0_", "down.0."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("down_1_", "down.1."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("down_2_", "down.2."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("down_3_", "down.3."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("up_0_", "up.0."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("up_1_", "up.1."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("up_2_", "up.2."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("up_3_", "up.3."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("up_4_", "up.4."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("conv_0.", "0."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("conv_1.", "1."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("conv_2.", "2."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("conv_3.", "3."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("upsamp.", "4."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("downsamp.", "3."): v for k,v in cd_orig.items()}
cd_orig = {k.replace("mid_0", "mid.0"): v for k,v in cd_orig.items()}
cd_orig = {k.replace("mid_1", "mid.1"): v for k,v in cd_orig.items()}
# conv+linear
cd_orig = {k.replace("f_t.w", "f_t.weight").replace("f_t.b", "f_t.bias"): v for k,v in cd_orig.items()}
cd_orig = {k.replace("f_1.w", "f_1.weight").replace("f_1.b", "f_1.bias"): v for k,v in cd_orig.items()}
cd_orig = {k.replace("f_2.w", "f_2.weight").replace("f_2.b", "f_2.bias"): v for k,v in cd_orig.items()}
cd_orig = {k.replace("f_s.w", "f_s.weight").replace("f_s.b", "f_s.bias"): v for k,v in cd_orig.items()}
cd_orig = {k.replace("f.w", "f.weight").replace("f.b", "f.bias"): v for k,v in cd_orig.items()}
# GN
cd_orig = {k.replace("gn_1.g", "gn_1.weight").replace("gn_1.b", "gn_1.bias"): v for k,v in cd_orig.items()}
cd_orig = {k.replace("gn_2.g", "gn_2.weight").replace("gn_2.b", "gn_2.bias"): v for k,v in cd_orig.items()}
cd_orig = {k.replace("gn.g", "gn.weight").replace("gn.b", "gn.bias"): v for k,v in cd_orig.items()}
# output sequence
# cd_orig = {k.replace("output.gn.", "output.0."): v for k,v in cd_orig.items()}
# cd_orig = {k.replace("output.f.", "output.2."): v for k,v in cd_orig.items()}
cd_orig["embed_time.emb.weight"] = safetensors.torch.load_file("embedding.safetensors")["weight"]
model.load_state_dict(cd_orig)
print(cd_orig["embed_time.emb.weight"][1][0])
return model
if __name__ == '__main__':
model = load_model()
print(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment