Skip to content

Instantly share code, notes, and snippets.

@ronghanghu
Created August 15, 2017 16:19
Show Gist options
  • Save ronghanghu/edb3a289a98bdea23847471780aa206a to your computer and use it in GitHub Desktop.
Save ronghanghu/edb3a289a98bdea23847471780aa206a to your computer and use it in GitHub Desktop.
Visualize VQA attentions
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from __future__ import absolute_import, division, print_function\n",
"\n",
"gpu_id = 0 # set GPU id to use\n",
"import os; os.environ[\"CUDA_VISIBLE_DEVICES\"] = str(gpu_id)\n",
"\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"# Start the session BEFORE importing tensorflow_fold\n",
"# to avoid taking up all GPU memory\n",
"sess = tf.Session(config=tf.ConfigProto(\n",
" gpu_options=tf.GPUOptions(allow_growth=True),\n",
" allow_soft_placement=False, log_device_placement=False))\n",
"import json\n",
"\n",
"from models_vqa.nmn3_assembler import Assembler\n",
"from models_vqa.nmn3_model import NMN3Model\n",
"from util.vqa_train.data_reader import DataReader\n",
"\n",
"from models_vqa.nmn3_modules import Modules\n",
"from models_vqa.nmn3_assembler import _module_input_num"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Module parameters\n",
"H_feat = 14\n",
"W_feat = 14\n",
"D_feat = 2048\n",
"embed_dim_txt = 300\n",
"embed_dim_nmn = 300\n",
"lstm_dim = 1000\n",
"num_layers = 2\n",
"T_encoder = 26\n",
"T_decoder = 13\n",
"N = 1\n",
"use_qpn = True\n",
"reduce_visfeat_dim = False\n",
"\n",
"exp_name = \"vqa_gt_layout\"\n",
"snapshot_name = \"00040000\"\n",
"# tst_image_set = 'train2014'\n",
"tst_image_set = 'val2014'\n",
"# tst_image_set = 'test-dev2015'\n",
"# tst_image_set = 'test2015'\n",
"snapshot_file = './exp_vqa/tfmodel/%s/%s' % (exp_name, snapshot_name)\n",
"\n",
"# Data files\n",
"vocab_question_file = './exp_vqa/data/vocabulary_vqa.txt'\n",
"vocab_layout_file = './exp_vqa/data/vocabulary_layout.txt'\n",
"vocab_answer_file = './exp_vqa/data/answers_vqa.txt'\n",
"\n",
"# imdb_file_trn = './exp_vqa/data/imdb/imdb_trainval2014.npy'\n",
"imdb_file_tst = './exp_vqa/data/imdb/imdb_%s.npy' % tst_image_set\n",
"\n",
"save_dir = './exp_vqa/results/%s/%s.%s' % (exp_name, snapshot_name, tst_image_set)\n",
"os.makedirs(save_dir, exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"assembler = Assembler(vocab_layout_file)\n",
"\n",
"# data_reader_trn = DataReader(imdb_file_trn, shuffle=True, one_pass=False,\n",
"# batch_size=N,\n",
"# T_encoder=T_encoder,\n",
"# T_decoder=T_decoder,\n",
"# assembler=assembler,\n",
"# vocab_question_file=vocab_question_file,\n",
"# vocab_answer_file=vocab_answer_file)\n",
"\n",
"data_reader_tst = DataReader(imdb_file_tst, shuffle=False, one_pass=True,\n",
" batch_size=N,\n",
" T_encoder=T_encoder,\n",
" T_decoder=T_decoder,\n",
" assembler=assembler,\n",
" vocab_question_file=vocab_question_file,\n",
" vocab_answer_file=vocab_answer_file)\n",
"\n",
"num_vocab_txt = data_reader_tst.batch_loader.vocab_dict.num_vocab\n",
"num_vocab_nmn = len(assembler.module_names)\n",
"num_choices = data_reader_tst.batch_loader.answer_dict.num_vocab"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Network inputs\n",
"input_seq_batch = tf.placeholder(tf.int32, [None, None])\n",
"seq_length_batch = tf.placeholder(tf.int32, [None])\n",
"image_feat_batch = tf.placeholder(tf.float32, [None, H_feat, W_feat, D_feat])\n",
"expr_validity_batch = tf.placeholder(tf.bool, [None])\n",
"\n",
"# The model for testing\n",
"nmn3_model_tst = NMN3Model(\n",
" image_feat_batch, input_seq_batch,\n",
" seq_length_batch, T_decoder=T_decoder,\n",
" num_vocab_txt=num_vocab_txt, embed_dim_txt=embed_dim_txt,\n",
" num_vocab_nmn=num_vocab_nmn, embed_dim_nmn=embed_dim_nmn,\n",
" lstm_dim=lstm_dim, num_layers=num_layers,\n",
" assembler=assembler,\n",
" encoder_dropout=False,\n",
" decoder_dropout=False,\n",
" decoder_sampling=False,\n",
" num_choices=num_choices,\n",
" use_qpn=use_qpn, qpn_dropout=False, reduce_visfeat_dim=reduce_visfeat_dim)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"image_feature_grid = nmn3_model_tst.image_feat_grid\n",
"word_vecs = nmn3_model_tst.word_vecs\n",
"atts = nmn3_model_tst.atts\n",
"\n",
"image_feat_grid_ph = tf.placeholder(tf.float32, image_feature_grid.get_shape())\n",
"word_vecs_ph = tf.placeholder(tf.float32, word_vecs.get_shape())\n",
"\n",
"batch_idx = tf.constant([0], tf.int32)\n",
"time_idx = tf.placeholder(tf.int32, [1])\n",
"input_0 = tf.placeholder(tf.float32, [1, H_feat, W_feat, 1])\n",
"input_1 = tf.placeholder(tf.float32, [1, H_feat, W_feat, 1])\n",
"\n",
"# Manually construct each module outside TensorFlow fold for visualization\n",
"module_outputs = {}\n",
"with tf.variable_scope(\"neural_module_network/layout_execution\", reuse=True):\n",
" modules = Modules(image_feat_grid_ph, word_vecs_ph, None, num_choices)\n",
" module_outputs['_Scene'] = modules.SceneModule(time_idx, batch_idx)\n",
" module_outputs['_Find'] = modules.FindModule(time_idx, batch_idx)\n",
" module_outputs['_FindSameProperty'] = modules.FindSamePropertyModule(input_0, time_idx, batch_idx)\n",
" module_outputs['_Transform'] = modules.TransformModule(input_0, time_idx, batch_idx)\n",
" module_outputs['_And'] = modules.AndModule(input_0, input_1, time_idx, batch_idx)\n",
" module_outputs['_Filter'] = modules.FilterModule(input_0, time_idx, batch_idx)\n",
" module_outputs['_Or'] = modules.OrModule(input_0, input_1, time_idx, batch_idx)\n",
" module_outputs['_Exist'] = modules.ExistModule(input_0, time_idx, batch_idx)\n",
" module_outputs['_Count'] = modules.CountModule(input_0, time_idx, batch_idx)\n",
" module_outputs['_EqualNum'] = modules.EqualNumModule(input_0, input_1, time_idx, batch_idx)\n",
" module_outputs['_MoreNum'] = modules.MoreNumModule(input_0, input_1, time_idx, batch_idx)\n",
" module_outputs['_LessNum'] = modules.LessNumModule(input_0, input_1, time_idx, batch_idx)\n",
" module_outputs['_SameProperty'] = modules.SamePropertyModule(input_0, input_1, time_idx, batch_idx)\n",
" module_outputs['_Describe'] = modules.DescribeModule(input_0, time_idx, batch_idx)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def eval_module(module_name, inputs, t, image_feat_grid_val, word_vecs_val):\n",
" feed_dict = {image_feat_grid_ph: image_feat_grid_val,\n",
" word_vecs_ph: word_vecs_val,\n",
" time_idx: [t]}\n",
" # print('evaluating module ' + module_name)\n",
" if 'input_0' in inputs:\n",
" feed_dict[input_0] = inputs['input_0']\n",
" if 'input_1' in inputs:\n",
" feed_dict[input_1] = inputs['input_1']\n",
" if module_name in module_outputs:\n",
" result = sess.run(module_outputs[module_name], feed_dict)\n",
" else:\n",
" raise ValueError(\"invalid module name: \" + module_name)\n",
"\n",
" return result\n",
"\n",
"def eval_expr(layout_tokens, image_feat_grid_val, word_vecs_val):\n",
" invalid_scores = np.array([[0, 0]], np.float32)\n",
" # Decoding Reverse Polish Notation with a stack\n",
" decoding_stack = []\n",
" all_output_stack = []\n",
" for t in range(len(layout_tokens)):\n",
" # decode a module/operation\n",
" module_idx = layout_tokens[t]\n",
" if module_idx == assembler.EOS_idx:\n",
" break\n",
" module_name = assembler.module_names[module_idx]\n",
" input_num = _module_input_num[module_name]\n",
"\n",
" # Get the input from stack\n",
" inputs = {}\n",
" for n_input in range(input_num-1, -1, -1):\n",
" stack_top = decoding_stack.pop()\n",
" inputs[\"input_%d\" % n_input] = stack_top\n",
" result = eval_module(module_name, inputs, t,\n",
" image_feat_grid_val, word_vecs_val)\n",
" decoding_stack.append(result)\n",
" all_output_stack.append((t, module_name, result[0]))\n",
"\n",
" assert(len(decoding_stack) == 1)\n",
" result = decoding_stack[0]\n",
" return result, all_output_stack\n",
"\n",
"def expr2str(expr, indent=4):\n",
" name = expr['module']\n",
" input_str = []\n",
" if 'input_0' in expr:\n",
" input_str.append('\\n'+' '*indent+expr2str(expr['input_0'], indent+4))\n",
" if 'input_1' in expr:\n",
" input_str.append('\\n'+' '*indent+expr2str(expr['input_1'], indent+4))\n",
" expr_str = name[1:]+('[%d]'%expr['time_idx'])+\"(\"+\", \".join(input_str)+\")\"\n",
" return expr_str"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"snapshot_saver = tf.train.Saver(max_to_keep=None) # keep all snapshots\n",
"snapshot_saver.restore(sess, snapshot_file)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import skimage.io\n",
"import skimage.transform\n",
"\n",
"def att_softmax(att):\n",
" exps = np.exp(att[..., 0] - np.max(att))\n",
" softmax = exps / np.sum(exps)\n",
" return softmax\n",
"\n",
"name_map = {\n",
" 'transform': 'relocate',\n",
" 'equalnum': 'eq_count',\n",
" 'morenum': 'more',\n",
" 'lessnum': 'less',\n",
" 'sameproperty': 'compare',\n",
" 'findsameproperty': 'relocate'}\n",
"def get_module_disp_name(name):\n",
" name = name[1:].lower()\n",
" if name in name_map:\n",
" name = name_map[name]\n",
" return name\n",
"\n",
"def attention_interpolation(im, att):\n",
" # steps:\n",
" # 1. reshape the attention to image size (with cubic)\n",
" softmax = att_softmax(att)\n",
" att_reshaped = skimage.transform.resize(softmax, im.shape[:2], order=3)\n",
" att_reshaped /= np.max(att_reshaped)\n",
" att_reshaped = att_reshaped[..., np.newaxis]\n",
" \n",
" all_white = np.ones_like(im) * (255 if im.dtype == np.uint8 else 1)\n",
" vis_im = att_reshaped*im + (1-att_reshaped)*all_white\n",
" vis_im = vis_im.astype(im.dtype)\n",
" return vis_im"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import skimage.io\n",
"import skimage.transform\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"plt.rcParams.update({'font.size': 6})\n",
"def run_visualization(dataset_tst):\n",
" if dataset_tst is None:\n",
" return\n",
" print('Running test...')\n",
" answer_word_list = dataset_tst.batch_loader.answer_dict.word_list\n",
" vocab_list = dataset_tst.batch_loader.vocab_dict.word_list\n",
" for n, batch in enumerate(dataset_tst.batches()):\n",
" if n >= 100: break\n",
" # set up input and output tensors\n",
" h = sess.partial_run_setup(\n",
" [nmn3_model_tst.predicted_tokens, nmn3_model_tst.scores, nmn3_model_tst.scores_qpn, word_vecs, atts],\n",
" [input_seq_batch, seq_length_batch, image_feat_batch,\n",
" nmn3_model_tst.compiler.loom_input_tensor, expr_validity_batch])\n",
"\n",
" # Part 0 & 1: Run Convnet and generate module layout\n",
" tokens, word_vecs_val, atts_val, scores_qpn_val =\\\n",
" sess.partial_run(h, (nmn3_model_tst.predicted_tokens, word_vecs, atts, nmn3_model_tst.scores_qpn),\n",
" feed_dict={input_seq_batch: batch['input_seq_batch'],\n",
" seq_length_batch: batch['seq_length_batch'],\n",
" image_feat_batch: batch['image_feat_batch']})\n",
" image_feat_grid_val = batch['image_feat_batch']\n",
"\n",
" # Assemble the layout tokens into network structure\n",
" expr_list, expr_validity_array = assembler.assemble(tokens)\n",
" labels = batch['answer_label_batch']\n",
" # Build TensorFlow Fold input for NMN\n",
" expr_feed = nmn3_model_tst.compiler.build_feed_dict(expr_list)\n",
" expr_feed[expr_validity_batch] = expr_validity_array\n",
"\n",
" # Part 2: Run NMN and learning steps\n",
" scores_val = sess.partial_run(h, nmn3_model_tst.scores, feed_dict=expr_feed)\n",
"\n",
" predictions = np.argmax(scores_val, axis=1)\n",
"\n",
" # Part 3: Visualization\n",
" print('visualizing %d' % n)\n",
" layout_tokens = tokens.T[0]\n",
" result, all_output_stack = eval_expr(layout_tokens, image_feat_grid_val, word_vecs_val)\n",
" result += scores_qpn_val\n",
" # check that the results are consistent\n",
" diff = np.max(np.abs(result - scores_val))\n",
" assert(np.all(diff < 1e-4))\n",
"\n",
" encoder_words = [vocab_list[w]\n",
" for n_w, w in enumerate(batch['input_seq_batch'][:, 0])\n",
" if n_w < batch['seq_length_batch'][0]]\n",
" decoder_words = [get_module_disp_name(assembler.module_names[w])+'[%d]'%n_w\n",
" for n_w, w in enumerate(layout_tokens)\n",
" if w != assembler.EOS_idx]\n",
" atts_val = atts_val[:len(decoder_words), :len(encoder_words)]\n",
" plt.figure(figsize=(12, 12))\n",
" plt.subplot(4, 3, 1)\n",
" im = skimage.io.imread(batch['image_path_list'][0])[..., :3]\n",
" plt.imshow(im)\n",
" skimage.io.imsave(os.path.join(save_dir, '%08d_image.jpg' % n), im)\n",
" plt.axis('off')\n",
" plt.subplot(4, 3, 2)\n",
" plt.axis('off')\n",
" plt.imshow(np.ones((3, 3, 3), np.float32))\n",
" plt.text(-1, 2, 'predicted layout:\\n\\n' + expr2str(expr_list[0]))\n",
" question = 'question: ' + ' '.join(encoder_words[:18]) + '\\n' + \\\n",
" ' '.join(encoder_words[18:36]) + '\\n' + \\\n",
" ' '.join(encoder_words[36:]) + '\\n' + \\\n",
" 'ground-truth answer: \"'+ answer_word_list[labels[0]] + '\" ' + \\\n",
" 'predicted answer: \"'+ answer_word_list[predictions[0]] + '\"\\n'\n",
" plt.title(question)\n",
" plt.subplot(4, 3, 3)\n",
" plt.imshow(atts_val.reshape(atts_val.shape[:2]), interpolation='nearest', cmap='Reds')\n",
" plt.xticks(np.arange(len(encoder_words)), encoder_words, rotation=90, fontsize=7.5)\n",
" plt.yticks(np.arange(len(decoder_words)), decoder_words, fontsize=7.5)\n",
" plt.colorbar()\n",
" for t, module_name, results in all_output_stack:\n",
" if t + 4 > 12:\n",
" break\n",
" result = all_output_stack[0][2]\n",
" np.save(os.path.join(save_dir, '%08d_out_%02d.npy' % (n,t)), result)\n",
" plt.subplot(4, 3, t+4)\n",
" if results.ndim > 2:\n",
" im_vis = attention_interpolation(im, results)\n",
" plt.imshow(im_vis)\n",
" skimage.io.imsave(os.path.join(save_dir, '%08d_out_%02d.jpg' % (n,t)), im_vis)\n",
" plt.axis('off')\n",
" else:\n",
" pass # not printing answer\n",
" plt.axis('off')\n",
" plt.imshow(np.ones((320, 480, 3), np.float32))\n",
" plt.text(150, 180, '\"%s\"' % answer_word_list[predictions[0]], fontsize=20)\n",
"# plot = np.tile(results.reshape((1, num_choices)), (2, 1))\n",
"# plt.imshow(plot, interpolation='nearest', vmin=-1.5, vmax=1.5, cmap='Reds')\n",
"# plt.xticks(range(len(answer_word_list)), answer_word_list, rotation=90, fontsize=7.5)\n",
"# plt.yticks([], [])\n",
"# plt.colorbar()\n",
" \n",
" plt.title('output from '+get_module_disp_name(module_name)+'[%d]'%t)\n",
"\n",
" plt.savefig(os.path.join(save_dir, '%08d.jpg' % n))\n",
" plt.close('all')\n",
" \n",
" # individually visualize the textual attention map\n",
" plt.figure(figsize=(5, 2.5))\n",
" plt.imshow(atts_val.reshape(atts_val.shape[:2]), interpolation='nearest', cmap='Reds')\n",
" plt.xticks(np.arange(len(encoder_words)), encoder_words, rotation=90, fontsize=7.5)\n",
" plt.yticks(np.arange(len(decoder_words)), decoder_words, fontsize=7.5)\n",
" plt.colorbar()\n",
" plt.savefig(os.path.join(save_dir, '%08d_txt_att.jpg' % n))\n",
" plt.close('all')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"scrolled": false
},
"outputs": [],
"source": [
"run_visualization(data_reader_tst)"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.1"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment