Last active
December 30, 2016 16:00
-
-
Save persiyanov/bf276852588536e18ed57b8705cfd686 to your computer and use it in GitHub Desktop.
Original GAN on MNIST
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### MNIST fetching" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from sklearn.datasets import fetch_mldata\n", | |
| "import numpy as np" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "mnist = fetch_mldata('MNIST original', data_home='data/mnist/')\n", | |
| "np.save('data/mnist/mnist', mnist.data)\n", | |
| "mnist = mnist.data" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "mnist = np.load('data/mnist/mnist.npy')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "-----" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import matplotlib.pyplot as plt\n", | |
| "%matplotlib inline\n", | |
| "plt.rcParams.update({'axes.titlesize': 'small'})" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 70, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "env: THEANO_FLAGS=\"device=gpu7\"\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%env THEANO_FLAGS=\"device=gpu7\"\n", | |
| "import theano\n", | |
| "import theano.tensor as T\n", | |
| "import lasagne\n", | |
| "from lasagne.layers import *\n", | |
| "from lasagne.regularization import regularize_network_params, l2\n", | |
| "from lasagne.objectives import binary_crossentropy\n", | |
| "\n", | |
| "theano.config.floatX = 'float32'" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "mnistX = mnist.astype(np.float32)/255." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "IMG_SHAPE = (28,28)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 71, | |
| "metadata": { | |
| "collapsed": false, | |
| "scrolled": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "inputX = T.matrix('input_img', 'float32') # [batch_size, num_channels, height, width]\n", | |
| "inputY = T.vector('input_labels') # [batch_size] of zeros or ones\n", | |
| "\n", | |
| "inputZ = T.matrix('input_noise', 'float32')\n", | |
| "\n", | |
| "CODE_SIZE = 100 # size of uniform noise\n", | |
| "\n", | |
| "class generator:\n", | |
| " l_in = InputLayer((None, CODE_SIZE), input_var=inputZ, name='g_input')\n", | |
| " l_dense0 = DenseLayer(l_in, 256, name='g_dense0')\n", | |
| " l_dense1 = DenseLayer(l_dense0, 512, name='g_dense1')\n", | |
| " l_gen = DenseLayer(l_dense1, IMG_SHAPE[0]*IMG_SHAPE[1], name='g_gen_layer')\n", | |
| " \n", | |
| " generated_img = get_output(l_gen)\n", | |
| " \n", | |
| " weights = get_all_params(l_gen, trainable=True)\n", | |
| " \n", | |
| " \n", | |
| "class discriminator:\n", | |
| " l_in = InputLayer((None, IMG_SHAPE[0]*IMG_SHAPE[1]), name='d_input')\n", | |
| " l_dense0 = dropout(DenseLayer(l_in, 512, name='d_dense0'))\n", | |
| " l_dense1 = dropout(DenseLayer(l_dense0, 256, name='d_dense1'))\n", | |
| " l_dense2 = dropout(DenseLayer(l_dense1, 128, name='d_dense2'))\n", | |
| " l_prob = DenseLayer(l_dense2, 1, nonlinearity=lasagne.nonlinearities.sigmoid, name='d_prob')\n", | |
| " \n", | |
| " noise_prob = get_output(l_prob, inputs=generator.generated_img)\n", | |
| " img_prob = get_output(l_prob, inputs=inputX)\n", | |
| " img_prob_determ = get_output(l_prob, inputs=inputX, deterministic=True) # for generating images in reports\n", | |
| " \n", | |
| " weights = get_all_params(l_prob, trainable=True)\n", | |
| " \n", | |
| "class training:\n", | |
| " # Minimizing -log(D(G(z)) instead of log(1-D(G(z))) is a hack for avoiding gradient vanishing.\n", | |
| " g_loss = binary_crossentropy(discriminator.noise_prob, T.ones_like(discriminator.noise_prob)).mean()\n", | |
| "# g_loss += regularize_network_params(generator.l_gen, l2) * 0.001\n", | |
| " \n", | |
| " g_updates = lasagne.updates.adam(g_loss, generator.weights)\n", | |
| " g_train_step = theano.function([inputZ], g_loss, updates=g_updates, allow_input_downcast=True)\n", | |
| "\n", | |
| " d_loss = binary_crossentropy(discriminator.img_prob, T.ones_like(discriminator.img_prob)).mean()\n", | |
| " d_loss += binary_crossentropy(discriminator.noise_prob, T.zeros_like(discriminator.noise_prob)).mean()\n", | |
| " \n", | |
| " d_updates = lasagne.updates.adam(d_loss, discriminator.weights)\n", | |
| " d_train_step = theano.function([inputX, inputZ], d_loss, updates=d_updates, allow_input_downcast=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 72, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array(1.2878267765045166, dtype=float32)" | |
| ] | |
| }, | |
| "execution_count": 72, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "training.d_train_step(sample_data_batch(100), sample_noise_batch(100))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 73, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array(0.6745600700378418, dtype=float32)" | |
| ] | |
| }, | |
| "execution_count": 73, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "training.g_train_step(sample_noise_batch(100))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 74, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(array(1.1519067287445068, dtype=float32),\n", | |
| " array(0.6140279173851013, dtype=float32))" | |
| ] | |
| }, | |
| "execution_count": 74, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "training.d_train_step(sample_data_batch(256), sample_noise_batch(256)), training.g_train_step(sample_noise_batch(256))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 75, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def sample_noise_batch(bsize):\n", | |
| " return np.random.uniform(-1.,1.,size=(bsize, CODE_SIZE))\n", | |
| "\n", | |
| "def sample_data_batch(bsize):\n", | |
| " idxs = np.random.choice(np.arange(mnistX.shape[0]), size=bsize)\n", | |
| " return mnistX[idxs]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Let's generate some images from untrained generator." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 76, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "gen_imgs = theano.function([inputZ], generator.generated_img, allow_input_downcast=True)\n", | |
| "get_img_prob = theano.function([inputX], discriminator.img_prob_determ, allow_input_downcast=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 77, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def generate_from_noise():\n", | |
| " generated_imgs = gen_imgs(sample_noise_batch(25))\n", | |
| " imgs_probs = get_img_prob(generated_imgs)\n", | |
| " plt.figure(figsize=(10,10))\n", | |
| " for i in range(25):\n", | |
| " plt.subplot(5,5,i+1)\n", | |
| " plt.imshow(generated_imgs[i].reshape(IMG_SHAPE))\n", | |
| " plt.title('d_prob={:.3f}'.format(imgs_probs[i][0]))\n", | |
| " plt.axis('off')\n", | |
| " plt.show()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 78, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[ 0.7369349 ],\n", | |
| " [ 0.80481601],\n", | |
| " [ 0.66464865],\n", | |
| " [ 0.88424724],\n", | |
| " [ 0.81623596],\n", | |
| " [ 0.94289351],\n", | |
| " [ 0.81866348],\n", | |
| " [ 0.8879593 ],\n", | |
| " [ 0.80184805],\n", | |
| " [ 0.81047082]], dtype=float32)" | |
| ] | |
| }, | |
| "execution_count": 78, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "get_img_prob(sample_data_batch(10))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 79, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[ 0.62268591],\n", | |
| " [ 0.60116392],\n", | |
| " [ 0.59812319],\n", | |
| " [ 0.61345595],\n", | |
| " [ 0.58478087],\n", | |
| " [ 0.61283481],\n", | |
| " [ 0.57961714],\n", | |
| " [ 0.60982937],\n", | |
| " [ 0.63371795],\n", | |
| " [ 0.59436017]], dtype=float32)" | |
| ] | |
| }, | |
| "execution_count": 79, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "get_img_prob(gen_imgs(sample_noise_batch(10)))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 80, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment