Created
October 28, 2018 09:44
-
-
Save jkjung-avt/ab01a4f2ab861d21d2345b7f9ebe80f4 to your computer and use it in GitHub Desktop.
A simple DCGAN with MNIST
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""dcgan_mnist.py | |
This script was orginally written by Rowel Atienza (see below), and | |
was modified by JK Jung <[email protected]>. | |
------ | |
DCGAN on MNIST using Keras | |
Author: Rowel Atienza | |
Project: https://github.com/roatienza/Deep-Learning-Experiments | |
Dependencies: tensorflow 1.0 and keras 2.0 | |
Usage: python3 dcgan_mnist.py | |
""" | |
import time | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import tensorflow as tf | |
from tensorflow.examples.tutorials.mnist import input_data | |
from keras import backend as K | |
from keras.models import Sequential | |
from keras.layers import Dense, Activation, Flatten, Reshape | |
from keras.layers import Conv2D, Conv2DTranspose, UpSampling2D | |
from keras.layers import LeakyReLU, Dropout | |
from keras.layers import BatchNormalization | |
from keras.optimizers import Adam, RMSprop | |
EMBEDDING_DIM = 1024 | |
DROPOUT_RATE = 0.4 | |
GPU_MEM_FRACTION = 0.2 | |
class ElapsedTimer(object): | |
def __init__(self): | |
self.start_time = time.time() | |
def elapsed(self, sec): | |
if sec < 60: | |
return str(sec) + " sec" | |
elif sec < (60 * 60): | |
return str(sec / 60) + " min" | |
else: | |
return str(sec / (60 * 60)) + " hr" | |
def elapsed_time(self): | |
print("Elapsed: %s" % self.elapsed(time.time() - self.start_time)) | |
class DCGAN(object): | |
def __init__(self, img_rows=28, img_cols=28, channel=1): | |
self.img_rows = img_rows | |
self.img_cols = img_cols | |
self.channel = channel | |
self.D = None # discriminator | |
self.G = None # generator | |
self.AM = None # adversarial model | |
self.DM = None # discriminator model | |
# (W−F+2P)/S+1 | |
def discriminator(self): | |
if self.D: | |
return self.D | |
self.D = Sequential() | |
depth = 64 | |
dropout = DROPOUT_RATE | |
# In: 28 x 28 x 1, depth = 1 | |
# Out: 14 x 14 x 1, depth=64 | |
input_shape = (self.img_rows, self.img_cols, self.channel) | |
self.D.add(Conv2D(depth*1, 5, strides=2, | |
input_shape=input_shape, padding='same')) | |
self.D.add(LeakyReLU(alpha=0.2)) | |
self.D.add(Dropout(dropout)) | |
self.D.add(Conv2D(depth*2, 5, strides=2, padding='same')) | |
self.D.add(LeakyReLU(alpha=0.2)) | |
self.D.add(Dropout(dropout)) | |
self.D.add(Conv2D(depth*4, 5, strides=2, padding='same')) | |
self.D.add(LeakyReLU(alpha=0.2)) | |
self.D.add(Dropout(dropout)) | |
self.D.add(Conv2D(depth*8, 5, strides=1, padding='same')) | |
self.D.add(LeakyReLU(alpha=0.2)) | |
self.D.add(Dropout(dropout)) | |
# Out: 1-dim probability | |
self.D.add(Flatten()) | |
self.D.add(Dense(1)) | |
self.D.add(Activation('sigmoid')) | |
self.D.summary() | |
return self.D | |
def generator(self): | |
if self.G: | |
return self.G | |
self.G = Sequential() | |
dropout = DROPOUT_RATE | |
depth = 64+64+64+64 | |
dim = 7 | |
# In: EMBEDDING_DIM | |
# Out: dim x dim x depth | |
self.G.add(Dense(dim*dim*depth, input_dim=EMBEDDING_DIM)) | |
self.G.add(BatchNormalization(momentum=0.9)) | |
self.G.add(Activation('relu')) | |
self.G.add(Reshape((dim, dim, depth))) | |
self.G.add(Dropout(dropout)) | |
# In: dim x dim x depth | |
# Out: 2*dim x 2*dim x depth/2 | |
self.G.add(UpSampling2D()) | |
self.G.add(Conv2DTranspose(int(depth/2), 5, padding='same')) | |
self.G.add(BatchNormalization(momentum=0.9)) | |
self.G.add(Activation('relu')) | |
self.G.add(Conv2DTranspose(int(depth/2), 5, padding='same')) | |
self.G.add(BatchNormalization(momentum=0.9)) | |
self.G.add(Activation('relu')) | |
self.G.add(UpSampling2D()) | |
self.G.add(Conv2DTranspose(int(depth/4), 5, padding='same')) | |
self.G.add(BatchNormalization(momentum=0.9)) | |
self.G.add(Activation('relu')) | |
self.G.add(Conv2DTranspose(int(depth/4), 5, padding='same')) | |
self.G.add(BatchNormalization(momentum=0.9)) | |
self.G.add(Activation('relu')) | |
self.G.add(Conv2DTranspose(int(depth/8), 5, padding='same')) | |
self.G.add(BatchNormalization(momentum=0.9)) | |
self.G.add(Activation('relu')) | |
# Out: 28 x 28 x 1 grayscale image [0.0,1.0] per pix | |
self.G.add(Conv2DTranspose(1, 5, padding='same')) | |
self.G.add(Activation('sigmoid')) | |
self.G.summary() | |
return self.G | |
def discriminator_model(self): | |
if self.DM: | |
return self.DM | |
optimizer = RMSprop(lr=0.0002, decay=6e-8) | |
self.DM = Sequential() | |
self.DM.add(self.discriminator()) | |
self.DM.compile(loss='binary_crossentropy', | |
optimizer=optimizer, | |
metrics=['accuracy']) | |
return self.DM | |
def adversarial_model(self): | |
if self.AM: | |
return self.AM | |
optimizer = RMSprop(lr=0.0001, decay=3e-8) | |
self.AM = Sequential() | |
self.AM.add(self.generator()) | |
self.AM.add(self.discriminator()) | |
self.AM.compile(loss='binary_crossentropy', | |
optimizer=optimizer, | |
metrics=['accuracy']) | |
return self.AM | |
class MNIST_DCGAN(object): | |
def __init__(self): | |
self.img_rows = 28 | |
self.img_cols = 28 | |
self.channel = 1 | |
mnist = input_data.read_data_sets('mnist', one_hot=True) | |
self.x_train = mnist.train.images.reshape( | |
-1, self.img_rows, self.img_cols, 1) | |
self.x_train = self.x_train.astype(np.float32) | |
self.DCGAN = DCGAN() | |
self.generator = self.DCGAN.generator() | |
self.discriminator = self.DCGAN.discriminator() | |
self.discriminator_model = self.DCGAN.discriminator_model() | |
self.adversarial_model = self.DCGAN.adversarial_model() | |
def set_discriminator_trainable(self): | |
for layer in self.discriminator.layers: | |
layer.trainable = True | |
def set_discriminator_untrainable(self): | |
for layer in self.discriminator.layers: | |
layer.trainable = False | |
def train(self, train_steps=2000, batch_size=256, save_interval=0): | |
noise_input = None | |
if save_interval > 0: | |
noise_input = np.random.uniform(-1.0, 1.0, | |
size=[16, EMBEDDING_DIM]) | |
for i in range(train_steps): | |
# sample some real images | |
images_train = self.x_train[ | |
np.random.randint(0, self.x_train.shape[0], | |
size=batch_size), :, :, :] | |
# use the generator to generate same number of fake images | |
noise = np.random.uniform(-1.0, 1.0, | |
size=[batch_size, EMBEDDING_DIM]) | |
images_fake = self.generator.predict(noise) | |
# stack real and fake images together | |
x = np.concatenate((images_train, images_fake)) | |
# label real images as 1, fake images as 0 | |
y = np.ones([batch_size*2, 1]) | |
y[batch_size:, :] = 0 | |
# train the discriminator for 1 step | |
self.set_discriminator_trainable() | |
d_loss = self.discriminator_model.train_on_batch(x, y) | |
# then train the whole GAN for 1 step; note that we freeze | |
# the weights in the discriminator and set the labels to 1 | |
# here, so we are (hopefully) effectively training the | |
# generator for 1 step | |
y = np.ones([batch_size, 1]) | |
noise = np.random.uniform(-1.0, 1.0, | |
size=[batch_size, EMBEDDING_DIM]) | |
self.set_discriminator_untrainable() | |
a_loss = self.adversarial_model.train_on_batch(noise, y) | |
log_mesg = "%d: [D loss: %f, acc: %f]" % \ | |
(i, d_loss[0], d_loss[1]) | |
log_mesg = "%s [A loss: %f, acc: %f]" % \ | |
(log_mesg, a_loss[0], a_loss[1]) | |
print(log_mesg) | |
if save_interval > 0: | |
if (i+1) % save_interval == 0: | |
self.plot_images(save2file=True, | |
samples=noise_input.shape[0], | |
noise=noise_input, | |
step=(i+1)) | |
def plot_images(self, save2file=False, fake=True, samples=16, | |
noise=None, step=0): | |
filename = 'sample.png' | |
if fake: | |
if noise is None: | |
noise = np.random.uniform(-1.0, 1.0, | |
size=[samples, EMBEDDING_DIM]) | |
else: | |
filename = "mnist_%d.png" % step | |
images = self.generator.predict(noise) | |
else: | |
i = np.random.randint(0, self.x_train.shape[0], samples) | |
images = self.x_train[i, :, :, :] | |
plt.figure(figsize=(10, 10)) | |
for i in range(images.shape[0]): | |
plt.subplot(4, 4, i+1) | |
image = images[i, :, :, :] | |
image = np.reshape(image, [self.img_rows, self.img_cols]) | |
plt.imshow(image, cmap='gray') | |
plt.axis('off') | |
plt.tight_layout() | |
if save2file: | |
plt.savefig(filename) | |
plt.close('all') | |
else: | |
plt.show() | |
if __name__ == '__main__': | |
config = tf.ConfigProto() | |
config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION | |
session = tf.Session(config=config) | |
K.set_session(session) | |
mnist_dcgan = MNIST_DCGAN() | |
timer = ElapsedTimer() | |
mnist_dcgan.train(train_steps=10000, batch_size=256, save_interval=500) | |
timer.elapsed_time() | |
# mnist_dcgan.plot_images(fake=True, save2file=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment