Created
November 7, 2016 04:12
-
-
Save bigsnarfdude/2f7b2144065f6056892a98495644d3e0 to your computer and use it in GitHub Desktop.
faster rcnn notebook
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Populating the interactive namespace from numpy and matplotlib\n" | |
] | |
} | |
], | |
"source": [ | |
"import _init_paths\n", | |
"import tensorflow as tf\n", | |
"\n", | |
"#import matplotlib\n", | |
"#matplotlib.use('Agg') # Must be before importing matplotlib.pyplot or pylab!\n", | |
"import matplotlib.pyplot as plt\n", | |
"%pylab inline\n", | |
"\n", | |
"from fast_rcnn.config import cfg\n", | |
"from fast_rcnn.test import im_detect\n", | |
"from fast_rcnn.nms_wrapper import nms\n", | |
"from utils.timer import Timer\n", | |
"import numpy as np\n", | |
"import os, sys, cv2\n", | |
"import argparse\n", | |
"from networks.factory import get_network\n", | |
"\n", | |
"CLASSES = ('__background__',\n", | |
" 'aeroplane', 'bicycle', 'bird', 'boat',\n", | |
" 'bottle', 'bus', 'car', 'cat', 'chair',\n", | |
" 'cow', 'diningtable', 'dog', 'horse',\n", | |
" 'motorbike', 'person', 'pottedplant',\n", | |
" 'sheep', 'sofa', 'train', 'tvmonitor')\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"#CLASSES = ('__background__','person','bike','motorbike','car','bus')\n", | |
"\n", | |
"def vis_detections(im, class_name, dets,ax, thresh=0.5):\n", | |
" \"\"\"Draw detected bounding boxes.\"\"\"\n", | |
" inds = np.where(dets[:, -1] >= thresh)[0]\n", | |
" if len(inds) == 0:\n", | |
" return\n", | |
"\n", | |
" for i in inds:\n", | |
" bbox = dets[i, :4]\n", | |
" score = dets[i, -1]\n", | |
"\n", | |
" ax.add_patch(\n", | |
" plt.Rectangle((bbox[0], bbox[1]),\n", | |
" bbox[2] - bbox[0],\n", | |
" bbox[3] - bbox[1], fill=False,\n", | |
" edgecolor='red', linewidth=3.5)\n", | |
" )\n", | |
" ax.text(bbox[0], bbox[1] - 2,\n", | |
" '{:s} {:.3f}'.format(class_name, score),\n", | |
" bbox=dict(facecolor='blue', alpha=0.5),\n", | |
" fontsize=14, color='white')\n", | |
"\n", | |
" ax.set_title(('{} detections with '\n", | |
" 'p({} | box) >= {:.1f}').format(class_name, class_name,\n", | |
" thresh),\n", | |
" fontsize=14)\n", | |
" plt.axis('off')\n", | |
" plt.tight_layout()\n", | |
" plt.draw()\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"def demo(sess, net, image_name):\n", | |
" \"\"\"Detect object classes in an image using pre-computed object proposals.\"\"\"\n", | |
"\n", | |
" # Load the demo image\n", | |
" im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)\n", | |
" #im_file = os.path.join('/home/corgi/Lab/label/pos_frame/ACCV/training/000001/',image_name)\n", | |
" im = cv2.imread(im_file)\n", | |
"\n", | |
" # Detect all object classes and regress object bounds\n", | |
" timer = Timer()\n", | |
" timer.tic()\n", | |
" scores, boxes = im_detect(sess, net, im)\n", | |
" timer.toc()\n", | |
" print ('Detection took {:.3f}s for '\n", | |
" '{:d} object proposals').format(timer.total_time, boxes.shape[0])\n", | |
"\n", | |
" # Visualize detections for each class\n", | |
" im = im[:, :, (2, 1, 0)]\n", | |
" fig, ax = plt.subplots(figsize=(12, 12))\n", | |
" ax.imshow(im, aspect='equal')\n", | |
"\n", | |
" CONF_THRESH = 0.8\n", | |
" NMS_THRESH = 0.3\n", | |
" for cls_ind, cls in enumerate(CLASSES[1:]):\n", | |
" cls_ind += 1 # because we skipped background\n", | |
" cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]\n", | |
" cls_scores = scores[:, cls_ind]\n", | |
" dets = np.hstack((cls_boxes,\n", | |
" cls_scores[:, np.newaxis])).astype(np.float32)\n", | |
" keep = nms(dets, NMS_THRESH)\n", | |
" dets = dets[keep, :]\n", | |
" vis_detections(im, cls, dets, ax, thresh=CONF_THRESH)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"cfg.TEST.HAS_RPN = True # Use RPN for proposals\n", | |
"gpu_id = 1\n", | |
"demo_net = \"VGGnet_test\"\n", | |
"model = \"/home/ubuntu/dev/Faster-RCNN_TF/VGGnet_fast_rcnn_iter_70000.ckpt\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Tensor(\"Placeholder:0\", shape=(?, ?, ?, 3), dtype=float32)\n", | |
"Tensor(\"conv5_3/conv5_3:0\", shape=(?, ?, ?, 512), dtype=float32)\n", | |
"Tensor(\"rpn_conv/3x3/rpn_conv/3x3:0\", shape=(?, ?, ?, 512), dtype=float32)\n", | |
"Tensor(\"rpn_cls_score/rpn_cls_score:0\", shape=(?, ?, ?, 18), dtype=float32)\n", | |
"Tensor(\"rpn_cls_prob:0\", shape=(?, ?, ?, ?), dtype=float32)\n", | |
"Tensor(\"rpn_cls_prob_reshape:0\", shape=(?, ?, ?, 18), dtype=float32)\n", | |
"Tensor(\"rpn_bbox_pred/rpn_bbox_pred:0\", shape=(?, ?, ?, 36), dtype=float32)\n", | |
"Tensor(\"Placeholder_1:0\", shape=(?, 3), dtype=float32)\n", | |
"Tensor(\"conv5_3/conv5_3:0\", shape=(?, ?, ?, 512), dtype=float32)\n", | |
"Tensor(\"rois:0\", shape=(?, 5), dtype=float32)\n", | |
"[<tf.Tensor 'conv5_3/conv5_3:0' shape=(?, ?, ?, 512) dtype=float32>, <tf.Tensor 'rois:0' shape=(?, 5) dtype=float32>]\n", | |
"Tensor(\"fc7/fc7:0\", shape=(?, 4096), dtype=float32)\n" | |
] | |
} | |
], | |
"source": [ | |
"net = get_network(demo_net)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"saver = tf.train.Saver()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"\n", | |
"Loaded network /home/ubuntu/dev/Faster-RCNN_TF/VGGnet_fast_rcnn_iter_70000.ckpt\n" | |
] | |
} | |
], | |
"source": [ | |
"saver.restore(sess, model)\n", | |
"\n", | |
"print('\\n\\nLoaded network {:s}'.format(model))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# Warmup on a dummy image\n", | |
"im = 128 * np.ones((300, 300, 3), dtype=np.uint8)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", | |
"Demo for data/demo/000456.jpg\n", | |
"Detection took 1.564s for 300 object proposals\n", | |
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", | |
"Demo for data/demo/000542.jpg\n", | |
"Detection took 0.111s for 261 object proposals\n", | |
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", | |
"Demo for data/demo/001150.jpg\n", | |
"Detection took 0.105s for 232 object proposals\n", | |
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", | |
"Demo for data/demo/001763.jpg\n", | |
"Detection took 0.112s for 265 object proposals\n", | |
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", | |
"Demo for data/demo/004545.jpg\n", | |
"Detection took 0.106s for 300 object proposals\n" | |
] | |
}, | |
{ | |
"data": { |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment