Last active
June 11, 2018 11:13
-
-
Save htoyryla/d57cf3889bdf32efd696a82547390ca5 to your computer and use it in GitHub Desktop.
HT-GAN including AE and growing image size
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
code for a gan trainer | |
adding an encoder to assist in generator training | |
htoyryla 8 Jun 2018 | |
support for progressive training with larger image size | |
htoyryla 11.6.2018 | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import argparse | |
import os | |
import random | |
import torch | |
import torch.nn as nn | |
import torch.nn.parallel | |
import torch.backends.cudnn as cudnn | |
import torch.optim as optim | |
import torch.utils.data | |
import torchvision.datasets as dset | |
import torchvision.transforms as transforms | |
import torchvision.utils as vutils | |
from torch.autograd import Variable | |
import functools | |
from collections import OrderedDict | |
# | |
# code for models for traindcg2 | |
# htoyryla 8 Jun 2018 | |
# | |
# experimental | |
# seriously in need of refactoring | |
# | |
# v.2c3 try naming of layers | |
# v.2c4 larger kernels on larger layers | |
# v.2c5 models changed so that in D & E, layers are added to input size as image size grows | |
nz = 100 | |
ngf = 64 | |
ndf = 64 | |
size = 64 | |
def weights_init(m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
m.weight.data.normal_(0.0, 0.02) | |
if hasattr(m.bias, 'data'): | |
m.bias.data.fill_(0) | |
elif classname.find('Batchnorm') !=-1: | |
m.weight.data.normal_(1.0, 0.02) | |
m.bias.data.fill_(0) | |
# total variation loss module | |
# use as a loss module, not a layer | |
class TVLoss(nn.Module): | |
def __init__(self,TVLoss_weight=1): | |
super(TVLoss,self).__init__() | |
self.TVLoss_weight = TVLoss_weight | |
def forward(self, x): | |
batch_size = x.size()[0] | |
h_x = x.size()[2] | |
w_x = x.size()[3] | |
count_h = self._tensor_size(x[:,:,1:,:]) | |
count_w = self._tensor_size(x[:,:,:,1:]) | |
h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum() | |
w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum() | |
return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size | |
def _tensor_size(self,t): | |
return t.size()[1]*t.size()[2]*t.size()[3] | |
# generator, duplicates my mygan architecture as of 4/2017 | |
class _netG(nn.Module): | |
def __init__(self, ngpu=1, norm_layer=nn.BatchNorm2d, opt=None): | |
super(_netG, self).__init__() | |
assert(opt is not None) | |
nz = opt.nz | |
nc = opt.nc | |
ndf = opt.ndf | |
ngf = opt.ndf | |
size = opt.imageSize | |
self.ngpu = ngpu | |
if type(norm_layer) == functools.partial: | |
use_bias = norm_layer.func == nn.InstanceNorm2d | |
else: | |
use_bias = norm_layer == nn.InstanceNorm2d | |
use_bias = norm_layer==nn.InstanceNorm2d | |
layers = [ | |
# input is Z, going into a convolution | |
("deconv1", nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=use_bias)), | |
("norm1", norm_layer(ngf * 8)), | |
("relu1", nn.ReLU(True)), | |
# state size. (ngf*8) x 4 x 4 | |
("deconv2", nn.ConvTranspose2d(ngf * 8, ngf * 4, 3, 2, 1, 1, bias=use_bias)), | |
("norm2", norm_layer(ngf * 4)), | |
("relu2", nn.ReLU(True)), | |
# state size. (ngf*4) x 8 x 8 | |
("deconv3", nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, 1, bias=use_bias)), | |
("norm3", norm_layer(ngf * 2)), | |
("relu3", nn.ReLU(True)), | |
# state size. (ngf*2) x 16 x 16 | |
("deconv4", nn.ConvTranspose2d(ngf * 2, ngf, 3, 2, 1, 1, bias=use_bias)), | |
("norm4", norm_layer(ngf)), | |
("relu4", nn.ReLU(True)), | |
# state size. (ngf) x 32 x 32 | |
("deconv5", nn.ConvTranspose2d( ngf, ngf, 3, 2, 1, 1, bias=use_bias)), | |
("norm5", norm_layer(ngf)), | |
("relu5", nn.ReLU(True)) ] | |
if size > 64: | |
#ht 30.4.2018 | |
more = [ | |
("deconv6", nn.ConvTranspose2d( ngf, ngf, 3, 2, 1, 1, bias=use_bias)), | |
("norm6", norm_layer(ngf)), | |
("relu6", nn.ReLU(True)) ] | |
layers.extend(more) | |
if size > 128: | |
more = [ | |
("deconv7", nn.ConvTranspose2d( ngf, ngf, 3, 2, 1, 1, bias=use_bias)), | |
("norm7", norm_layer(ngf)), | |
("relu7", nn.ReLU(True)) ] | |
layers.extend(more) | |
if size > 256: | |
more = [ | |
("deconv8", nn.ConvTranspose2d( ngf, ngf, 6, 2, 2, 0, bias=use_bias)), | |
("norm8", norm_layer(ngf)), | |
("relu8", nn.ReLU(True)) ] | |
layers.extend(more) | |
if size > 512: | |
more = [ | |
("deconv9", nn.ConvTranspose2d( ngf, ngf, 6, 2, 2, 0, bias=use_bias)), | |
("norm9", norm_layer(ngf)), | |
("relu9", nn.ReLU(True)) ] | |
layers.extend(more) | |
final = [ | |
("outconv", nn.Conv2d(ngf, nc, kernel_size=3, stride=1, padding=1, bias=use_bias)), | |
("outactiv", nn.Tanh()) ] | |
layers.extend(final) | |
print(layers) | |
self.main = nn.Sequential(OrderedDict(layers)) | |
def forward(self, input): | |
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: | |
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) | |
else: | |
output = self.main(input) | |
return output | |
# discriminator | |
class _netD(nn.Module): | |
def __init__(self, ngpu=1, use_sigmoid=True, norm_layer=nn.BatchNorm2d, opt=None): | |
super(_netD, self).__init__() | |
nz = opt.nz | |
nc = opt.nc | |
ndf = opt.ndf | |
ngf = opt.ndf | |
size = opt.imageSize | |
self.ngpu = ngpu | |
if type(norm_layer) == functools.partial: | |
use_bias = norm_layer.func==nn.InstanceNorm2d | |
else: | |
use_bias = norm_layer==nn.InstanceNorm2d | |
sequence = [ | |
("inconv", nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1, bias=use_bias)), | |
("inrelu", nn.LeakyReLU(0.2, inplace=True))] | |
if size > 512: | |
more = [ | |
("conv1024", nn.Conv2d(ndf, ndf, kernel_size=4, stride=2, padding=1, bias=use_bias)), | |
("norm1024", norm_layer(ndf)), | |
("relu1024", nn.LeakyReLU(0.2, inplace=True)) ] | |
sequence.extend(more) | |
if size > 256: | |
more = [ | |
("conv512", nn.Conv2d(ndf, ndf, kernel_size=4, stride=2, padding=1, bias=use_bias)), | |
("norm5", norm_layer(ndf)), | |
("relu5", nn.LeakyReLU(0.2, inplace=True)) ] | |
sequence.extend(more) | |
if size > 128: | |
more = [ | |
("conv256", nn.Conv2d(ndf, ndf, kernel_size=4, stride=2, padding=1, bias=use_bias)), | |
("norm256", norm_layer(ndf)), | |
("relu256", nn.LeakyReLU(0.2, inplace=True)) ] | |
sequence.extend(more) | |
if size > 64: | |
more = [ | |
("conv128", nn.Conv2d(ndf, ndf, kernel_size=4, stride=2, padding=1, bias=use_bias)), | |
("norm128", norm_layer(ndf)), | |
("relu128", nn.LeakyReLU(0.2, inplace=True)) ] | |
sequence.extend(more) | |
body = [ | |
("conv2", nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1, bias=use_bias)), | |
("norm2", norm_layer(ndf * 2)), | |
("relu2", nn.LeakyReLU(0.2, inplace=True)), | |
("conv3", nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1, bias=use_bias)), | |
("norm3", norm_layer(ndf * 4)), | |
("relu3", nn.LeakyReLU(0.2, inplace=True)), | |
('convf4', nn.Conv2d(ndf * 4, 1, kernel_size=4, stride=4, padding=1, bias=use_bias)), | |
('outconv', nn.Conv2d(1, 1, kernel_size=4, stride=4, padding=1, bias=use_bias))] | |
sequence.extend(body) | |
if use_sigmoid: | |
sequence += [("outactiv", nn.Sigmoid())] | |
self.main = nn.Sequential(OrderedDict(sequence)) | |
def forward(self, input): | |
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: | |
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) | |
else: | |
output = self.main(input) | |
return output.view(-1, 1).squeeze(1) | |
# encoder, duplicates my mygan architecture as of 4/2017 | |
class _netE(nn.Module): | |
def __init__(self, ngpu=1, use_sigmoid=True, norm_layer=nn.BatchNorm2d, opt=None): | |
super(_netE, self).__init__() | |
self.ngpu = ngpu | |
nz = opt.nz | |
nc = opt.nc | |
ndf = opt.ndf | |
ngf = opt.ndf | |
size = opt.imageSize | |
if type(norm_layer) == functools.partial: | |
use_bias = norm_layer.func==nn.InstanceNorm2d | |
else: | |
use_bias = norm_layer==nn.InstanceNorm2d | |
sequence = [ | |
("inconv", nn.Conv2d(nc, ndf, kernel_size=3, stride=2, padding=1, bias=use_bias)), | |
("inrelu", nn.LeakyReLU(0.2, inplace=True))] | |
if size > 512: | |
more = [ | |
("conv1024a", nn.Conv2d(ndf, ndf, kernel_size=3, stride=2, padding=1, bias=use_bias)), | |
("conv1024b", nn.Conv2d(ndf, ndf, kernel_size=3, stride=1, padding=1, bias=use_bias)), | |
("norm1024", norm_layer(ndf)), | |
("relu1024", nn.LeakyReLU(0.2, inplace=True))] | |
sequence.extend(more) | |
if size > 256: | |
more = [ | |
("conv512a", nn.Conv2d(ndf, ndf, kernel_size=3, stride=2, padding=1, bias=use_bias)), | |
("conv512b", nn.Conv2d(ndf, ndf, kernel_size=3, stride=1, padding=1, bias=use_bias)), | |
("norm512", norm_layer(ndf)), | |
("relu512", nn.LeakyReLU(0.2, inplace=True))] | |
sequence.extend(more) | |
if size > 128: | |
more = [ | |
("conv256a", nn.Conv2d(ndf, ndf, kernel_size=3, stride=2, padding=1, bias=use_bias)), | |
("conv256b", nn.Conv2d(ndf, ndf, kernel_size=3, stride=1, padding=1, bias=use_bias)), | |
("norm256", norm_layer(ndf)), | |
("relu256", nn.LeakyReLU(0.2, inplace=True))] | |
sequence.extend(more) | |
if size > 64: | |
more = [ | |
("conv128a", nn.Conv2d(ndf, ndf, kernel_size=3, stride=2, padding=1, bias=use_bias)), | |
("conv128b", nn.Conv2d(ndf, ndf, kernel_size=3, stride=1, padding=1, bias=use_bias)), | |
("norm128", norm_layer(ndf)), | |
("relu128", nn.LeakyReLU(0.2, inplace=True))] | |
sequence.extend(more) | |
body = [ | |
("conv2", nn.Conv2d(ndf, ndf * 2, kernel_size=3, stride=2, padding=1, bias=use_bias)), | |
("conv3", nn.Conv2d(ndf*2, ndf * 2, kernel_size=3, stride=1, padding=1, bias=use_bias)), | |
("norm3", norm_layer(ndf * 2)), | |
("relu3", nn.LeakyReLU(0.2, inplace=True)), | |
("conv4", nn.Conv2d(ndf * 2, ndf * 4, kernel_size=3, stride=2, padding=1, bias=use_bias)), | |
("conv5", nn.Conv2d(ndf*4, ndf * 4, kernel_size=3, stride=1, padding=1, bias=use_bias)), | |
("norm5", norm_layer(ndf * 4)), | |
("relu5", nn.LeakyReLU(0.2, inplace=True)), | |
("conv6", nn.Conv2d(ndf * 4, ndf * 8, kernel_size=3, stride=2, padding=1, bias=use_bias)), | |
("conv7", nn.Conv2d(ndf*8, ndf * 8, kernel_size=3, stride=1, padding=1, bias=use_bias)), | |
("norm7", norm_layer(ndf * 8)), | |
("relu7", nn.LeakyReLU(0.2, inplace=True)), | |
("outconv", nn.Conv2d( ndf * 8, 100, kernel_size=4, stride=1, padding=0, bias=use_bias)) ] | |
sequence.extend(body) | |
if use_sigmoid: | |
sequence += [("outactiv", nn.Sigmoid())] | |
self.main = nn.Sequential(OrderedDict(sequence)) | |
def forward(self, input): | |
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1: | |
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) | |
else: | |
output = self.main(input) | |
return output.view(-1, 1).squeeze(1) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import argparse | |
import os | |
import sys | |
import random | |
import torch | |
import torch.nn as nn | |
import torch.nn.parallel | |
import torch.backends.cudnn as cudnn | |
import torch.optim as optim | |
import torch.utils.data | |
import torchvision.datasets as dset | |
import torchvision.transforms as transforms | |
import torchvision.utils as vutils | |
from torch.autograd import Variable | |
from model2c5 import _netG, _netD, _netE, weights_init, TVLoss | |
from torch.optim import lr_scheduler | |
import numpy as np | |
import cv2 | |
# gan trainer | |
# htoyryla 8 Jun 2018 | |
# | |
# put images in datasets/<name>/train/ | |
# | |
# v.2c3 adds noisy labels | |
# v.2c4 add loading of netE weights | |
# v.2c5 add non-strict (partial) loading of pretrained weights (requires model2c3) | |
# v.2c6 add optional freezing of already trained layers | |
# v.2c7 use noisy labels only in discriminator, models: use larger kernels in upper layers | |
# v.2c8 models changed so that in D & E, layers are added to input size as image size grows | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--dataset', default='folder', required=True, help=' folder | fake') | |
parser.add_argument('--dataroot', required=True, help='path to dataset') | |
parser.add_argument('--workers', type=int, default=8, help='number of data loading workers') | |
parser.add_argument('--batchSize', type=int, default=64, help='input batch size') | |
parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network') | |
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector') | |
parser.add_argument('--ngf', type=int, default=64) | |
parser.add_argument('--ndf', type=int, default=64) | |
parser.add_argument('--nc', type=int, default=3) | |
parser.add_argument('--lmbd', type=float, default=100, help='lambda, default=100') | |
parser.add_argument('--tvloss', type=float, default=0.0002, help='tv loss, default=0.0002') | |
parser.add_argument('--niter', type=int, default=60, help='number of epochs to train for') | |
parser.add_argument('--save_every', type=int, default=10, help='number of epochs between saves') | |
parser.add_argument('--imgStep', type=int, default=0, help='minibatches between image folder saves') | |
parser.add_argument('--lrG', type=float, default=0.0002, help='learning G rate, default=0.0002') | |
parser.add_argument('--lrD', type=float, default=0.00005, help='learning D rate, default=0.0002') | |
parser.add_argument('--lrE', type=float, default=0.0002, help='learning E rate, default=0.0002') | |
parser.add_argument('--step', type=int, default=40, help='lr step, default=40') | |
parser.add_argument('--gamma', type=float, default=0.1, help='gamma, default=0.1') | |
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') | |
parser.add_argument('--netG', default='', help="path to netG (to continue training)") | |
parser.add_argument('--netD', default='', help="path to netD (to continue training)") | |
parser.add_argument('--netE', default='', help="path to netE (to continue training)") | |
parser.add_argument('--name', default='baseline', help='folder to output images and model checkpoints') | |
parser.add_argument('--manualSeed', type=int, help='manual seed') | |
parser.add_argument('--gpu_ids', default='0', type=str, help='gpu_ids: e.g. 0 0,1,2 0,2') | |
parser.add_argument('--lsgan', action='store_true', help='use lsgan') | |
parser.add_argument('--instance', action='store_true', help='use instance norm') | |
parser.add_argument('--withoutE', action='store_true', help='do not use Encoder Network') | |
parser.add_argument('--debug', action='store_true', help='show debug info') | |
parser.add_argument('--hsv', action='store_true', help='use hsv color space') | |
parser.add_argument('--weight_decay', type=float, default=0, help='L2 regularization weight. Greatly helps convergence but leads to artifacts in images, not recommended.') | |
parser.add_argument('--nlabels', action='store_true', help='use noisy labels') | |
parser.add_argument('--nostrict', action='store_true', help='allow partial loading of pretrained nets') | |
parser.add_argument('--freeze', action='store_true', help='freeze already trained layers') | |
opt = parser.parse_args() | |
str_ids = opt.gpu_ids.split(',') | |
gpu_ids = [] | |
for str_id in str_ids: | |
id = int(str_id) | |
if id>=0: | |
gpu_ids.append(id) | |
print(opt) | |
try: | |
os.makedirs(os.path.join('./model',opt.name)) | |
os.makedirs(os.path.join('./visual',opt.name)) | |
except OSError: | |
pass | |
if opt.manualSeed is None: | |
opt.manualSeed = random.randint(1, 10000) | |
print("Random Seed: ", opt.manualSeed) | |
random.seed(opt.manualSeed) | |
torch.manual_seed(opt.manualSeed) | |
opt.cuda=False | |
if torch.cuda.is_available(): | |
opt.cuda=True | |
torch.cuda.manual_seed_all(opt.manualSeed) | |
torch.cuda.set_device(gpu_ids[0]) | |
cudnn.benchmark = True | |
if opt.dataset in ['imagenet', 'folder', 'lfw']: | |
dataset = dset.ImageFolder(root=opt.dataroot, | |
transform=transforms.Compose([ | |
transforms.Resize(opt.imageSize), | |
transforms.CenterCrop(opt.imageSize), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
])) | |
elif opt.dataset == 'fake': | |
dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize), | |
transform=transforms.ToTensor()) | |
assert dataset | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, | |
shuffle=True, num_workers=int(opt.workers)) | |
ngpu = len(gpu_ids) | |
nz = int(opt.nz) | |
ngf = int(opt.ngf) | |
ndf = int(opt.ndf) | |
nc = 3 | |
lmbd = opt.lmbd | |
assert(opt.imageSize in [64,128,256,512,1024]) | |
# pre/deprocessing for using hsv color space in model | |
# not fully working at the moment | |
def preproc(im, blur=0): | |
im = (im.numpy()*255).astype(np.uint8).transpose((2, 1, 0)) | |
if blur > 2: | |
im = cv2.medianBlur(im, blur) | |
hsv = cv2.cvtColor(im, cv2.COLOR_RGB2HSV) | |
h,s,v = cv2.split(hsv) | |
h = h / 179.9 | |
s = s / 255. | |
v = v / 255. | |
hsv = cv2.merge((h, s, v)).transpose(2,0,1) | |
hsv = (hsv - 0.5) | |
return hsv | |
def deproc(hsv, blur=0): | |
#hsv = torch.clamp(hsv, -1, 1) | |
hsv = hsv/2 + 0.5 | |
hsv = hsv.cpu().numpy().transpose(1,2,0) | |
h,s,v = cv2.split(hsv) | |
h = h * 179.9 | |
s = s * 255. | |
v = v * 255. | |
hsv = cv2.merge((h, s, v)).astype(np.uint8) | |
im = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) | |
im = im / 255. | |
im = im.transpose(2,1,0) | |
return im | |
# create models | |
# generator | |
if opt.instance: | |
netG = _netG(ngpu, norm_layer=nn.InstanceNorm2d, opt=opt) | |
netG.apply(weights_init) | |
else: | |
netG = _netG(ngpu, opt=opt) | |
netG.apply(weights_init) | |
# load prelearned weights if any | |
if opt.netG != '': | |
Gpar = torch.load(opt.netG) | |
try: | |
netG.load_state_dict(Gpar, strict = not opt.nostrict) | |
except RuntimeError: | |
print("Layer size mismatch during loading") | |
print(netG) | |
# discriminator | |
if opt.instance: | |
netD = _netD(ngpu, use_sigmoid=(not opt.lsgan), norm_layer=nn.InstanceNorm2d, opt=opt) | |
netD.apply(weights_init) | |
else: | |
netD = _netD(ngpu, use_sigmoid=(not opt.lsgan), opt=opt) | |
netD.apply(weights_init) | |
# load prelearned weights if any | |
if opt.netD != '': | |
Dpar = torch.load(opt.netD) | |
try: | |
netD.load_state_dict(Dpar, strict = not opt.nostrict) | |
except RuntimeError: | |
print("Layer size mismatch during loading") | |
print(netD) | |
# encoder | |
if not opt.withoutE: | |
if opt.instance: | |
netE = _netE(ngpu, use_sigmoid=(not opt.lsgan), norm_layer=nn.InstanceNorm2d, opt=opt) | |
netE.apply(weights_init) | |
else: | |
netE = _netE(ngpu, use_sigmoid=(not opt.lsgan), opt=opt) | |
netE.apply(weights_init) | |
# load prelearned weights if any | |
if opt.netE != '': | |
Epar = torch.load(opt.netE) | |
try: | |
netE.load_state_dict(Epar, strict = not opt.nostrict) | |
except RuntimeError: | |
print("Layer size mismatch during loading") | |
print(netE) | |
# freeze | |
if opt.freeze: | |
for key, mod in netG.main.named_children(): | |
k = key.split(".") | |
if k[0] == "outconv": continue #do not freeze the final output layer | |
layer = "main."+k[0] | |
w = layer+".weight" | |
if w in Gpar.keys(): # otherwise freeze all pretrained layers | |
print("freezing netG."+w) | |
mod.requires_grad = False | |
for key, mod in netD.main.named_children(): | |
k = key.split(".") | |
if k[0] == "inconv": continue #do not freeze the input layer | |
layer = "main."+k[0] | |
w = layer+".weight" | |
if w in Dpar.keys(): # otherwise freeze all pretrained layers | |
print("freezing netD."+w) | |
mod.requires_grad = False | |
for key, mod in netE.main.named_children(): | |
k = key.split(".") | |
if k[0] == "inconv": continue #do not freeze the input layer | |
layer = "main."+k[0] | |
w = layer+".weight" # otherwise freeze all pretrained layers | |
if w in Epar.keys(): | |
print("freezing netE."+w) | |
mod.requires_grad = False | |
# loss module for real / fake testing | |
class GANLoss(nn.Module): | |
def __init__(self, use_lsgan=False, target_real_label=1.0, target_fake_label=0.0, noisy=False, tensor=torch.FloatTensor): | |
super(GANLoss, self).__init__() | |
self.real_label = target_real_label | |
self.fake_label = target_fake_label | |
self.real_label_var = None | |
self.fake_label_var = None | |
self.Tensor = tensor | |
self.noisy = noisy | |
if use_lsgan: | |
self.loss = nn.MSELoss() | |
else: | |
self.loss = nn.BCELoss() | |
# make a target tensor for real and fake | |
# use noisy labels if opt.nlabels | |
def get_target_tensor(self, input, target_is_real): | |
target_tensor = None | |
if target_is_real: | |
if self.noisy: | |
real_tensor = self.Tensor(input.size()).uniform_(0.8, 1.0) | |
else: | |
real_tensor = self.Tensor(input.size()).fill_(self.real_label) | |
self.real_label_var = Variable(real_tensor, requires_grad=False) | |
target_tensor = self.real_label_var | |
else: | |
if self.noisy: | |
fake_tensor = self.Tensor(input.size()).uniform_(0, 0.2) | |
else: | |
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) | |
self.fake_label_var = Variable(fake_tensor, requires_grad=False) | |
target_tensor = self.fake_label_var | |
return target_tensor | |
def __call__(self, input, target_is_real): | |
target_tensor = self.get_target_tensor(input, target_is_real) | |
return self.loss(input, target_tensor) | |
# additional loss functions for AE loss | |
def mse_loss(input, target): | |
return torch.sum((input - target)**2) / input.data.nelement() | |
def l1_loss(input, target): | |
return torch.sum(torch.abs(input - target)) / input.data.nelement() | |
criterion = GANLoss(use_lsgan=opt.lsgan, tensor=torch.cuda.FloatTensor) | |
Dcriterion = GANLoss(use_lsgan=opt.lsgan, noisy = opt.nlabels, tensor=torch.cuda.FloatTensor) | |
criterionL1 = nn.L1Loss() | |
tvloss = TVLoss(opt.tvloss) | |
# general purpose vectors | |
input = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize) | |
noise = torch.FloatTensor(opt.batchSize, nz, 1, 1) | |
fixed_noise = torch.FloatTensor(opt.batchSize, nz, 1, 1).normal_(0, 1) | |
label = torch.FloatTensor(opt.batchSize) | |
real_label = 1 | |
fake_label = 0 | |
if opt.cuda: | |
netD.cuda() | |
netG.cuda() | |
if not opt.withoutE: | |
netE.cuda() | |
criterion.cuda() | |
criterionL1.cuda() | |
input, label = input.cuda(), label.cuda() | |
noise, fixed_noise = noise.cuda(), fixed_noise.cuda() | |
fixed_noise = Variable(fixed_noise) | |
# setup optimizer | |
optimizerD = optim.Adam(netD.parameters(), lr=opt.lrD, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay) | |
optimizerG = optim.Adam(netG.parameters(), lr=opt.lrG, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay) | |
if not opt.withoutE: | |
optimizerE = optim.Adam(netE.parameters(), lr=opt.lrE, betas=(opt.beta1, 0.999), weight_decay=opt.weight_decay) | |
schedulers = [] | |
schedulers.append(lr_scheduler.StepLR(optimizerD, step_size=opt.step, gamma=opt.gamma)) | |
schedulers.append(lr_scheduler.StepLR(optimizerG, step_size=opt.step, gamma=opt.gamma)) | |
if not opt.withoutE: | |
schedulers.append(lr_scheduler.StepLR(optimizerE, step_size=opt.step, gamma=opt.gamma)) | |
# main training loop starts here | |
imgCtr = 0 | |
for epoch in range(opt.niter): | |
# get a batch of input images | |
for i, data in enumerate(dataloader, 0): | |
iimgs = data[0].clone() # store input images for later display | |
# convert to hsv if needed | |
if opt.hsv: | |
data_ = [] | |
for im in data[0]: | |
data_.append(preproc(im, 0)) | |
data[0] = torch.Tensor(data_) | |
# update netD | |
# first with real | |
netD.zero_grad() | |
real_cpu, _ = data | |
batch_size = real_cpu.size(0) # needed for take care of an incomplete batch at the end of an epoch | |
if opt.cuda: | |
real_cpu = real_cpu.cuda() | |
input.resize_as_(real_cpu).copy_(real_cpu) | |
inputv = Variable(input) | |
output = netD(inputv) # get D(x) | |
errD_real = Dcriterion(output, True) # get err ref to real | |
errD_real.backward() # get D_real gradients | |
D_x = output.data.mean() | |
# train with fake | |
noise.resize_(batch_size, nz, 1, 1).normal_(0, 1) # get z | |
noisev = Variable(noise) | |
fake_z = netG(noisev) # G(z) | |
output = netD(fake_z.detach()) # D(G(z)) | |
errD_fake = Dcriterion(output, False) # get err ref to fake | |
errD_fake.backward() # get D_fake gradients | |
D_G_z1 = output.data.mean() | |
errD = errD_real + errD_fake # total D err display | |
optimizerD.step() # update D weights | |
# | |
# Update netG | |
# | |
netG.zero_grad() | |
output = netD(fake_z) # use G(z) fake from above, TODO should we use another? | |
tvl = 0 | |
# if encoder used | |
if not opt.withoutE: | |
# encode input into Z | |
embedding = netE(inputv.detach()).view(batch_size,opt.nz,1,1) # z = E(x) | |
fake_e = netG(embedding.detach()) # new fake = G(E(x)), detach from E, train E later | |
errG = criterion(output, True) # get D(G(x)) err | |
dist = l1_loss(inputv.detach(), fake_e)*lmbd # get err between input and G(E(x)) | |
errG = errG + dist | |
if opt.tvloss: | |
tvl = tvloss(fake_e) # tv loss on G(E(x)) | |
errG = errG + tvl | |
else: | |
errG = criterion(output, True) # no E used, take plain gan loss as errG | |
if opt.tvloss: | |
tvl = tvloss(fake_z) # just add TV loss on G(z) | |
errG = errG + tvl | |
dist = 0 | |
errG.backward() # get G gradients | |
D_G_z2 = output.data.mean() | |
optimizerG.step() # update G parameters | |
# Update E | |
if not opt.withoutE: | |
netE.zero_grad() | |
embedding = netE(fake_z.detach()) # E(G(z)) | |
errE = criterionL1(embedding.view(batch_size, opt.nz, 1, 1), noisev) # err between E(G(z)) and z | |
errE.backward() | |
optimizerE.step() | |
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_E: %.4f D(x): %.4f Dist: %4f TVLoss: %4f D(G(z)): %.4f / %.4f' | |
% (epoch, opt.niter, i, len(dataloader), | |
errD.data, errG.data, errE.data, D_x, dist, tvl, D_G_z1, D_G_z2)) | |
else: | |
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f TVLss: %4f' | |
% (epoch, opt.niter, i, len(dataloader), | |
errD.data, errG.data, D_x, D_G_z1, D_G_z2, tvl)) | |
# save single samples if opt.imgStep > 0 | |
if opt.imgStep != 0 and imgCtr % opt.imgStep == 0: | |
sampleNoise = torch.FloatTensor(opt.batchSize, nz, 1, 1).normal_(0, 1) | |
if opt.cuda: sampleNoise = sampleNoise.cuda() | |
fakeimg = netG(Variable(sampleNoise)) | |
fakeimg = fakeimg.data | |
if opt.hsv: | |
sample = [] | |
for hsv in fakeimg: | |
sample.append(deproc(hsv)) | |
fakeimg = torch.Tensor(sample) | |
vutils.save_image(fakeimg, | |
'./images/'+opt.name+'-sample%06d.png' % (int(imgCtr/opt.imgStep)), | |
normalize=True) | |
imgCtr = imgCtr + 1 | |
# visualize results | |
if i % 100 == 0: | |
vutils.save_image(iimgs, | |
'./visual/%s/real_samples.png' % opt.name, | |
normalize=True) | |
fake = netG(fixed_noise) | |
fake = fake.data | |
if opt.hsv: | |
sample = [] | |
for hsv in fake: | |
sample.append(deproc(hsv)) | |
fake = torch.Tensor(sample) | |
print('saving fakes ', fake.shape) | |
vutils.save_image(fake, | |
'./visual/%s/fake_samples_epoch_%03d.png' % (opt.name, epoch), | |
normalize=True) | |
# do checkpointing | |
if epoch % opt.save_every == 0: | |
torch.save(netG.state_dict(), './model/%s/netG_epoch_%d.pth' % (opt.name, epoch)) | |
torch.save(netD.state_dict(), './model/%s/netD_epoch_%d.pth' % (opt.name, epoch)) | |
if not opt.withoutE: | |
torch.save(netE.state_dict(), './model/%s/netE_epoch_%d.pth' % (opt.name, epoch)) | |
#step lrRate | |
for scheduler in schedulers: | |
scheduler.step() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment