Created
August 19, 2019 08:48
-
-
Save koshian2/9be6d617ba26452a51a9fc2e34477e14 to your computer and use it in GitHub Desktop.
ACGAN(3) AnimeFace, full, resnet
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
class ResidualBlock(nn.Module): | |
def __init__(self, ch): | |
super().__init__() | |
self.conv1 = self.conv_bn_relu(ch) | |
self.conv2 = self.conv_bn_relu(ch) | |
def conv_bn_relu(self, ch): | |
return nn.Sequential( | |
nn.Conv2d(ch, ch, kernel_size=3, padding=1), | |
nn.BatchNorm2d(ch), | |
nn.ReLU(True) | |
) | |
def forward(self, inputs): | |
x = self.conv2(self.conv1(inputs)) | |
return inputs + x | |
class Generator(nn.Module): | |
def __init__(self, upsampling_type): | |
assert upsampling_type in ["nearest_neighbor", "transpose_conv", "pixel_shuffler"] | |
self.upsampling_type = upsampling_type | |
super().__init__() | |
self.inital = nn.Sequential( | |
nn.Conv2d(276, 768, 1), | |
nn.BatchNorm2d(768), | |
nn.ReLU(True) | |
) | |
self.conv1 = self.generator_block(768, 512, 4, 2) | |
self.conv2 = self.generator_block(512, 256, 2, 2) | |
self.conv3 = self.generator_block(256, 128, 2, 2) | |
self.conv4 = self.generator_block(128, 64, 2, 2) | |
self.conv5 = self.generator_block(64, 32, 2, 2) | |
self.conv6 = self.generator_block(32, 16, 2, 1) | |
self.out = nn.Sequential( | |
nn.Conv2d(16, 3, kernel_size=3, padding=1), | |
nn.Tanh() | |
) | |
def generator_block(self, in_ch, out_ch, upsampling_factor, n_residual_block): | |
layers = [] | |
if self.upsampling_type == "transpose_conv": | |
layers.append(nn.ConvTranspose2d(in_ch, out_ch, kernel_size=upsampling_factor, stride=upsampling_factor)) | |
layers.append(nn.BatchNorm2d(out_ch)) | |
layers.append(nn.ReLU(True)) | |
elif self.upsampling_type == "nearest_neighbor": | |
layers.append(nn.UpsamplingNearest2d(scale_factor=upsampling_factor)) | |
layers.append(nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)) | |
layers.append(nn.BatchNorm2d(out_ch)) | |
layers.append(nn.ReLU(True)) | |
elif self.upsampling_type == "pixel_shuffler": | |
layers.append(nn.Conv2d(in_ch, out_ch * upsampling_factor ** 2, kernel_size=1)) | |
layers.append(nn.BatchNorm2d(out_ch * upsampling_factor ** 2)) | |
layers.append(nn.ReLU(True)) | |
layers.append(nn.PixelShuffle(upscale_factor=upsampling_factor)) | |
for i in range(n_residual_block): | |
layers.append(ResidualBlock(out_ch)) | |
return nn.Sequential(*layers) | |
def forward(self, inputs): | |
x = self.conv6(self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(self.inital(inputs))))))) | |
return self.out(x) | |
class Discriminator(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv1 = self.conv_bn_relu(3, 32, 2, 1) | |
self.conv2 = self.conv_bn_relu(32, 64, 2, 2) | |
self.conv3 = self.conv_bn_relu(64, 128, 2, 2) | |
self.conv4 = self.conv_bn_relu(128, 256, 2, 2) | |
self.conv5 = self.conv_bn_relu(256, 512, 2, 2) | |
self.prob = nn.Linear(512, 1) | |
self.classes = nn.Linear(512, 176) | |
def conv_bn_relu(self, in_ch, out_ch, reps, pooling_size): | |
layers = [] | |
if pooling_size > 1: | |
layers.append(nn.AvgPool2d(pooling_size)) | |
for i in range(reps): | |
layers.append(nn.Conv2d(in_ch if i == 0 else out_ch, out_ch, 3, padding=1)) | |
layers.append(nn.BatchNorm2d(out_ch)) | |
layers.append(nn.LeakyReLU(0.2, True)) | |
# layers.append(nn.Dropout(0.5)) | |
return nn.Sequential(*layers) | |
def forward(self, inputs): | |
x = self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(inputs))))) | |
x = F.avg_pool2d(x, kernel_size=8).view(x.size(0), -1) | |
prob = self.prob(x) | |
classes = self.classes(x) | |
return prob, classes |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import torchvision | |
from torchvision import transforms | |
from tqdm import tqdm | |
import numpy as np | |
from models import Generator, Discriminator | |
import os | |
import shutil | |
import pickle | |
import statistics | |
import glob | |
def load_dataset(batch_size): | |
# 前処理 | |
for dir in sorted(glob.glob("thumb/*")): | |
imgs = glob.glob(dir + "/*.png") | |
if len(imgs) == 0: | |
shutil.rmtree(dir) | |
trans = transforms.Compose([ | |
transforms.Resize((128, 128)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) | |
]) | |
dataset = torchvision.datasets.ImageFolder(root="./thumb", transform=trans) | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=6) | |
return dataloader | |
def weight_init(layer): | |
if type(layer) in [nn.Conv2d, nn.ConvTranspose2d]: | |
nn.init.normal_(layer.weight, 0.0, 0.02) | |
class HingeLoss(nn.Module): | |
def __init__(self, batch_size, device): | |
super().__init__() | |
self.ones = torch.ones(batch_size).to(device) | |
self.zeros = torch.zeros(batch_size).to(device) | |
def __call__(self, logits, condition): | |
assert condition in ["gen", "dis_real", "dis_fake"] | |
batch_len = len(logits) | |
if condition == "gen": | |
# Generatorでは、本物になるようにHinge lossを返す | |
return -torch.mean(logits) | |
elif condition == "dis_real": | |
minval = torch.min(logits - 1, self.zeros[:batch_len]) | |
return -torch.mean(minval) | |
else: | |
minval = torch.min(-logits - 1, self.zeros[:batch_len]) | |
return - torch.mean(minval) | |
def train(upsampling_type): | |
assert upsampling_type in ["nearest_neighbor", "transpose_conv", "pixel_shuffler"] | |
output_dir = "anime_acgan_" + upsampling_type | |
device = "cuda" | |
batch_size = 128 | |
dataloader = load_dataset(batch_size) | |
model_G = Generator(upsampling_type) | |
model_D = Discriminator() | |
model_G.apply(weight_init) | |
model_D.apply(weight_init) | |
model_G, model_D = model_G.to(device), model_D.to(device) | |
if device == "cuda": | |
model_G, model_D = torch.nn.DataParallel(model_G), torch.nn.DataParallel(model_D) | |
param_G = torch.optim.Adam(model_G.parameters(), lr=0.0002, betas=(0.5, 0.999)) | |
param_D = torch.optim.Adam(model_D.parameters(), lr=0.0002, betas=(0.5, 0.999)) | |
hinge_loss = HingeLoss(batch_size, device) | |
softmax_loss = torch.nn.CrossEntropyLoss() | |
result = {"d_loss":[], "g_loss":[]} | |
for epoch in range(401): | |
log_loss_D, log_loss_G = [], [] | |
for real_img, real_label in tqdm(dataloader): | |
batch_len = len(real_img) | |
real_img, real_label = real_img.to(device), real_label.to(device) | |
# train G | |
rand_X = torch.randn(batch_len, 100, 1, 1) | |
label_onehot = torch.eye(176)[real_label] | |
label_onehot = label_onehot.view(batch_len, 176, 1, 1) | |
rand_X = torch.cat([rand_X, label_onehot], dim=1) | |
rand_X = rand_X.to(device) | |
fake_img = model_G(rand_X) | |
fake_img_tensor = fake_img.detach() | |
g_out = model_D(fake_img) | |
loss = hinge_loss(g_out[0], "gen") | |
loss += softmax_loss(g_out[1], real_label) | |
log_loss_G.append(loss.item()) | |
# backprop | |
param_D.zero_grad() | |
param_G.zero_grad() | |
loss.backward() | |
param_G.step() | |
# train D | |
# train real | |
d_out_real = model_D(real_img) | |
loss_real = hinge_loss(d_out_real[0], "dis_real") | |
loss_real += softmax_loss(d_out_real[1], real_label) | |
# train fake | |
d_out_fake = model_D(fake_img_tensor) | |
loss_fake = hinge_loss(d_out_fake[0], "dis_fake") | |
loss_fake += softmax_loss(d_out_fake[1], real_label) | |
loss = (loss_real + loss_fake) / 2.0 | |
log_loss_D.append(loss.item()) | |
# backprop | |
param_D.zero_grad() | |
param_G.zero_grad() | |
loss.backward() | |
param_D.step() | |
# ログ | |
result["d_loss"].append(statistics.mean(log_loss_D)) | |
result["g_loss"].append(statistics.mean(log_loss_G)) | |
print(f"epoch = {epoch}, g_loss = {result['g_loss'][-1]}, d_loss = {result['d_loss'][-1]}") | |
if not os.path.exists(output_dir): | |
os.mkdir(output_dir) | |
torchvision.utils.save_image(fake_img_tensor[:36], f"{output_dir}/epoch_{epoch:03}.png", nrow=6, padding=3, normalize=True, range=(-1.0, 1.0)) | |
# 係数保存 | |
if not os.path.exists(output_dir + "/models"): | |
os.mkdir(output_dir+"/models") | |
if epoch % 10 == 0: | |
torch.save(model_G.state_dict(), f"{output_dir}/models/gen_epoch_{epoch:03}.pytorch") | |
torch.save(model_D.state_dict(), f"{output_dir}/models/dis_epoch_{epoch:03}.pytorch") | |
# ログ | |
with open(output_dir + "/logs.pkl", "wb") as fp: | |
pickle.dump(result, fp) | |
if __name__ == "__main__": | |
for upsampling in ["pixel_shuffler"]: | |
train(upsampling) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment