Created
November 7, 2023 12:08
-
-
Save KohakuBlueleaf/e8bc8c562d4d0ecd383ddf4a7cfebabd to your computer and use it in GitHub Desktop.
A transcript of Consistency Decoder
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class TimestepEmbedding(nn.Module): | |
def __init__(self, n_time=1024, n_emb=320, n_out=1280) -> None: | |
super().__init__() | |
self.emb = nn.Embedding(n_time, n_emb) | |
self.f_1 = nn.Linear(n_emb, n_out) | |
# self.act = nn.SiLU() | |
self.f_2 = nn.Linear(n_out, n_out) | |
def forward(self, x) -> torch.Tensor: | |
x = self.emb(x) | |
x = self.f_1(x) | |
x = F.silu(x) | |
return self.f_2(x) | |
class ImageEmbedding(nn.Module): | |
def __init__(self, in_channels=7, out_channels=320) -> None: | |
super().__init__() | |
self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) | |
def forward(self, x) -> torch.Tensor: | |
return self.f(x) | |
class ImageUnembedding(nn.Module): | |
def __init__(self, in_channels=320, out_channels=6) -> None: | |
super().__init__() | |
self.gn = nn.GroupNorm(32, in_channels) | |
self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) | |
def forward(self, x) -> torch.Tensor: | |
return self.f(F.silu(self.gn(x))) | |
class ConvResblock(nn.Module): | |
def __init__(self, in_features=320, out_features=320, skip_conv=False) -> None: | |
super().__init__() | |
self.f_t = nn.Linear(1280, out_features * 2) | |
self.gn_1 = nn.GroupNorm(32, in_features) | |
self.f_1 = nn.Conv2d(in_features, out_features, kernel_size=3, padding=1) | |
self.gn_2 = nn.GroupNorm(32, out_features) | |
self.f_2 = nn.Conv2d(out_features, out_features, kernel_size=3, padding=1) | |
self.f_s = nn.Identity() if not skip_conv else nn.Conv2d(in_features, out_features, kernel_size=1) | |
def forward(self, x, t): | |
x_skip = x | |
t: torch.Tensor = self.f_t(F.silu(t)) | |
t = t.chunk(2, dim=1) | |
# ??? | |
# maybe need to swap them out idk, idxs are like that, first one is +1, other is as is | |
# probably that stupid while loop with `None`s | |
t_1 = t[0].unsqueeze(dim=2).unsqueeze(dim=3) + 1 | |
t_2 = t[1].unsqueeze(dim=2).unsqueeze(dim=3) | |
gn_1 = F.silu(self.gn_1(x)) | |
f_1 = self.f_1(gn_1) | |
gn_2 = self.gn_2(f_1) | |
# I don't know how addcmul is routed, probably += a*b? self is t_2, tensor1 is gn_2, tensor2 is t_1 | |
addcmul = F.silu(gn_2 * t_1 + t_2) | |
return self.f_s(x_skip) + self.f_2(addcmul) | |
# Also ConvResblock | |
class Downsample(nn.Module): | |
def __init__(self, in_channels=320) -> None: | |
super().__init__() | |
self.f_t = nn.Linear(1280, in_channels*2) | |
self.gn_1 = nn.GroupNorm(32, in_channels) | |
self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) | |
self.gn_2 = nn.GroupNorm(32, in_channels) | |
self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) | |
def forward(self, x, t) -> torch.Tensor: | |
x_skip = x | |
t = F.silu(self.f_t(t)) | |
t_1, t_2 = t.chunk(2, dim=1) | |
t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1 | |
t_2 = t_2.unsqueeze(2).unsqueeze(3) | |
gn_1 = F.silu(self.gn_1(x)) | |
avg_pool2d = F.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None) | |
f_1 = self.f_1(avg_pool2d) | |
gn_2 = self.gn_2(f_1) | |
addcmul = F.silu(t_2 + (t_1 * gn_2)) | |
f_2 = self.f_2(addcmul) | |
return f_2 + F.avg_pool2d(x_skip, kernel_size=(2, 2), stride=None) | |
# Also ConvResblock | |
class Upsample(nn.Module): | |
def __init__(self, in_channels=1024) -> None: | |
super().__init__() | |
self.f_t = nn.Linear(1280, in_channels*2) | |
self.gn_1 = nn.GroupNorm(32, in_channels) | |
self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) | |
self.gn_2 = nn.GroupNorm(32, in_channels) | |
self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) | |
def forward(self, x, t) -> torch.Tensor: | |
x_skip = x | |
t = F.silu(self.f_t(t)) | |
t_1, t_2 = t.chunk(2, dim=1) | |
t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1 | |
t_2 = t_2.unsqueeze(2).unsqueeze(3) | |
gn_1 = F.silu(self.gn_1(x)) | |
# avg_pool2d = F.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None) | |
upsample = F.interpolate(gn_1, scale_factor=2, mode='nearest-exact') | |
f_1 = self.f_1(upsample) | |
gn_2 = self.gn_2(f_1) | |
addcmul = F.silu(t_2 + (t_1 * gn_2)) | |
f_2 = self.f_2(addcmul) | |
return f_2 + F.interpolate(x_skip, scale_factor=2, mode='nearest-exact') | |
# ConsistencyDecoder aka super resolution from 4 to 3 channels! | |
class Decoder(nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
self.embed_image = ImageEmbedding() | |
self.embed_time = TimestepEmbedding() | |
# No attention is needed here! | |
# We only "upscale" (48x that is or 64x if you don't count chan diff lulw) | |
# I was close to doing that, | |
# but I had CrossAttn over VAE reshaped | |
# to be Bx(HW div by whatever or -1 if you prefer)x1024 alongside DiffNeXt's skip | |
# 3 ResBlocks before downsample | |
# repeat 4 times | |
# downs are [320, 640, 1024, 1024] | |
# in reality it has distinctions between conv and downsamp | |
# Chess Battle Advanced | |
down_0 = nn.ModuleList([ | |
ConvResblock(320, 320), | |
ConvResblock(320, 320), | |
ConvResblock(320, 320), | |
Downsample(320), | |
]) | |
down_1 = nn.ModuleList([ | |
ConvResblock(320, 640, skip_conv=True), | |
ConvResblock(640, 640), | |
ConvResblock(640, 640), | |
Downsample(640), | |
]) | |
down_2 = nn.ModuleList([ | |
ConvResblock(640, 1024, skip_conv=True), | |
ConvResblock(1024, 1024), | |
ConvResblock(1024, 1024), | |
Downsample(1024), | |
]) | |
down_3 = nn.ModuleList([ | |
ConvResblock(1024, 1024), | |
ConvResblock(1024, 1024), | |
ConvResblock(1024, 1024), | |
]) | |
self.down = nn.ModuleList([ | |
down_0, | |
down_1, | |
down_2, | |
down_3, | |
]) | |
# mid has 2 | |
self.mid = nn.ModuleList([ | |
ConvResblock(1024, 1024), | |
ConvResblock(1024, 1024), | |
]) | |
# Again, | |
# Chess Battle Advanced | |
up_3 = nn.ModuleList([ | |
ConvResblock(1024*2, 1024, skip_conv=True), | |
ConvResblock(1024*2, 1024, skip_conv=True), | |
ConvResblock(1024*2, 1024, skip_conv=True), | |
ConvResblock(1024*2, 1024, skip_conv=True), | |
Upsample(1024), | |
]) | |
up_2 = nn.ModuleList([ | |
ConvResblock(1024*2, 1024, skip_conv=True), | |
ConvResblock(1024*2, 1024, skip_conv=True), | |
ConvResblock(1024*2, 1024, skip_conv=True), | |
ConvResblock(1024+640, 1024, skip_conv=True), | |
Upsample(1024), | |
]) | |
up_1 = nn.ModuleList([ | |
ConvResblock(1024+640, 640, skip_conv=True), | |
ConvResblock(640*2, 640, skip_conv=True), | |
ConvResblock(640*2, 640, skip_conv=True), | |
ConvResblock(320+640, 640, skip_conv=True), | |
Upsample(640), | |
]) | |
up_0 = nn.ModuleList([ | |
ConvResblock(320+640, 320, skip_conv=True), | |
ConvResblock(320*2, 320, skip_conv=True), | |
ConvResblock(320*2, 320, skip_conv=True), | |
ConvResblock(320*2, 320, skip_conv=True), | |
]) | |
self.up = nn.ModuleList([ | |
up_0, | |
up_1, | |
up_2, | |
up_3, | |
]) | |
self.output = ImageUnembedding() | |
@torch.no_grad() | |
def forward(self, x, t, features) -> torch.Tensor: | |
t = self.embed_time(t) | |
# LITERAL SUPER RESOLUTION | |
x = torch.cat( | |
[x, F.interpolate(features, scale_factor=8, mode='nearest-exact')], | |
dim=1 | |
) | |
x = self.embed_image(x) | |
# DOWN | |
block_outs = [x] | |
for down in self.down: | |
for block in down: | |
x = block(x, t) | |
block_outs.append(x) | |
# mid | |
for i in range(2): | |
x = self.mid[i](x, t) | |
# UP | |
for up in reversed(self.up): | |
for block in up: | |
if not isinstance(block, Upsample) and block_outs: | |
x = torch.cat([x, block_outs.pop()], dim=1) | |
x = block(x, t) | |
# OUT | |
# GN -> silu -> f | |
x = self.output(x) | |
return x | |
def load_model(): | |
model = Decoder() | |
import safetensors.torch | |
cd_orig = safetensors.torch.load_file("consistency_decoder.safetensors") | |
# print(cd_orig.keys()) | |
# prefix | |
cd_orig = {k.replace("blocks.", ""): v for k,v in cd_orig.items()} | |
# layer names | |
cd_orig = {k.replace("down_0_", "down.0."): v for k,v in cd_orig.items()} | |
cd_orig = {k.replace("down_1_", "down.1."): v for k,v in cd_orig.items()} | |
cd_orig = {k.replace("down_2_", "down.2."): v for k,v in cd_orig.items()} | |
cd_orig = {k.replace("down_3_", "down.3."): v for k,v in cd_orig.items()} | |
cd_orig = {k.replace("up_0_", "up.0."): v for k,v in cd_orig.items()} | |
cd_orig = {k.replace("up_1_", "up.1."): v for k,v in cd_orig.items()} | |
cd_orig = {k.replace("up_2_", "up.2."): v for k,v in cd_orig.items()} | |
cd_orig = {k.replace("up_3_", "up.3."): v for k,v in cd_orig.items()} | |
cd_orig = {k.replace("up_4_", "up.4."): v for k,v in cd_orig.items()} | |
cd_orig = {k.replace("conv_0.", "0."): v for k,v in cd_orig.items()} | |
cd_orig = {k.replace("conv_1.", "1."): v for k,v in cd_orig.items()} | |
cd_orig = {k.replace("conv_2.", "2."): v for k,v in cd_orig.items()} | |
cd_orig = {k.replace("conv_3.", "3."): v for k,v in cd_orig.items()} | |
cd_orig = {k.replace("upsamp.", "4."): v for k,v in cd_orig.items()} | |
cd_orig = {k.replace("downsamp.", "3."): v for k,v in cd_orig.items()} | |
cd_orig = {k.replace("mid_0", "mid.0"): v for k,v in cd_orig.items()} | |
cd_orig = {k.replace("mid_1", "mid.1"): v for k,v in cd_orig.items()} | |
# conv+linear | |
cd_orig = {k.replace("f_t.w", "f_t.weight").replace("f_t.b", "f_t.bias"): v for k,v in cd_orig.items()} | |
cd_orig = {k.replace("f_1.w", "f_1.weight").replace("f_1.b", "f_1.bias"): v for k,v in cd_orig.items()} | |
cd_orig = {k.replace("f_2.w", "f_2.weight").replace("f_2.b", "f_2.bias"): v for k,v in cd_orig.items()} | |
cd_orig = {k.replace("f_s.w", "f_s.weight").replace("f_s.b", "f_s.bias"): v for k,v in cd_orig.items()} | |
cd_orig = {k.replace("f.w", "f.weight").replace("f.b", "f.bias"): v for k,v in cd_orig.items()} | |
# GN | |
cd_orig = {k.replace("gn_1.g", "gn_1.weight").replace("gn_1.b", "gn_1.bias"): v for k,v in cd_orig.items()} | |
cd_orig = {k.replace("gn_2.g", "gn_2.weight").replace("gn_2.b", "gn_2.bias"): v for k,v in cd_orig.items()} | |
cd_orig = {k.replace("gn.g", "gn.weight").replace("gn.b", "gn.bias"): v for k,v in cd_orig.items()} | |
# output sequence | |
# cd_orig = {k.replace("output.gn.", "output.0."): v for k,v in cd_orig.items()} | |
# cd_orig = {k.replace("output.f.", "output.2."): v for k,v in cd_orig.items()} | |
cd_orig["embed_time.emb.weight"] = safetensors.torch.load_file("embedding.safetensors")["weight"] | |
model.load_state_dict(cd_orig) | |
print(cd_orig["embed_time.emb.weight"][1][0]) | |
return model | |
if __name__ == '__main__': | |
model = load_model() | |
print(model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment