Skip to content

Instantly share code, notes, and snippets.

@persiyanov
Last active December 30, 2016 16:00
Show Gist options
  • Select an option

  • Save persiyanov/bf276852588536e18ed57b8705cfd686 to your computer and use it in GitHub Desktop.

Select an option

Save persiyanov/bf276852588536e18ed57b8705cfd686 to your computer and use it in GitHub Desktop.
Original GAN on MNIST
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### MNIST fetching"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.datasets import fetch_mldata\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"mnist = fetch_mldata('MNIST original', data_home='data/mnist/')\n",
"np.save('data/mnist/mnist', mnist.data)\n",
"mnist = mnist.data"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"mnist = np.load('data/mnist/mnist.npy')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"-----"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"plt.rcParams.update({'axes.titlesize': 'small'})"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"env: THEANO_FLAGS=\"device=gpu7\"\n"
]
}
],
"source": [
"%env THEANO_FLAGS=\"device=gpu7\"\n",
"import theano\n",
"import theano.tensor as T\n",
"import lasagne\n",
"from lasagne.layers import *\n",
"from lasagne.regularization import regularize_network_params, l2\n",
"from lasagne.objectives import binary_crossentropy\n",
"\n",
"theano.config.floatX = 'float32'"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"mnistX = mnist.astype(np.float32)/255."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"IMG_SHAPE = (28,28)"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [],
"source": [
"inputX = T.matrix('input_img', 'float32') # [batch_size, num_channels, height, width]\n",
"inputY = T.vector('input_labels') # [batch_size] of zeros or ones\n",
"\n",
"inputZ = T.matrix('input_noise', 'float32')\n",
"\n",
"CODE_SIZE = 100 # size of uniform noise\n",
"\n",
"class generator:\n",
" l_in = InputLayer((None, CODE_SIZE), input_var=inputZ, name='g_input')\n",
" l_dense0 = DenseLayer(l_in, 256, name='g_dense0')\n",
" l_dense1 = DenseLayer(l_dense0, 512, name='g_dense1')\n",
" l_gen = DenseLayer(l_dense1, IMG_SHAPE[0]*IMG_SHAPE[1], name='g_gen_layer')\n",
" \n",
" generated_img = get_output(l_gen)\n",
" \n",
" weights = get_all_params(l_gen, trainable=True)\n",
" \n",
" \n",
"class discriminator:\n",
" l_in = InputLayer((None, IMG_SHAPE[0]*IMG_SHAPE[1]), name='d_input')\n",
" l_dense0 = dropout(DenseLayer(l_in, 512, name='d_dense0'))\n",
" l_dense1 = dropout(DenseLayer(l_dense0, 256, name='d_dense1'))\n",
" l_dense2 = dropout(DenseLayer(l_dense1, 128, name='d_dense2'))\n",
" l_prob = DenseLayer(l_dense2, 1, nonlinearity=lasagne.nonlinearities.sigmoid, name='d_prob')\n",
" \n",
" noise_prob = get_output(l_prob, inputs=generator.generated_img)\n",
" img_prob = get_output(l_prob, inputs=inputX)\n",
" img_prob_determ = get_output(l_prob, inputs=inputX, deterministic=True) # for generating images in reports\n",
" \n",
" weights = get_all_params(l_prob, trainable=True)\n",
" \n",
"class training:\n",
" # Minimizing -log(D(G(z)) instead of log(1-D(G(z))) is a hack for avoiding gradient vanishing.\n",
" g_loss = binary_crossentropy(discriminator.noise_prob, T.ones_like(discriminator.noise_prob)).mean()\n",
"# g_loss += regularize_network_params(generator.l_gen, l2) * 0.001\n",
" \n",
" g_updates = lasagne.updates.adam(g_loss, generator.weights)\n",
" g_train_step = theano.function([inputZ], g_loss, updates=g_updates, allow_input_downcast=True)\n",
"\n",
" d_loss = binary_crossentropy(discriminator.img_prob, T.ones_like(discriminator.img_prob)).mean()\n",
" d_loss += binary_crossentropy(discriminator.noise_prob, T.zeros_like(discriminator.noise_prob)).mean()\n",
" \n",
" d_updates = lasagne.updates.adam(d_loss, discriminator.weights)\n",
" d_train_step = theano.function([inputX, inputZ], d_loss, updates=d_updates, allow_input_downcast=True)"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array(1.2878267765045166, dtype=float32)"
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"training.d_train_step(sample_data_batch(100), sample_noise_batch(100))"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array(0.6745600700378418, dtype=float32)"
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"training.g_train_step(sample_noise_batch(100))"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"(array(1.1519067287445068, dtype=float32),\n",
" array(0.6140279173851013, dtype=float32))"
]
},
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"training.d_train_step(sample_data_batch(256), sample_noise_batch(256)), training.g_train_step(sample_noise_batch(256))"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def sample_noise_batch(bsize):\n",
" return np.random.uniform(-1.,1.,size=(bsize, CODE_SIZE))\n",
"\n",
"def sample_data_batch(bsize):\n",
" idxs = np.random.choice(np.arange(mnistX.shape[0]), size=bsize)\n",
" return mnistX[idxs]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Let's generate some images from untrained generator."
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"gen_imgs = theano.function([inputZ], generator.generated_img, allow_input_downcast=True)\n",
"get_img_prob = theano.function([inputX], discriminator.img_prob_determ, allow_input_downcast=True)"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def generate_from_noise():\n",
" generated_imgs = gen_imgs(sample_noise_batch(25))\n",
" imgs_probs = get_img_prob(generated_imgs)\n",
" plt.figure(figsize=(10,10))\n",
" for i in range(25):\n",
" plt.subplot(5,5,i+1)\n",
" plt.imshow(generated_imgs[i].reshape(IMG_SHAPE))\n",
" plt.title('d_prob={:.3f}'.format(imgs_probs[i][0]))\n",
" plt.axis('off')\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0.7369349 ],\n",
" [ 0.80481601],\n",
" [ 0.66464865],\n",
" [ 0.88424724],\n",
" [ 0.81623596],\n",
" [ 0.94289351],\n",
" [ 0.81866348],\n",
" [ 0.8879593 ],\n",
" [ 0.80184805],\n",
" [ 0.81047082]], dtype=float32)"
]
},
"execution_count": 78,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_img_prob(sample_data_batch(10))"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0.62268591],\n",
" [ 0.60116392],\n",
" [ 0.59812319],\n",
" [ 0.61345595],\n",
" [ 0.58478087],\n",
" [ 0.61283481],\n",
" [ 0.57961714],\n",
" [ 0.60982937],\n",
" [ 0.63371795],\n",
" [ 0.59436017]], dtype=float32)"
]
},
"execution_count": 79,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_img_prob(gen_imgs(sample_noise_batch(10)))"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment