Created
August 15, 2017 16:19
-
-
Save ronghanghu/edb3a289a98bdea23847471780aa206a to your computer and use it in GitHub Desktop.
Visualize VQA attentions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"from __future__ import absolute_import, division, print_function\n", | |
"\n", | |
"gpu_id = 0 # set GPU id to use\n", | |
"import os; os.environ[\"CUDA_VISIBLE_DEVICES\"] = str(gpu_id)\n", | |
"\n", | |
"import numpy as np\n", | |
"import tensorflow as tf\n", | |
"# Start the session BEFORE importing tensorflow_fold\n", | |
"# to avoid taking up all GPU memory\n", | |
"sess = tf.Session(config=tf.ConfigProto(\n", | |
" gpu_options=tf.GPUOptions(allow_growth=True),\n", | |
" allow_soft_placement=False, log_device_placement=False))\n", | |
"import json\n", | |
"\n", | |
"from models_vqa.nmn3_assembler import Assembler\n", | |
"from models_vqa.nmn3_model import NMN3Model\n", | |
"from util.vqa_train.data_reader import DataReader\n", | |
"\n", | |
"from models_vqa.nmn3_modules import Modules\n", | |
"from models_vqa.nmn3_assembler import _module_input_num" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# Module parameters\n", | |
"H_feat = 14\n", | |
"W_feat = 14\n", | |
"D_feat = 2048\n", | |
"embed_dim_txt = 300\n", | |
"embed_dim_nmn = 300\n", | |
"lstm_dim = 1000\n", | |
"num_layers = 2\n", | |
"T_encoder = 26\n", | |
"T_decoder = 13\n", | |
"N = 1\n", | |
"use_qpn = True\n", | |
"reduce_visfeat_dim = False\n", | |
"\n", | |
"exp_name = \"vqa_gt_layout\"\n", | |
"snapshot_name = \"00040000\"\n", | |
"# tst_image_set = 'train2014'\n", | |
"tst_image_set = 'val2014'\n", | |
"# tst_image_set = 'test-dev2015'\n", | |
"# tst_image_set = 'test2015'\n", | |
"snapshot_file = './exp_vqa/tfmodel/%s/%s' % (exp_name, snapshot_name)\n", | |
"\n", | |
"# Data files\n", | |
"vocab_question_file = './exp_vqa/data/vocabulary_vqa.txt'\n", | |
"vocab_layout_file = './exp_vqa/data/vocabulary_layout.txt'\n", | |
"vocab_answer_file = './exp_vqa/data/answers_vqa.txt'\n", | |
"\n", | |
"# imdb_file_trn = './exp_vqa/data/imdb/imdb_trainval2014.npy'\n", | |
"imdb_file_tst = './exp_vqa/data/imdb/imdb_%s.npy' % tst_image_set\n", | |
"\n", | |
"save_dir = './exp_vqa/results/%s/%s.%s' % (exp_name, snapshot_name, tst_image_set)\n", | |
"os.makedirs(save_dir, exist_ok=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"assembler = Assembler(vocab_layout_file)\n", | |
"\n", | |
"# data_reader_trn = DataReader(imdb_file_trn, shuffle=True, one_pass=False,\n", | |
"# batch_size=N,\n", | |
"# T_encoder=T_encoder,\n", | |
"# T_decoder=T_decoder,\n", | |
"# assembler=assembler,\n", | |
"# vocab_question_file=vocab_question_file,\n", | |
"# vocab_answer_file=vocab_answer_file)\n", | |
"\n", | |
"data_reader_tst = DataReader(imdb_file_tst, shuffle=False, one_pass=True,\n", | |
" batch_size=N,\n", | |
" T_encoder=T_encoder,\n", | |
" T_decoder=T_decoder,\n", | |
" assembler=assembler,\n", | |
" vocab_question_file=vocab_question_file,\n", | |
" vocab_answer_file=vocab_answer_file)\n", | |
"\n", | |
"num_vocab_txt = data_reader_tst.batch_loader.vocab_dict.num_vocab\n", | |
"num_vocab_nmn = len(assembler.module_names)\n", | |
"num_choices = data_reader_tst.batch_loader.answer_dict.num_vocab" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# Network inputs\n", | |
"input_seq_batch = tf.placeholder(tf.int32, [None, None])\n", | |
"seq_length_batch = tf.placeholder(tf.int32, [None])\n", | |
"image_feat_batch = tf.placeholder(tf.float32, [None, H_feat, W_feat, D_feat])\n", | |
"expr_validity_batch = tf.placeholder(tf.bool, [None])\n", | |
"\n", | |
"# The model for testing\n", | |
"nmn3_model_tst = NMN3Model(\n", | |
" image_feat_batch, input_seq_batch,\n", | |
" seq_length_batch, T_decoder=T_decoder,\n", | |
" num_vocab_txt=num_vocab_txt, embed_dim_txt=embed_dim_txt,\n", | |
" num_vocab_nmn=num_vocab_nmn, embed_dim_nmn=embed_dim_nmn,\n", | |
" lstm_dim=lstm_dim, num_layers=num_layers,\n", | |
" assembler=assembler,\n", | |
" encoder_dropout=False,\n", | |
" decoder_dropout=False,\n", | |
" decoder_sampling=False,\n", | |
" num_choices=num_choices,\n", | |
" use_qpn=use_qpn, qpn_dropout=False, reduce_visfeat_dim=reduce_visfeat_dim)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"image_feature_grid = nmn3_model_tst.image_feat_grid\n", | |
"word_vecs = nmn3_model_tst.word_vecs\n", | |
"atts = nmn3_model_tst.atts\n", | |
"\n", | |
"image_feat_grid_ph = tf.placeholder(tf.float32, image_feature_grid.get_shape())\n", | |
"word_vecs_ph = tf.placeholder(tf.float32, word_vecs.get_shape())\n", | |
"\n", | |
"batch_idx = tf.constant([0], tf.int32)\n", | |
"time_idx = tf.placeholder(tf.int32, [1])\n", | |
"input_0 = tf.placeholder(tf.float32, [1, H_feat, W_feat, 1])\n", | |
"input_1 = tf.placeholder(tf.float32, [1, H_feat, W_feat, 1])\n", | |
"\n", | |
"# Manually construct each module outside TensorFlow fold for visualization\n", | |
"module_outputs = {}\n", | |
"with tf.variable_scope(\"neural_module_network/layout_execution\", reuse=True):\n", | |
" modules = Modules(image_feat_grid_ph, word_vecs_ph, None, num_choices)\n", | |
" module_outputs['_Scene'] = modules.SceneModule(time_idx, batch_idx)\n", | |
" module_outputs['_Find'] = modules.FindModule(time_idx, batch_idx)\n", | |
" module_outputs['_FindSameProperty'] = modules.FindSamePropertyModule(input_0, time_idx, batch_idx)\n", | |
" module_outputs['_Transform'] = modules.TransformModule(input_0, time_idx, batch_idx)\n", | |
" module_outputs['_And'] = modules.AndModule(input_0, input_1, time_idx, batch_idx)\n", | |
" module_outputs['_Filter'] = modules.FilterModule(input_0, time_idx, batch_idx)\n", | |
" module_outputs['_Or'] = modules.OrModule(input_0, input_1, time_idx, batch_idx)\n", | |
" module_outputs['_Exist'] = modules.ExistModule(input_0, time_idx, batch_idx)\n", | |
" module_outputs['_Count'] = modules.CountModule(input_0, time_idx, batch_idx)\n", | |
" module_outputs['_EqualNum'] = modules.EqualNumModule(input_0, input_1, time_idx, batch_idx)\n", | |
" module_outputs['_MoreNum'] = modules.MoreNumModule(input_0, input_1, time_idx, batch_idx)\n", | |
" module_outputs['_LessNum'] = modules.LessNumModule(input_0, input_1, time_idx, batch_idx)\n", | |
" module_outputs['_SameProperty'] = modules.SamePropertyModule(input_0, input_1, time_idx, batch_idx)\n", | |
" module_outputs['_Describe'] = modules.DescribeModule(input_0, time_idx, batch_idx)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"def eval_module(module_name, inputs, t, image_feat_grid_val, word_vecs_val):\n", | |
" feed_dict = {image_feat_grid_ph: image_feat_grid_val,\n", | |
" word_vecs_ph: word_vecs_val,\n", | |
" time_idx: [t]}\n", | |
" # print('evaluating module ' + module_name)\n", | |
" if 'input_0' in inputs:\n", | |
" feed_dict[input_0] = inputs['input_0']\n", | |
" if 'input_1' in inputs:\n", | |
" feed_dict[input_1] = inputs['input_1']\n", | |
" if module_name in module_outputs:\n", | |
" result = sess.run(module_outputs[module_name], feed_dict)\n", | |
" else:\n", | |
" raise ValueError(\"invalid module name: \" + module_name)\n", | |
"\n", | |
" return result\n", | |
"\n", | |
"def eval_expr(layout_tokens, image_feat_grid_val, word_vecs_val):\n", | |
" invalid_scores = np.array([[0, 0]], np.float32)\n", | |
" # Decoding Reverse Polish Notation with a stack\n", | |
" decoding_stack = []\n", | |
" all_output_stack = []\n", | |
" for t in range(len(layout_tokens)):\n", | |
" # decode a module/operation\n", | |
" module_idx = layout_tokens[t]\n", | |
" if module_idx == assembler.EOS_idx:\n", | |
" break\n", | |
" module_name = assembler.module_names[module_idx]\n", | |
" input_num = _module_input_num[module_name]\n", | |
"\n", | |
" # Get the input from stack\n", | |
" inputs = {}\n", | |
" for n_input in range(input_num-1, -1, -1):\n", | |
" stack_top = decoding_stack.pop()\n", | |
" inputs[\"input_%d\" % n_input] = stack_top\n", | |
" result = eval_module(module_name, inputs, t,\n", | |
" image_feat_grid_val, word_vecs_val)\n", | |
" decoding_stack.append(result)\n", | |
" all_output_stack.append((t, module_name, result[0]))\n", | |
"\n", | |
" assert(len(decoding_stack) == 1)\n", | |
" result = decoding_stack[0]\n", | |
" return result, all_output_stack\n", | |
"\n", | |
"def expr2str(expr, indent=4):\n", | |
" name = expr['module']\n", | |
" input_str = []\n", | |
" if 'input_0' in expr:\n", | |
" input_str.append('\\n'+' '*indent+expr2str(expr['input_0'], indent+4))\n", | |
" if 'input_1' in expr:\n", | |
" input_str.append('\\n'+' '*indent+expr2str(expr['input_1'], indent+4))\n", | |
" expr_str = name[1:]+('[%d]'%expr['time_idx'])+\"(\"+\", \".join(input_str)+\")\"\n", | |
" return expr_str" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"snapshot_saver = tf.train.Saver(max_to_keep=None) # keep all snapshots\n", | |
"snapshot_saver.restore(sess, snapshot_file)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import skimage.io\n", | |
"import skimage.transform\n", | |
"\n", | |
"def att_softmax(att):\n", | |
" exps = np.exp(att[..., 0] - np.max(att))\n", | |
" softmax = exps / np.sum(exps)\n", | |
" return softmax\n", | |
"\n", | |
"name_map = {\n", | |
" 'transform': 'relocate',\n", | |
" 'equalnum': 'eq_count',\n", | |
" 'morenum': 'more',\n", | |
" 'lessnum': 'less',\n", | |
" 'sameproperty': 'compare',\n", | |
" 'findsameproperty': 'relocate'}\n", | |
"def get_module_disp_name(name):\n", | |
" name = name[1:].lower()\n", | |
" if name in name_map:\n", | |
" name = name_map[name]\n", | |
" return name\n", | |
"\n", | |
"def attention_interpolation(im, att):\n", | |
" # steps:\n", | |
" # 1. reshape the attention to image size (with cubic)\n", | |
" softmax = att_softmax(att)\n", | |
" att_reshaped = skimage.transform.resize(softmax, im.shape[:2], order=3)\n", | |
" att_reshaped /= np.max(att_reshaped)\n", | |
" att_reshaped = att_reshaped[..., np.newaxis]\n", | |
" \n", | |
" all_white = np.ones_like(im) * (255 if im.dtype == np.uint8 else 1)\n", | |
" vis_im = att_reshaped*im + (1-att_reshaped)*all_white\n", | |
" vis_im = vis_im.astype(im.dtype)\n", | |
" return vis_im" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"import skimage.io\n", | |
"import skimage.transform\n", | |
"import matplotlib.pyplot as plt\n", | |
"%matplotlib inline\n", | |
"plt.rcParams.update({'font.size': 6})\n", | |
"def run_visualization(dataset_tst):\n", | |
" if dataset_tst is None:\n", | |
" return\n", | |
" print('Running test...')\n", | |
" answer_word_list = dataset_tst.batch_loader.answer_dict.word_list\n", | |
" vocab_list = dataset_tst.batch_loader.vocab_dict.word_list\n", | |
" for n, batch in enumerate(dataset_tst.batches()):\n", | |
" if n >= 100: break\n", | |
" # set up input and output tensors\n", | |
" h = sess.partial_run_setup(\n", | |
" [nmn3_model_tst.predicted_tokens, nmn3_model_tst.scores, nmn3_model_tst.scores_qpn, word_vecs, atts],\n", | |
" [input_seq_batch, seq_length_batch, image_feat_batch,\n", | |
" nmn3_model_tst.compiler.loom_input_tensor, expr_validity_batch])\n", | |
"\n", | |
" # Part 0 & 1: Run Convnet and generate module layout\n", | |
" tokens, word_vecs_val, atts_val, scores_qpn_val =\\\n", | |
" sess.partial_run(h, (nmn3_model_tst.predicted_tokens, word_vecs, atts, nmn3_model_tst.scores_qpn),\n", | |
" feed_dict={input_seq_batch: batch['input_seq_batch'],\n", | |
" seq_length_batch: batch['seq_length_batch'],\n", | |
" image_feat_batch: batch['image_feat_batch']})\n", | |
" image_feat_grid_val = batch['image_feat_batch']\n", | |
"\n", | |
" # Assemble the layout tokens into network structure\n", | |
" expr_list, expr_validity_array = assembler.assemble(tokens)\n", | |
" labels = batch['answer_label_batch']\n", | |
" # Build TensorFlow Fold input for NMN\n", | |
" expr_feed = nmn3_model_tst.compiler.build_feed_dict(expr_list)\n", | |
" expr_feed[expr_validity_batch] = expr_validity_array\n", | |
"\n", | |
" # Part 2: Run NMN and learning steps\n", | |
" scores_val = sess.partial_run(h, nmn3_model_tst.scores, feed_dict=expr_feed)\n", | |
"\n", | |
" predictions = np.argmax(scores_val, axis=1)\n", | |
"\n", | |
" # Part 3: Visualization\n", | |
" print('visualizing %d' % n)\n", | |
" layout_tokens = tokens.T[0]\n", | |
" result, all_output_stack = eval_expr(layout_tokens, image_feat_grid_val, word_vecs_val)\n", | |
" result += scores_qpn_val\n", | |
" # check that the results are consistent\n", | |
" diff = np.max(np.abs(result - scores_val))\n", | |
" assert(np.all(diff < 1e-4))\n", | |
"\n", | |
" encoder_words = [vocab_list[w]\n", | |
" for n_w, w in enumerate(batch['input_seq_batch'][:, 0])\n", | |
" if n_w < batch['seq_length_batch'][0]]\n", | |
" decoder_words = [get_module_disp_name(assembler.module_names[w])+'[%d]'%n_w\n", | |
" for n_w, w in enumerate(layout_tokens)\n", | |
" if w != assembler.EOS_idx]\n", | |
" atts_val = atts_val[:len(decoder_words), :len(encoder_words)]\n", | |
" plt.figure(figsize=(12, 12))\n", | |
" plt.subplot(4, 3, 1)\n", | |
" im = skimage.io.imread(batch['image_path_list'][0])[..., :3]\n", | |
" plt.imshow(im)\n", | |
" skimage.io.imsave(os.path.join(save_dir, '%08d_image.jpg' % n), im)\n", | |
" plt.axis('off')\n", | |
" plt.subplot(4, 3, 2)\n", | |
" plt.axis('off')\n", | |
" plt.imshow(np.ones((3, 3, 3), np.float32))\n", | |
" plt.text(-1, 2, 'predicted layout:\\n\\n' + expr2str(expr_list[0]))\n", | |
" question = 'question: ' + ' '.join(encoder_words[:18]) + '\\n' + \\\n", | |
" ' '.join(encoder_words[18:36]) + '\\n' + \\\n", | |
" ' '.join(encoder_words[36:]) + '\\n' + \\\n", | |
" 'ground-truth answer: \"'+ answer_word_list[labels[0]] + '\" ' + \\\n", | |
" 'predicted answer: \"'+ answer_word_list[predictions[0]] + '\"\\n'\n", | |
" plt.title(question)\n", | |
" plt.subplot(4, 3, 3)\n", | |
" plt.imshow(atts_val.reshape(atts_val.shape[:2]), interpolation='nearest', cmap='Reds')\n", | |
" plt.xticks(np.arange(len(encoder_words)), encoder_words, rotation=90, fontsize=7.5)\n", | |
" plt.yticks(np.arange(len(decoder_words)), decoder_words, fontsize=7.5)\n", | |
" plt.colorbar()\n", | |
" for t, module_name, results in all_output_stack:\n", | |
" if t + 4 > 12:\n", | |
" break\n", | |
" result = all_output_stack[0][2]\n", | |
" np.save(os.path.join(save_dir, '%08d_out_%02d.npy' % (n,t)), result)\n", | |
" plt.subplot(4, 3, t+4)\n", | |
" if results.ndim > 2:\n", | |
" im_vis = attention_interpolation(im, results)\n", | |
" plt.imshow(im_vis)\n", | |
" skimage.io.imsave(os.path.join(save_dir, '%08d_out_%02d.jpg' % (n,t)), im_vis)\n", | |
" plt.axis('off')\n", | |
" else:\n", | |
" pass # not printing answer\n", | |
" plt.axis('off')\n", | |
" plt.imshow(np.ones((320, 480, 3), np.float32))\n", | |
" plt.text(150, 180, '\"%s\"' % answer_word_list[predictions[0]], fontsize=20)\n", | |
"# plot = np.tile(results.reshape((1, num_choices)), (2, 1))\n", | |
"# plt.imshow(plot, interpolation='nearest', vmin=-1.5, vmax=1.5, cmap='Reds')\n", | |
"# plt.xticks(range(len(answer_word_list)), answer_word_list, rotation=90, fontsize=7.5)\n", | |
"# plt.yticks([], [])\n", | |
"# plt.colorbar()\n", | |
" \n", | |
" plt.title('output from '+get_module_disp_name(module_name)+'[%d]'%t)\n", | |
"\n", | |
" plt.savefig(os.path.join(save_dir, '%08d.jpg' % n))\n", | |
" plt.close('all')\n", | |
" \n", | |
" # individually visualize the textual attention map\n", | |
" plt.figure(figsize=(5, 2.5))\n", | |
" plt.imshow(atts_val.reshape(atts_val.shape[:2]), interpolation='nearest', cmap='Reds')\n", | |
" plt.xticks(np.arange(len(encoder_words)), encoder_words, rotation=90, fontsize=7.5)\n", | |
" plt.yticks(np.arange(len(decoder_words)), decoder_words, fontsize=7.5)\n", | |
" plt.colorbar()\n", | |
" plt.savefig(os.path.join(save_dir, '%08d_txt_att.jpg' % n))\n", | |
" plt.close('all')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": false | |
}, | |
"outputs": [], | |
"source": [ | |
"run_visualization(data_reader_tst)" | |
] | |
} | |
], | |
"metadata": { | |
"anaconda-cloud": {}, | |
"kernelspec": { | |
"display_name": "Python [default]", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 1 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment