Last active
June 27, 2018 16:46
-
-
Save goldsborough/510b9475aec37bd7c1178c6ef53e2422 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import argparse | |
import os | |
import random | |
import torch | |
import torch.nn as nn | |
import torch.nn.parallel | |
import torch.backends.cudnn as cudnn | |
import torch.optim as optim | |
import torch.utils.data | |
import torchvision.datasets as dset | |
import torchvision.transforms as transforms | |
import torchvision.utils as vutils | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--dataroot', required=True, help='path to dataset') | |
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers') | |
parser.add_argument('--epochs', type=int, default=25, help='number of epochs to train for') | |
parser.add_argument('--lr', type=float, default=2e-4, help='learning rate') | |
parser.add_argument('--cuda', action='store_true', help='enables cuda') | |
parser.add_argument('--gpus', type=int, default=1, help='number of GPUs to use') | |
parser.add_argument('--output-folder', default='out', help='output directory') | |
options = parser.parse_args() | |
print(options) | |
random.seed(123) | |
torch.manual_seed(123) | |
cudnn.benchmark = True | |
if not os.path.exists(options.output_folder): | |
os.makedirs(options.output_folder) | |
device = torch.device("cuda:0" if options.cuda else "cpu") | |
noise_size = 100 | |
batch_size = 64 | |
dataset = dset.MNIST(root=options.dataroot, download=True, | |
transform=transforms.Compose([ | |
transforms.Resize(28), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]), | |
])) | |
dataloader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
num_workers=options.workers, | |
drop_last=True) | |
# custom weights initialization called on generator and discriminator | |
def weights_init(m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
m.weight.data.normal_(0.0, 0.02) | |
elif classname.find('BatchNorm') != -1: | |
m.weight.data.normal_(1.0, 0.02) | |
m.bias.data.fill_(0) | |
class Generator(nn.Module): | |
def __init__(self, gpus): | |
super(Generator, self).__init__() | |
self.gpus = gpus | |
self.main = nn.Sequential( | |
# input is Z, going into a convolution | |
nn.ConvTranspose2d(noise_size, 256, kernel_size=4, bias=False), | |
nn.BatchNorm2d(256), | |
nn.ReLU(inplace=True), | |
# state size. 256 x 4 x 4 | |
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, bias=False), | |
nn.BatchNorm2d(128), | |
nn.ReLU(True), | |
# state size. 128 x 7 x 7 | |
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False), | |
nn.BatchNorm2d(64), | |
nn.ReLU(True), | |
# state size. 64 x 14 x 14 | |
nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1, bias=False), | |
nn.Tanh() | |
# state size. 1 x 28 x 28 | |
) | |
def forward(self, input): | |
if input.is_cuda and self.gpus > 1: | |
output = nn.parallel.data_parallel(self.main, input, range(self.gpus)) | |
else: | |
output = self.main(input) | |
return output | |
generator = Generator(options.gpus).to(device) | |
generator.apply(weights_init) | |
class Discriminator(nn.Module): | |
def __init__(self, gpus): | |
super(Discriminator, self).__init__() | |
self.gpus = gpus | |
self.main = nn.Sequential( | |
# input is 1 x 28 x 28 | |
nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1, bias=False), | |
nn.LeakyReLU(0.2, inplace=True), | |
# state size. 64 x 14 x 14 | |
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False), | |
nn.BatchNorm2d(64 * 2), | |
nn.LeakyReLU(0.2, inplace=True), | |
# state size. 128 x 7 x 7 | |
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False), | |
nn.BatchNorm2d(256), | |
nn.LeakyReLU(0.2, inplace=True), | |
# state size. (64*4) x 3 x 3 | |
nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=0, bias=False), | |
nn.Sigmoid() | |
) | |
def forward(self, input): | |
if input.is_cuda and self.gpus > 1: | |
output = nn.parallel.data_parallel(self.main, input, range(self.gpus)) | |
else: | |
output = self.main(input) | |
return output.view(-1, 1).squeeze(1) | |
discriminator = Discriminator(options.gpus).to(device) | |
discriminator.apply(weights_init) | |
criterion = nn.BCELoss() | |
fixed_noise = torch.randn(batch_size, noise_size, 1, 1, device=device) | |
# setup optimizer | |
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=options.lr, betas=(0.5, 0.999)) | |
generator_optimizer = optim.Adam(generator.parameters(), lr=options.lr, betas=(0.5, 0.999)) | |
for epoch in range(options.epochs): | |
for i, data in enumerate(dataloader, 0): | |
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) | |
# train with real | |
discriminator.zero_grad() | |
real_images = data[0].to(device) | |
real_labels = torch.empty([batch_size], device=device).uniform_(0.8, 1.0) | |
output = discriminator(real_images) | |
d_loss_real = criterion(output, real_labels) | |
d_loss_real.backward() | |
# train with fake images | |
noise = torch.randn([batch_size, noise_size, 1, 1], device=device) | |
fake_images = generator(noise) | |
fake_labels = torch.zeros([batch_size], device=device) | |
output = discriminator(fake_images.detach()) | |
d_loss_fake = criterion(output, fake_labels) | |
d_loss_fake.backward() | |
d_loss = d_loss_real + d_loss_fake | |
discriminator_optimizer.step() | |
# (2) Update G network: maximize log(D(G(z))) | |
generator.zero_grad() | |
fake_labels = torch.ones([batch_size], device=device) | |
output = discriminator(fake_images) | |
g_loss = criterion(output, fake_labels) | |
g_loss.backward() | |
generator_optimizer.step() | |
print('[{}/{}][{}/{}] Loss_D: {:.4f} Loss_G: {:.4f}' | |
.format(epoch, options.epochs, i, len(dataloader), d_loss.item(), g_loss.item())) | |
if i % 100 == 0: | |
vutils.save_image(real_images, 'out/real_samples.png', normalize=True) | |
fake_images = generator(fixed_noise) | |
vutils.save_image(fake_images.detach(), 'out/fake_samples_epoch_{}.png'.format(epoch), normalize=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment