Skip to content

Instantly share code, notes, and snippets.

@bigsnarfdude
Created November 7, 2016 04:12
Show Gist options
  • Save bigsnarfdude/2f7b2144065f6056892a98495644d3e0 to your computer and use it in GitHub Desktop.
Save bigsnarfdude/2f7b2144065f6056892a98495644d3e0 to your computer and use it in GitHub Desktop.
faster rcnn notebook
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Populating the interactive namespace from numpy and matplotlib\n"
]
}
],
"source": [
"import _init_paths\n",
"import tensorflow as tf\n",
"\n",
"#import matplotlib\n",
"#matplotlib.use('Agg') # Must be before importing matplotlib.pyplot or pylab!\n",
"import matplotlib.pyplot as plt\n",
"%pylab inline\n",
"\n",
"from fast_rcnn.config import cfg\n",
"from fast_rcnn.test import im_detect\n",
"from fast_rcnn.nms_wrapper import nms\n",
"from utils.timer import Timer\n",
"import numpy as np\n",
"import os, sys, cv2\n",
"import argparse\n",
"from networks.factory import get_network\n",
"\n",
"CLASSES = ('__background__',\n",
" 'aeroplane', 'bicycle', 'bird', 'boat',\n",
" 'bottle', 'bus', 'car', 'cat', 'chair',\n",
" 'cow', 'diningtable', 'dog', 'horse',\n",
" 'motorbike', 'person', 'pottedplant',\n",
" 'sheep', 'sofa', 'train', 'tvmonitor')\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"\n",
"#CLASSES = ('__background__','person','bike','motorbike','car','bus')\n",
"\n",
"def vis_detections(im, class_name, dets,ax, thresh=0.5):\n",
" \"\"\"Draw detected bounding boxes.\"\"\"\n",
" inds = np.where(dets[:, -1] >= thresh)[0]\n",
" if len(inds) == 0:\n",
" return\n",
"\n",
" for i in inds:\n",
" bbox = dets[i, :4]\n",
" score = dets[i, -1]\n",
"\n",
" ax.add_patch(\n",
" plt.Rectangle((bbox[0], bbox[1]),\n",
" bbox[2] - bbox[0],\n",
" bbox[3] - bbox[1], fill=False,\n",
" edgecolor='red', linewidth=3.5)\n",
" )\n",
" ax.text(bbox[0], bbox[1] - 2,\n",
" '{:s} {:.3f}'.format(class_name, score),\n",
" bbox=dict(facecolor='blue', alpha=0.5),\n",
" fontsize=14, color='white')\n",
"\n",
" ax.set_title(('{} detections with '\n",
" 'p({} | box) >= {:.1f}').format(class_name, class_name,\n",
" thresh),\n",
" fontsize=14)\n",
" plt.axis('off')\n",
" plt.tight_layout()\n",
" plt.draw()\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"\n",
"def demo(sess, net, image_name):\n",
" \"\"\"Detect object classes in an image using pre-computed object proposals.\"\"\"\n",
"\n",
" # Load the demo image\n",
" im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)\n",
" #im_file = os.path.join('/home/corgi/Lab/label/pos_frame/ACCV/training/000001/',image_name)\n",
" im = cv2.imread(im_file)\n",
"\n",
" # Detect all object classes and regress object bounds\n",
" timer = Timer()\n",
" timer.tic()\n",
" scores, boxes = im_detect(sess, net, im)\n",
" timer.toc()\n",
" print ('Detection took {:.3f}s for '\n",
" '{:d} object proposals').format(timer.total_time, boxes.shape[0])\n",
"\n",
" # Visualize detections for each class\n",
" im = im[:, :, (2, 1, 0)]\n",
" fig, ax = plt.subplots(figsize=(12, 12))\n",
" ax.imshow(im, aspect='equal')\n",
"\n",
" CONF_THRESH = 0.8\n",
" NMS_THRESH = 0.3\n",
" for cls_ind, cls in enumerate(CLASSES[1:]):\n",
" cls_ind += 1 # because we skipped background\n",
" cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]\n",
" cls_scores = scores[:, cls_ind]\n",
" dets = np.hstack((cls_boxes,\n",
" cls_scores[:, np.newaxis])).astype(np.float32)\n",
" keep = nms(dets, NMS_THRESH)\n",
" dets = dets[keep, :]\n",
" vis_detections(im, cls, dets, ax, thresh=CONF_THRESH)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"cfg.TEST.HAS_RPN = True # Use RPN for proposals\n",
"gpu_id = 1\n",
"demo_net = \"VGGnet_test\"\n",
"model = \"/home/ubuntu/dev/Faster-RCNN_TF/VGGnet_fast_rcnn_iter_70000.ckpt\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Tensor(\"Placeholder:0\", shape=(?, ?, ?, 3), dtype=float32)\n",
"Tensor(\"conv5_3/conv5_3:0\", shape=(?, ?, ?, 512), dtype=float32)\n",
"Tensor(\"rpn_conv/3x3/rpn_conv/3x3:0\", shape=(?, ?, ?, 512), dtype=float32)\n",
"Tensor(\"rpn_cls_score/rpn_cls_score:0\", shape=(?, ?, ?, 18), dtype=float32)\n",
"Tensor(\"rpn_cls_prob:0\", shape=(?, ?, ?, ?), dtype=float32)\n",
"Tensor(\"rpn_cls_prob_reshape:0\", shape=(?, ?, ?, 18), dtype=float32)\n",
"Tensor(\"rpn_bbox_pred/rpn_bbox_pred:0\", shape=(?, ?, ?, 36), dtype=float32)\n",
"Tensor(\"Placeholder_1:0\", shape=(?, 3), dtype=float32)\n",
"Tensor(\"conv5_3/conv5_3:0\", shape=(?, ?, ?, 512), dtype=float32)\n",
"Tensor(\"rois:0\", shape=(?, 5), dtype=float32)\n",
"[<tf.Tensor 'conv5_3/conv5_3:0' shape=(?, ?, ?, 512) dtype=float32>, <tf.Tensor 'rois:0' shape=(?, 5) dtype=float32>]\n",
"Tensor(\"fc7/fc7:0\", shape=(?, 4096), dtype=float32)\n"
]
}
],
"source": [
"net = get_network(demo_net)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"saver = tf.train.Saver()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Loaded network /home/ubuntu/dev/Faster-RCNN_TF/VGGnet_fast_rcnn_iter_70000.ckpt\n"
]
}
],
"source": [
"saver.restore(sess, model)\n",
"\n",
"print('\\n\\nLoaded network {:s}'.format(model))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Warmup on a dummy image\n",
"im = 128 * np.ones((300, 300, 3), dtype=np.uint8)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
"Demo for data/demo/000456.jpg\n",
"Detection took 1.564s for 300 object proposals\n",
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
"Demo for data/demo/000542.jpg\n",
"Detection took 0.111s for 261 object proposals\n",
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
"Demo for data/demo/001150.jpg\n",
"Detection took 0.105s for 232 object proposals\n",
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
"Demo for data/demo/001763.jpg\n",
"Detection took 0.112s for 265 object proposals\n",
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
"Demo for data/demo/004545.jpg\n",
"Detection took 0.106s for 300 object proposals\n"
]
},
{
"data": {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment