Created
December 19, 2017 06:08
-
-
Save ronekko/b339dfd30319e97ddf074cd74f1d3cdd to your computer and use it in GitHub Desktop.
Generative adversarial network (GAN) for 1-dimensional Gaussian
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Created on Fri Dec 15 15:22:04 2017 | |
@author: sakurai | |
""" | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import chainer | |
import chainer.functions as F | |
import chainer.links as L | |
from chainer import cuda | |
class Generator(chainer.Chain): | |
def __init__(self, c_in=10, c_out=1): | |
c_1 = 200 | |
c_2 = 200 | |
super(Generator, self).__init__( | |
fc1=L.Linear(c_in, c_1), | |
fc2=L.Linear(c_1, c_2), | |
fc3=L.Linear(c_2, c_out)) | |
self.dim_z = c_in | |
def __call__(self, z): | |
h = F.leaky_relu(self.fc1(z)) | |
h = F.leaky_relu(self.fc2(h)) | |
return self.fc3(h) | |
def draw_sample(self, size=100): | |
z = self.xp.random.uniform(-1, 1, size=(size, self.dim_z)).astype('f') | |
return self(z) | |
class Discriminator(chainer.Chain): | |
def __init__(self, c_in=1): | |
c_1 = 200 | |
c_2 = 200 | |
super(Discriminator, self).__init__( | |
fc1=L.Linear(c_in, c_1), | |
fc2=L.Linear(c_1, c_2), | |
fc3=L.Linear(c_2, 1)) | |
def __call__(self, x): | |
h = F.elu(self.fc1(x)) | |
h = F.elu(self.fc2(h)) | |
return self.fc3(h) | |
if __name__ == '__main__': | |
real_mean = 10 | |
real_std = 2 | |
N = 1000 | |
dim_z = 100 | |
update_generator_interval = 5 | |
plot_x_scale = 20 | |
num_epochs = 10000 | |
batch_size = 100 | |
test_size = 50000 | |
alpha = 0.001 | |
x_train = real_std * np.random.randn(N, 1).astype(np.float32) + real_mean | |
ds_train = chainer.datasets.TupleDataset(x_train) | |
it_real = chainer.iterators.SerialIterator(ds_train, batch_size) | |
generator = Generator(dim_z) | |
discriminator = Discriminator() | |
opt_g = chainer.optimizers.Adam(alpha) | |
opt_d = chainer.optimizers.Adam(alpha) | |
opt_g.setup(generator) | |
opt_d.setup(discriminator) | |
# opt_d.add_hook(chainer.optimizer.WeightDecay(0.01)) | |
t_fake = np.zeros((batch_size, 1), dtype=np.int32) | |
t_real = np.ones((batch_size, 1), dtype=np.int32) | |
for epoch in range(num_epochs + 1): | |
x_fake = generator.draw_sample(batch_size) | |
y_fake = discriminator(x_fake) | |
loss_g = F.sigmoid_cross_entropy(y_fake, t_real) | |
loss_d = F.sigmoid_cross_entropy(y_fake, t_fake) | |
x_real = chainer.dataset.concat_examples(next(it_real))[0] | |
x_real = chainer.Variable(x_real) | |
y_real = discriminator(x_real) | |
loss_d += F.sigmoid_cross_entropy(y_real, t_real) | |
discriminator.cleargrads() | |
loss_d.backward() | |
opt_d.update() | |
if epoch % update_generator_interval == 0: | |
generator.cleargrads() | |
loss_g.backward() | |
opt_g.update() | |
# plot | |
if epoch % 10 == 0 and epoch != 0: | |
print(f'# {epoch}') | |
# drawn samples | |
plt.hist(cuda.to_cpu(x_real.data).ravel(), color='b') | |
plt.hist(cuda.to_cpu(x_fake.data).ravel(), color='r', alpha=0.5) | |
plt.title('Histograms of real and fake sample (mini-batch)') | |
plt.legend(['Real', 'Fake']) | |
plt.xlabel('x') | |
plt.ylabel('Counts') | |
r = plot_x_scale | |
plt.xlim(real_mean - r * real_std, real_mean + r * real_std) | |
plt.ylim(0, 30) | |
plt.grid() | |
plt.show() | |
# discriminator's value | |
x = np.linspace(real_mean - r * real_std, real_mean + r * real_std, | |
100, dtype='f').reshape(-1, 1) | |
d = F.sigmoid(discriminator(x)) | |
plt.plot(x, d.data.ravel()) | |
plt.plot(x, np.full_like(x, 0.5), '--') | |
plt.title('Discriminator\'s values for each x') | |
plt.legend(['$\sigma(D(x))$', '0.5']) | |
plt.xlabel('x') | |
plt.ylabel('Discriminator\'s prediction (with sigmoid)') | |
plt.xlim(real_mean - r * real_std, real_mean + r * real_std) | |
plt.ylim(0 - 0.05, 1 + 0.05) | |
plt.grid() | |
plt.show() | |
if epoch % 100 == 0 and epoch != 0: | |
x = real_std * np.random.randn(test_size) + real_mean | |
plt.hist(x, 50, color='b') | |
xs = [] | |
for i in range(test_size // (batch_size * 10)): | |
with chainer.no_backprop_mode(): | |
xs.append(generator.draw_sample(batch_size * 10).data) | |
x = np.concatenate(xs).ravel() | |
plt.hist(cuda.to_cpu(x), 50, color='r', alpha=0.5) | |
plt.title( | |
'Histograms of real and fake sample ({} examples)'.format( | |
test_size)) | |
plt.legend(['Real', 'Fake']) | |
plt.xlabel('x') | |
plt.ylabel('Counts') | |
plt.xlim(real_mean - 3 * real_std, real_mean + 3 * real_std) | |
plt.grid() | |
plt.show() | |
print('Sample mean = {}'.format(x.mean())) | |
print('Sample std = {}'.format(x.std())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment