Instantly share code, notes, and snippets.
Last active
January 2, 2018 02:46
-
Star
1
(1)
You must be signed in to star a gist -
Fork
0
(0)
You must be signed in to fork a gist
-
Save beader/bdac84577e9731435bf73cee772191ed to your computer and use it in GitHub Desktop.
visualize the discriminative ability of hidden layers as feature representations of original images
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Using TensorFlow backend.\n" | |
] | |
} | |
], | |
"source": [ | |
"import os\n", | |
"import shutil\n", | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"import keras.backend.tensorflow_backend as K\n", | |
"from keras.datasets.mnist import load_data\n", | |
"from keras.models import Model, Sequential \n", | |
"from keras.layers import Input, Conv2D, MaxPool2D, Dense, \\\n", | |
" Flatten, Activation, Dropout, \\\n", | |
" Embedding\n", | |
"from keras.losses import sparse_categorical_crossentropy\n", | |
"from keras.callbacks import TensorBoard, Callback\n", | |
"from keras.optimizers import Adam" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"x_train.shape (60000, 28, 28, 1)\n", | |
"y_train.shape (60000, 1)\n", | |
"x_test.shape (10000, 28, 28, 1)\n", | |
"y_test.shape (10000, 1)\n" | |
] | |
} | |
], | |
"source": [ | |
"img_rows, img_cols = 28, 28\n", | |
"num_classes = 10\n", | |
"\n", | |
"(x_train, y_train), (x_test, y_test) = load_data()\n", | |
"x_train = x_train.reshape(-1, img_rows, img_cols, 1)\n", | |
"y_train = y_train.reshape(-1, 1)\n", | |
"x_test = x_test.reshape(-1, img_rows, img_cols, 1)\n", | |
"y_test = y_test.reshape(-1, 1)\n", | |
"\n", | |
"x_train = x_train.astype('float32') / 255\n", | |
"x_test = x_test.astype('float32') / 255\n", | |
"print('x_train.shape', x_train.shape)\n", | |
"print('y_train.shape', y_train.shape)\n", | |
"print('x_test.shape', x_test.shape)\n", | |
"print('y_test.shape', y_test.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def build_model(input_shape, name='mnist_cnn'):\n", | |
" model = Sequential(name=name)\n", | |
" model.add(Conv2D(32, kernel_size=(3, 3), padding='same',\n", | |
" name='conv1',\n", | |
" activation='relu', input_shape=input_shape))\n", | |
" model.add(MaxPool2D(pool_size=(2, 2)))\n", | |
" model.add(Conv2D(64, kernel_size=(3, 3), padding='same',\n", | |
" name='conv2',\n", | |
" activation='relu'))\n", | |
" model.add(MaxPool2D(pool_size=(2, 2)))\n", | |
" model.add(Dropout(0.25))\n", | |
" model.add(Flatten())\n", | |
" model.add(Dense(128, activation='relu', name='dense1'))\n", | |
" model.add(Dropout(0.5))\n", | |
" model.add(Dense(num_classes, activation='softmax', name='dense2'))\n", | |
" model.compile(loss=sparse_categorical_crossentropy,\n", | |
" optimizer=Adam(),\n", | |
" metrics=['accuracy'])\n", | |
" return model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def gen_sprites_image(images, n_rows=None, n_cols=None, inverse_color=True):\n", | |
" num_imgs, img_height, img_width = images.shape[:3]\n", | |
" if len(images.shape) == 4:\n", | |
" num_channels = images.shape[3]\n", | |
" else:\n", | |
" num_channels = 1\n", | |
" if n_rows is None or n_cols is None:\n", | |
" n_rows = n_cols = np.sqrt(num_imgs).astype('int')\n", | |
" assert num_imgs == n_rows * n_cols\n", | |
" assert num_channels == 1 or num_channels == 3\n", | |
" sprites = (images - images.min()) / (images.max() - images.min())\n", | |
" sprites = images.reshape(n_rows, n_cols, img_height, img_width, num_channels)\n", | |
" sprites = sprites.swapaxes(1, 2)\n", | |
" sprites = sprites.reshape(n_rows * img_height, n_cols * img_width, num_channels)\n", | |
" if num_channels == 1:\n", | |
" sprites = np.repeat(sprites, 3, axis=-1)\n", | |
" if inverse_color:\n", | |
" sprites = 1 - sprites\n", | |
" return sprites" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class HiddenLayerFeatureVisCallback(Callback):\n", | |
" \"\"\"Visualize the hidden layer's output as a feature\n", | |
" representation of original images\n", | |
" \"\"\"\n", | |
" def __init__(self, log_dir='./logs',\n", | |
" images=None,\n", | |
" labels=None,\n", | |
" layer_names=None,\n", | |
" save_freq=5):\n", | |
" global tf, projector\n", | |
" import tensorflow as tf\n", | |
" from tensorflow.contrib.tensorboard.plugins import projector\n", | |
"\n", | |
" assert isinstance(layer_names, list) and len(layer_names) > 0\n", | |
" assert images is not None\n", | |
" super(HiddenLayerFeatureVisCallback, self).__init__()\n", | |
" self.log_dir = log_dir\n", | |
" self.layer_names = layer_names\n", | |
" self.images = images\n", | |
" self.num_imgs = images.shape[0]\n", | |
" self.labels = labels.squeeze()\n", | |
" self.save_freq = save_freq\n", | |
" \n", | |
" def save_sprites_image(self):\n", | |
" global plt\n", | |
" import matplotlib.pyplot as plt\n", | |
" sprite_image = gen_sprites_image(self.images)\n", | |
" plt.imsave(os.path.join(self.log_dir, 'sprites.png'), sprite_image)\n", | |
" \n", | |
" def save_meta_data(self):\n", | |
" self.metadata_file = 'metadata.tsv'\n", | |
" with open(os.path.join(self.log_dir, self.metadata_file), 'w') as f:\n", | |
" f.write('id\\tlabel\\n')\n", | |
" for i, label in enumerate(self.labels):\n", | |
" f.write('%d\\t%s\\n' % (i, str(label)))\n", | |
" \n", | |
" def set_model(self, model):\n", | |
" self.model = model\n", | |
" self.sess = K.get_session() \n", | |
" self.save_sprites_image()\n", | |
" self.save_meta_data()\n", | |
" self.feature_reprs = dict()\n", | |
" for layer in self.model.layers:\n", | |
" if layer.name in self.layer_names:\n", | |
" feature_dim = np.product(layer.output_shape[1:])\n", | |
" self.feature_reprs[layer.name] = K.variable(np.zeros((self.num_imgs, feature_dim)), \n", | |
" name=layer.name)\n", | |
" \n", | |
" config = projector.ProjectorConfig()\n", | |
" self.feature_reprs_ckpt_path = os.path.join(self.log_dir, 'feature_reprs.ckpt')\n", | |
" for layer_name, tensor in self.feature_reprs.items():\n", | |
" embedding = config.embeddings.add()\n", | |
" embedding.tensor_name = tensor.name\n", | |
" embedding.metadata_path = 'metadata.tsv'\n", | |
" embedding.sprite.image_path = 'sprites.png'\n", | |
" embedding.sprite.single_image_dim.extend(self.images.shape[1:3])\n", | |
"\n", | |
" self.writer = tf.summary.FileWriter(self.log_dir)\n", | |
" self.saver = tf.train.Saver(list(self.feature_reprs.values()))\n", | |
" projector.visualize_embeddings(self.writer, config)\n", | |
" \n", | |
" def on_epoch_end(self, epoch, logs=None):\n", | |
" if (epoch + 1) % self.save_freq != 0:\n", | |
" return\n", | |
" for layer in self.model.layers:\n", | |
" if layer.name in self.feature_reprs:\n", | |
" func = K.function([self.model.input, K.learning_phase()], [layer.output])\n", | |
" layer_out = func([self.images, 0])[0].reshape(self.num_imgs, -1)\n", | |
" K.set_value(self.feature_reprs[layer.name], layer_out)\n", | |
" self.saver.save(self.sess, self.feature_reprs_ckpt_path, epoch)\n", | |
" \n", | |
" def on_train_end(self, _):\n", | |
" self.writer.close()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = build_model(input_shape=(28, 28, 1))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"conv1 (Conv2D) (None, 28, 28, 32) 320 \n", | |
"_________________________________________________________________\n", | |
"max_pooling2d_1 (MaxPooling2 (None, 14, 14, 32) 0 \n", | |
"_________________________________________________________________\n", | |
"conv2 (Conv2D) (None, 14, 14, 64) 18496 \n", | |
"_________________________________________________________________\n", | |
"max_pooling2d_2 (MaxPooling2 (None, 7, 7, 64) 0 \n", | |
"_________________________________________________________________\n", | |
"dropout_1 (Dropout) (None, 7, 7, 64) 0 \n", | |
"_________________________________________________________________\n", | |
"flatten_1 (Flatten) (None, 3136) 0 \n", | |
"_________________________________________________________________\n", | |
"dense1 (Dense) (None, 128) 401536 \n", | |
"_________________________________________________________________\n", | |
"dropout_2 (Dropout) (None, 128) 0 \n", | |
"_________________________________________________________________\n", | |
"dense2 (Dense) (None, 10) 1290 \n", | |
"=================================================================\n", | |
"Total params: 421,642\n", | |
"Trainable params: 421,642\n", | |
"Non-trainable params: 0\n", | |
"_________________________________________________________________\n" | |
] | |
} | |
], | |
"source": [ | |
"model.summary()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Train on 60000 samples, validate on 10000 samples\n", | |
"Epoch 1/20\n", | |
"60000/60000 [==============================] - 3s 53us/step - loss: 0.2926 - acc: 0.9087 - val_loss: 0.0628 - val_acc: 0.9795\n", | |
"Epoch 2/20\n", | |
"60000/60000 [==============================] - 2s 35us/step - loss: 0.1009 - acc: 0.9691 - val_loss: 0.0427 - val_acc: 0.9858\n", | |
"Epoch 3/20\n", | |
"60000/60000 [==============================] - 2s 35us/step - loss: 0.0736 - acc: 0.9776 - val_loss: 0.0336 - val_acc: 0.9881\n", | |
"Epoch 4/20\n", | |
"60000/60000 [==============================] - 2s 35us/step - loss: 0.0625 - acc: 0.9813 - val_loss: 0.0283 - val_acc: 0.9902\n", | |
"Epoch 5/20\n", | |
"60000/60000 [==============================] - 2s 41us/step - loss: 0.0545 - acc: 0.9833 - val_loss: 0.0290 - val_acc: 0.9909\n", | |
"Epoch 6/20\n", | |
"60000/60000 [==============================] - 2s 36us/step - loss: 0.0462 - acc: 0.9861 - val_loss: 0.0232 - val_acc: 0.9926\n", | |
"Epoch 7/20\n", | |
"60000/60000 [==============================] - 2s 35us/step - loss: 0.0424 - acc: 0.9869 - val_loss: 0.0250 - val_acc: 0.9921\n", | |
"Epoch 8/20\n", | |
"60000/60000 [==============================] - 2s 35us/step - loss: 0.0401 - acc: 0.9876 - val_loss: 0.0231 - val_acc: 0.9916\n", | |
"Epoch 9/20\n", | |
"60000/60000 [==============================] - 2s 35us/step - loss: 0.0354 - acc: 0.9890 - val_loss: 0.0214 - val_acc: 0.9920\n", | |
"Epoch 10/20\n", | |
"60000/60000 [==============================] - 2s 38us/step - loss: 0.0328 - acc: 0.9894 - val_loss: 0.0243 - val_acc: 0.9922\n", | |
"Epoch 11/20\n", | |
"60000/60000 [==============================] - 2s 36us/step - loss: 0.0319 - acc: 0.9900 - val_loss: 0.0218 - val_acc: 0.9928\n", | |
"Epoch 12/20\n", | |
"60000/60000 [==============================] - 2s 35us/step - loss: 0.0294 - acc: 0.9903 - val_loss: 0.0198 - val_acc: 0.9941\n", | |
"Epoch 13/20\n", | |
"60000/60000 [==============================] - 2s 35us/step - loss: 0.0274 - acc: 0.9913 - val_loss: 0.0203 - val_acc: 0.9934\n", | |
"Epoch 14/20\n", | |
"60000/60000 [==============================] - 2s 35us/step - loss: 0.0267 - acc: 0.9911 - val_loss: 0.0195 - val_acc: 0.9935\n", | |
"Epoch 15/20\n", | |
"60000/60000 [==============================] - 2s 38us/step - loss: 0.0237 - acc: 0.9922 - val_loss: 0.0174 - val_acc: 0.9940\n", | |
"Epoch 16/20\n", | |
"60000/60000 [==============================] - 2s 36us/step - loss: 0.0226 - acc: 0.9922 - val_loss: 0.0225 - val_acc: 0.9937\n", | |
"Epoch 17/20\n", | |
"60000/60000 [==============================] - 2s 35us/step - loss: 0.0219 - acc: 0.9930 - val_loss: 0.0238 - val_acc: 0.9934\n", | |
"Epoch 18/20\n", | |
"60000/60000 [==============================] - 2s 35us/step - loss: 0.0210 - acc: 0.9932 - val_loss: 0.0216 - val_acc: 0.9922\n", | |
"Epoch 19/20\n", | |
"60000/60000 [==============================] - 2s 35us/step - loss: 0.0204 - acc: 0.9931 - val_loss: 0.0238 - val_acc: 0.9928\n", | |
"Epoch 20/20\n", | |
"60000/60000 [==============================] - 2s 39us/step - loss: 0.0194 - acc: 0.9932 - val_loss: 0.0204 - val_acc: 0.9935\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<keras.callbacks.History at 0x7ff496b1ad30>" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"batch_size = 128\n", | |
"epochs = 20\n", | |
"log_dir = os.path.join(os.getcwd(), 'logs', 'mnist_cnn_visualize')\n", | |
"shutil.rmtree(log_dir, ignore_errors=True)\n", | |
"os.makedirs(log_dir, exist_ok=True)\n", | |
"\n", | |
"callbacks = [\n", | |
" HiddenLayerFeatureVisCallback(log_dir, images=x_test[:400], labels=y_test[:400], \n", | |
" layer_names=['conv1', 'conv2', 'dense1', 'dense2'])\n", | |
"]\n", | |
"\n", | |
"model.fit(x_train, y_train, batch_size=batch_size, \n", | |
" validation_data=(x_test, y_test),\n", | |
" epochs=epochs, callbacks=callbacks)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment