Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save beader/bdac84577e9731435bf73cee772191ed to your computer and use it in GitHub Desktop.
Save beader/bdac84577e9731435bf73cee772191ed to your computer and use it in GitHub Desktop.
visualize the discriminative ability of hidden layers as feature representations of original images
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"import os\n",
"import shutil\n",
"import numpy as np\n",
"import pandas as pd\n",
"import keras.backend.tensorflow_backend as K\n",
"from keras.datasets.mnist import load_data\n",
"from keras.models import Model, Sequential \n",
"from keras.layers import Input, Conv2D, MaxPool2D, Dense, \\\n",
" Flatten, Activation, Dropout, \\\n",
" Embedding\n",
"from keras.losses import sparse_categorical_crossentropy\n",
"from keras.callbacks import TensorBoard, Callback\n",
"from keras.optimizers import Adam"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x_train.shape (60000, 28, 28, 1)\n",
"y_train.shape (60000, 1)\n",
"x_test.shape (10000, 28, 28, 1)\n",
"y_test.shape (10000, 1)\n"
]
}
],
"source": [
"img_rows, img_cols = 28, 28\n",
"num_classes = 10\n",
"\n",
"(x_train, y_train), (x_test, y_test) = load_data()\n",
"x_train = x_train.reshape(-1, img_rows, img_cols, 1)\n",
"y_train = y_train.reshape(-1, 1)\n",
"x_test = x_test.reshape(-1, img_rows, img_cols, 1)\n",
"y_test = y_test.reshape(-1, 1)\n",
"\n",
"x_train = x_train.astype('float32') / 255\n",
"x_test = x_test.astype('float32') / 255\n",
"print('x_train.shape', x_train.shape)\n",
"print('y_train.shape', y_train.shape)\n",
"print('x_test.shape', x_test.shape)\n",
"print('y_test.shape', y_test.shape)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def build_model(input_shape, name='mnist_cnn'):\n",
" model = Sequential(name=name)\n",
" model.add(Conv2D(32, kernel_size=(3, 3), padding='same',\n",
" name='conv1',\n",
" activation='relu', input_shape=input_shape))\n",
" model.add(MaxPool2D(pool_size=(2, 2)))\n",
" model.add(Conv2D(64, kernel_size=(3, 3), padding='same',\n",
" name='conv2',\n",
" activation='relu'))\n",
" model.add(MaxPool2D(pool_size=(2, 2)))\n",
" model.add(Dropout(0.25))\n",
" model.add(Flatten())\n",
" model.add(Dense(128, activation='relu', name='dense1'))\n",
" model.add(Dropout(0.5))\n",
" model.add(Dense(num_classes, activation='softmax', name='dense2'))\n",
" model.compile(loss=sparse_categorical_crossentropy,\n",
" optimizer=Adam(),\n",
" metrics=['accuracy'])\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def gen_sprites_image(images, n_rows=None, n_cols=None, inverse_color=True):\n",
" num_imgs, img_height, img_width = images.shape[:3]\n",
" if len(images.shape) == 4:\n",
" num_channels = images.shape[3]\n",
" else:\n",
" num_channels = 1\n",
" if n_rows is None or n_cols is None:\n",
" n_rows = n_cols = np.sqrt(num_imgs).astype('int')\n",
" assert num_imgs == n_rows * n_cols\n",
" assert num_channels == 1 or num_channels == 3\n",
" sprites = (images - images.min()) / (images.max() - images.min())\n",
" sprites = images.reshape(n_rows, n_cols, img_height, img_width, num_channels)\n",
" sprites = sprites.swapaxes(1, 2)\n",
" sprites = sprites.reshape(n_rows * img_height, n_cols * img_width, num_channels)\n",
" if num_channels == 1:\n",
" sprites = np.repeat(sprites, 3, axis=-1)\n",
" if inverse_color:\n",
" sprites = 1 - sprites\n",
" return sprites"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class HiddenLayerFeatureVisCallback(Callback):\n",
" \"\"\"Visualize the hidden layer's output as a feature\n",
" representation of original images\n",
" \"\"\"\n",
" def __init__(self, log_dir='./logs',\n",
" images=None,\n",
" labels=None,\n",
" layer_names=None,\n",
" save_freq=5):\n",
" global tf, projector\n",
" import tensorflow as tf\n",
" from tensorflow.contrib.tensorboard.plugins import projector\n",
"\n",
" assert isinstance(layer_names, list) and len(layer_names) > 0\n",
" assert images is not None\n",
" super(HiddenLayerFeatureVisCallback, self).__init__()\n",
" self.log_dir = log_dir\n",
" self.layer_names = layer_names\n",
" self.images = images\n",
" self.num_imgs = images.shape[0]\n",
" self.labels = labels.squeeze()\n",
" self.save_freq = save_freq\n",
" \n",
" def save_sprites_image(self):\n",
" global plt\n",
" import matplotlib.pyplot as plt\n",
" sprite_image = gen_sprites_image(self.images)\n",
" plt.imsave(os.path.join(self.log_dir, 'sprites.png'), sprite_image)\n",
" \n",
" def save_meta_data(self):\n",
" self.metadata_file = 'metadata.tsv'\n",
" with open(os.path.join(self.log_dir, self.metadata_file), 'w') as f:\n",
" f.write('id\\tlabel\\n')\n",
" for i, label in enumerate(self.labels):\n",
" f.write('%d\\t%s\\n' % (i, str(label)))\n",
" \n",
" def set_model(self, model):\n",
" self.model = model\n",
" self.sess = K.get_session() \n",
" self.save_sprites_image()\n",
" self.save_meta_data()\n",
" self.feature_reprs = dict()\n",
" for layer in self.model.layers:\n",
" if layer.name in self.layer_names:\n",
" feature_dim = np.product(layer.output_shape[1:])\n",
" self.feature_reprs[layer.name] = K.variable(np.zeros((self.num_imgs, feature_dim)), \n",
" name=layer.name)\n",
" \n",
" config = projector.ProjectorConfig()\n",
" self.feature_reprs_ckpt_path = os.path.join(self.log_dir, 'feature_reprs.ckpt')\n",
" for layer_name, tensor in self.feature_reprs.items():\n",
" embedding = config.embeddings.add()\n",
" embedding.tensor_name = tensor.name\n",
" embedding.metadata_path = 'metadata.tsv'\n",
" embedding.sprite.image_path = 'sprites.png'\n",
" embedding.sprite.single_image_dim.extend(self.images.shape[1:3])\n",
"\n",
" self.writer = tf.summary.FileWriter(self.log_dir)\n",
" self.saver = tf.train.Saver(list(self.feature_reprs.values()))\n",
" projector.visualize_embeddings(self.writer, config)\n",
" \n",
" def on_epoch_end(self, epoch, logs=None):\n",
" if (epoch + 1) % self.save_freq != 0:\n",
" return\n",
" for layer in self.model.layers:\n",
" if layer.name in self.feature_reprs:\n",
" func = K.function([self.model.input, K.learning_phase()], [layer.output])\n",
" layer_out = func([self.images, 0])[0].reshape(self.num_imgs, -1)\n",
" K.set_value(self.feature_reprs[layer.name], layer_out)\n",
" self.saver.save(self.sess, self.feature_reprs_ckpt_path, epoch)\n",
" \n",
" def on_train_end(self, _):\n",
" self.writer.close()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"model = build_model(input_shape=(28, 28, 1))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"conv1 (Conv2D) (None, 28, 28, 32) 320 \n",
"_________________________________________________________________\n",
"max_pooling2d_1 (MaxPooling2 (None, 14, 14, 32) 0 \n",
"_________________________________________________________________\n",
"conv2 (Conv2D) (None, 14, 14, 64) 18496 \n",
"_________________________________________________________________\n",
"max_pooling2d_2 (MaxPooling2 (None, 7, 7, 64) 0 \n",
"_________________________________________________________________\n",
"dropout_1 (Dropout) (None, 7, 7, 64) 0 \n",
"_________________________________________________________________\n",
"flatten_1 (Flatten) (None, 3136) 0 \n",
"_________________________________________________________________\n",
"dense1 (Dense) (None, 128) 401536 \n",
"_________________________________________________________________\n",
"dropout_2 (Dropout) (None, 128) 0 \n",
"_________________________________________________________________\n",
"dense2 (Dense) (None, 10) 1290 \n",
"=================================================================\n",
"Total params: 421,642\n",
"Trainable params: 421,642\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 60000 samples, validate on 10000 samples\n",
"Epoch 1/20\n",
"60000/60000 [==============================] - 3s 53us/step - loss: 0.2926 - acc: 0.9087 - val_loss: 0.0628 - val_acc: 0.9795\n",
"Epoch 2/20\n",
"60000/60000 [==============================] - 2s 35us/step - loss: 0.1009 - acc: 0.9691 - val_loss: 0.0427 - val_acc: 0.9858\n",
"Epoch 3/20\n",
"60000/60000 [==============================] - 2s 35us/step - loss: 0.0736 - acc: 0.9776 - val_loss: 0.0336 - val_acc: 0.9881\n",
"Epoch 4/20\n",
"60000/60000 [==============================] - 2s 35us/step - loss: 0.0625 - acc: 0.9813 - val_loss: 0.0283 - val_acc: 0.9902\n",
"Epoch 5/20\n",
"60000/60000 [==============================] - 2s 41us/step - loss: 0.0545 - acc: 0.9833 - val_loss: 0.0290 - val_acc: 0.9909\n",
"Epoch 6/20\n",
"60000/60000 [==============================] - 2s 36us/step - loss: 0.0462 - acc: 0.9861 - val_loss: 0.0232 - val_acc: 0.9926\n",
"Epoch 7/20\n",
"60000/60000 [==============================] - 2s 35us/step - loss: 0.0424 - acc: 0.9869 - val_loss: 0.0250 - val_acc: 0.9921\n",
"Epoch 8/20\n",
"60000/60000 [==============================] - 2s 35us/step - loss: 0.0401 - acc: 0.9876 - val_loss: 0.0231 - val_acc: 0.9916\n",
"Epoch 9/20\n",
"60000/60000 [==============================] - 2s 35us/step - loss: 0.0354 - acc: 0.9890 - val_loss: 0.0214 - val_acc: 0.9920\n",
"Epoch 10/20\n",
"60000/60000 [==============================] - 2s 38us/step - loss: 0.0328 - acc: 0.9894 - val_loss: 0.0243 - val_acc: 0.9922\n",
"Epoch 11/20\n",
"60000/60000 [==============================] - 2s 36us/step - loss: 0.0319 - acc: 0.9900 - val_loss: 0.0218 - val_acc: 0.9928\n",
"Epoch 12/20\n",
"60000/60000 [==============================] - 2s 35us/step - loss: 0.0294 - acc: 0.9903 - val_loss: 0.0198 - val_acc: 0.9941\n",
"Epoch 13/20\n",
"60000/60000 [==============================] - 2s 35us/step - loss: 0.0274 - acc: 0.9913 - val_loss: 0.0203 - val_acc: 0.9934\n",
"Epoch 14/20\n",
"60000/60000 [==============================] - 2s 35us/step - loss: 0.0267 - acc: 0.9911 - val_loss: 0.0195 - val_acc: 0.9935\n",
"Epoch 15/20\n",
"60000/60000 [==============================] - 2s 38us/step - loss: 0.0237 - acc: 0.9922 - val_loss: 0.0174 - val_acc: 0.9940\n",
"Epoch 16/20\n",
"60000/60000 [==============================] - 2s 36us/step - loss: 0.0226 - acc: 0.9922 - val_loss: 0.0225 - val_acc: 0.9937\n",
"Epoch 17/20\n",
"60000/60000 [==============================] - 2s 35us/step - loss: 0.0219 - acc: 0.9930 - val_loss: 0.0238 - val_acc: 0.9934\n",
"Epoch 18/20\n",
"60000/60000 [==============================] - 2s 35us/step - loss: 0.0210 - acc: 0.9932 - val_loss: 0.0216 - val_acc: 0.9922\n",
"Epoch 19/20\n",
"60000/60000 [==============================] - 2s 35us/step - loss: 0.0204 - acc: 0.9931 - val_loss: 0.0238 - val_acc: 0.9928\n",
"Epoch 20/20\n",
"60000/60000 [==============================] - 2s 39us/step - loss: 0.0194 - acc: 0.9932 - val_loss: 0.0204 - val_acc: 0.9935\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7ff496b1ad30>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"batch_size = 128\n",
"epochs = 20\n",
"log_dir = os.path.join(os.getcwd(), 'logs', 'mnist_cnn_visualize')\n",
"shutil.rmtree(log_dir, ignore_errors=True)\n",
"os.makedirs(log_dir, exist_ok=True)\n",
"\n",
"callbacks = [\n",
" HiddenLayerFeatureVisCallback(log_dir, images=x_test[:400], labels=y_test[:400], \n",
" layer_names=['conv1', 'conv2', 'dense1', 'dense2'])\n",
"]\n",
"\n",
"model.fit(x_train, y_train, batch_size=batch_size, \n",
" validation_data=(x_test, y_test),\n",
" epochs=epochs, callbacks=callbacks)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment