Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save ginrou/56c7e4721c4e6697ddee2e9dc954f17c to your computer and use it in GitHub Desktop.
Save ginrou/56c7e4721c4e6697ddee2e9dc954f17c to your computer and use it in GitHub Desktop.
ObjectDetection_KITTI_ssd_resnet50_v1_fpn_640x640
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "ObjectDetection_KITTI_ssd_resnet50_v1_fpn_640x640",
"provenance": [],
"collapsed_sections": [],
"mount_file_id": "1r4sim-bDsp8kQnUyFInm2yxrY2jTvAF8",
"authorship_tag": "ABX9TyMsU1MSxVV0xoK9Yymx8BfR",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ginrou/56c7e4721c4e6697ddee2e9dc954f17c/objectdetection_kitti_ssd_resnet50_v1_fpn_640x640.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "rCSVpSRgHgUf",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 731
},
"outputId": "d41e0334-54de-4f67-cc06-454891c40693"
},
"source": [
"!pip -q install tfds-nightly matplotlib opencv-python\n",
"import os, pathlib, time, cv2\n",
"import tensorflow as tf\n",
"import tensorflow_datasets as tfds\n",
"import cv2\n",
"from IPython.display import display, clear_output\n",
"from PIL import Image\n",
"from typing import *\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import pprint as pprint_lib\n",
"pp = pprint_lib.PrettyPrinter(width=200)\n",
"%matplotlib inline\n",
"\n",
"\n",
"if os.path.exists(\"/content/models/research/\"):\n",
" %cd /content/models/research/\n",
"else:\n",
" !git clone --quiet --depth 1 https://github.com/tensorflow/models\n",
" %cd /content/models/research/\n",
" !protoc object_detection/protos/*.proto --python_out=.\n",
" !cp object_detection/packages/tf2/setup.py . && python -m pip -q install .\n",
" !wget http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz\n",
" !tar -xf ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz\n",
" !mv ssd_resnet50_v1_fpn_640x640_coco17_tpu-8/checkpoint /content/models/research/object_detection/test_data/\n",
"\n",
"pp.pprint([f\"{d.name}, {d.device_type} {d.physical_device_desc}\" for d in tf.python.client.device_lib.list_local_devices()])\n",
"\n",
"from object_detection.utils import config_util\n",
"from object_detection.builders import model_builder"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"\u001b[K |████████████████████████████████| 3.5MB 4.6MB/s \n",
"\u001b[?25h/content/models/research\n",
"\u001b[K |████████████████████████████████| 8.3MB 11.8MB/s \n",
"\u001b[K |████████████████████████████████| 358kB 49.9MB/s \n",
"\u001b[K |████████████████████████████████| 849kB 57.5MB/s \n",
"\u001b[K |████████████████████████████████| 63.8MB 47kB/s \n",
"\u001b[K |████████████████████████████████| 153kB 57.7MB/s \n",
"\u001b[K |████████████████████████████████| 81kB 10.9MB/s \n",
"\u001b[K |████████████████████████████████| 51kB 7.6MB/s \n",
"\u001b[K |████████████████████████████████| 61kB 8.5MB/s \n",
"\u001b[K |████████████████████████████████| 1.4MB 55.0MB/s \n",
"\u001b[K |████████████████████████████████| 829kB 42.6MB/s \n",
"\u001b[K |████████████████████████████████| 1.1MB 46.6MB/s \n",
"\u001b[K |████████████████████████████████| 174kB 55.0MB/s \n",
"\u001b[K |████████████████████████████████| 36.6MB 83kB/s \n",
"\u001b[K |████████████████████████████████| 102kB 12.8MB/s \n",
"\u001b[K |████████████████████████████████| 112kB 56.8MB/s \n",
"\u001b[?25h Building wheel for object-detection (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Building wheel for avro-python3 (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Building wheel for dill (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Building wheel for oauth2client (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Building wheel for hdfs (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Building wheel for future (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
"\u001b[31mERROR: pydrive 1.3.1 has requirement oauth2client>=4.0.0, but you'll have oauth2client 3.0.0 which is incompatible.\u001b[0m\n",
"\u001b[31mERROR: multiprocess 0.70.10 has requirement dill>=0.3.2, but you'll have dill 0.3.1.1 which is incompatible.\u001b[0m\n",
"\u001b[31mERROR: apache-beam 2.23.0 has requirement avro-python3!=1.9.2,<1.10.0,>=1.8.1; python_version >= \"3.0\", but you'll have avro-python3 1.10.0 which is incompatible.\u001b[0m\n",
"--2020-09-03 19:41:33-- http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz\n",
"Resolving download.tensorflow.org (download.tensorflow.org)... 74.125.142.128, 2607:f8b0:400e:c08::80\n",
"Connecting to download.tensorflow.org (download.tensorflow.org)|74.125.142.128|:80... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 244817203 (233M) [application/x-tar]\n",
"Saving to: ‘ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz’\n",
"\n",
"ssd_resnet50_v1_fpn 100%[===================>] 233.48M 231MB/s in 1.0s \n",
"\n",
"2020-09-03 19:41:35 (231 MB/s) - ‘ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz’ saved [244817203/244817203]\n",
"\n",
"['/device:CPU:0, CPU ',\n",
" '/device:XLA_CPU:0, XLA_CPU device: XLA_CPU device',\n",
" '/device:XLA_GPU:0, XLA_GPU device: XLA_GPU device',\n",
" '/device:GPU:0, GPU device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5']\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ijmPnxFyt5lA",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "08fd16bd-0a73-48a1-b0ec-9ac1c98e807b"
},
"source": [
"#tfds_data_dir = \"/content/drive/My Drive/Colab Notebooks/tensorflow_datasets\"\n",
"tfds_data_dir = \"/content/tensorflow_datasets\" \n",
"pipeline_config = \"/content/models/research/object_detection/configs/tf2/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.config\"\n",
"checkpoint_path = \"/content/models/research/object_detection/test_data/checkpoint/ckpt-0\"\n",
"\n",
"import tensorflow_datasets as tfds\n",
"kitti, ds_info = tfds.load(\"kitti\", with_info=True, data_dir=tfds_data_dir)\n",
"LABELS = ['Car', 'Van', 'Truck', 'Pedestrian', 'Person_sitting', 'Cyclist', 'Tram', 'Misc', 'DontCare']\n",
"NUM_CLASSES = len(LABELS)\n",
"LABEL_COLORS = 256 * np.array([plt.get_cmap(\"tab10\")(i) for i in range(NUM_CLASSES)], dtype=np.float32)\n",
"\n",
"\n",
"#%%\n",
"# util functions for tfds kitti datasets\n",
"def take_kitti_example(split_type: str,\n",
" to_one_hot: bool = True,\n",
" batch: int = 0) -> Tuple[List[tf.Tensor], List[tf.Tensor], List[tf.Tensor]]:\n",
" images, bboxes, classes = [], [], [] # type: List[tf.Tensor], List[tf.Tensor], List[tf.Tensor]\n",
" for example in kitti[split_type].shuffle(10).take(max(batch, 1)):\n",
" image, bbox, klass = example[\"image\"], example[\"objects\"][\"bbox\"], example[\"objects\"][\"type\"]\n",
"\n",
" images.append(tf.cast(image, tf.float32))\n",
" classes.append(tf.one_hot(klass, depth=len(LABELS)) if to_one_hot else klass)\n",
"\n",
" # bounding box order of kitti dataset differ from tensorflow manner\n",
" bbox = bbox.numpy()\n",
" ymin, ymax = np.array(bbox[:, 0]), np.array(bbox[:, 2])\n",
" bbox[:, 0] = 1.0 - ymax\n",
" bbox[:, 2] = 1.0 - ymin\n",
" bboxes.append(tf.convert_to_tensor(bbox))\n",
"\n",
" if batch < 1:\n",
" return images[0], bboxes[0], classes[0]\n",
" else:\n",
" return images, bboxes, classes\n",
"\n",
"\n",
"def render_kitti(image: tf.Tensor, boxes: tf.Tensor, classes: tf.Tensor) -> Image:\n",
" assert tf.rank(image) == 3 and tf.rank(boxes) == 2\n",
" classes_np = classes.numpy() if tf.rank(classes) == 1 else np.argmax(classes.numpy(), axis=-1)\n",
" colors = tf.convert_to_tensor(LABEL_COLORS[classes_np.astype(\"int\")])\n",
" rendered_image = tf.image.draw_bounding_boxes(tf.expand_dims(image, axis=0), tf.expand_dims(boxes, axis=0),\n",
" colors)[0]\n",
"\n",
" return Image.fromarray(rendered_image.numpy().astype(np.uint8))\n",
"\n",
"\n",
"img, bbox, klass = take_kitti_example(\"train\", batch=0)\n",
"display(render_kitti(img, bbox, klass))\n",
"\n",
"# %% Load Model\n",
"print(\"Start building and restoring model\")\n",
"\n",
"\n",
"# setup model configuration;\n",
"configs = config_util.get_configs_from_pipeline_file(pipeline_config)\n",
"model_config = configs[\"model\"]\n",
"model_config.ssd.freeze_batchnorm = False\n",
"model_config.ssd.num_classes = NUM_CLASSES\n",
"\n",
"# build model with configuration and run once\n",
"detection_model = model_builder.build(model_config=model_config, is_training=True)\n",
"image, shapes = detection_model.preprocess(tf.zeros([1, 640, 640, 3]))\n",
"detection_model.postprocess(detection_model.predict(image, shapes), shapes)\n",
"\n",
"# Load model\n",
"dammy_model = tf.train.Checkpoint(\n",
" _feature_extractor=detection_model._feature_extractor,\n",
" _box_predictor=tf.train.Checkpoint(\n",
" _base_tower_layer_for_heads=detection_model._box_predictor._base_tower_layers_for_heads,\n",
" _box_prediction_head=detection_model._box_predictor._box_prediction_head))\n",
"\n",
"ckpt = tf.compat.v2.train.Checkpoint(model=dammy_model)\n",
"ckpt.restore(checkpoint_path).expect_partial()\n",
"\n",
"print(\"Weights restored!\")\n",
"\n",
"#%% Define training step\n",
"\n",
"train_var_prefixes = [\n",
" 'WeightSharedConvolutionalBoxPredictor/WeightSharedConvolutionalBoxHead',\n",
" 'WeightSharedConvolutionalBoxPredictor/BoxPredictionTower',\n",
" 'WeightSharedConvolutionalBoxPredictor/WeightSharedConvolutionalClassHead',\n",
" 'WeightSharedConvolutionalBoxPredictor/ClassPredictionTower',\n",
"]\n",
"\n",
"train_vars = [\n",
" var for var in detection_model.trainable_variables\n",
" if any([var.name.startswith(prefix) for prefix in train_var_prefixes])\n",
"]\n",
"print(\"train vars = \")\n",
"pp.pprint([f\"{v.name}, shape = {v.get_shape().as_list()}\" for v in train_vars])\n",
"print(\"number of parameters = \", np.sum([np.prod(v.get_shape().as_list()) for v in train_vars]))\n",
"\n",
"optimizer = tf.keras.optimizers.SGD(\n",
" learning_rate=tf.keras.optimizers.schedules.InverseTimeDecay(0.01, 1000, 0.5),\n",
" momentum=0.9)\n",
"\n",
"\n",
"def train_step_fn(image_tensors, gt_box_list, gt_class_list):\n",
" shapes = tf.constant(len(image_tensors) * [[640, 640, 3]], dtype=tf.int32)\n",
" detection_model.provide_groundtruth(groundtruth_boxes_list=gt_box_list, groundtruth_classes_list=gt_class_list)\n",
" with tf.GradientTape() as tape:\n",
" preprocessed_image = tf.concat([detection_model.preprocess(tensor)[0] for tensor in image_tensors], axis=0)\n",
" prediction_dict = detection_model.predict(preprocessed_image, shapes)\n",
" losses_dict = detection_model.loss(prediction_dict, shapes)\n",
" total_loss = losses_dict['Loss/localization_loss'] + losses_dict['Loss/classification_loss']\n",
" gradients = tape.gradient(total_loss, train_vars)\n",
" optimizer.apply_gradients(zip(gradients, train_vars))\n",
" return losses_dict\n",
"\n",
"\n",
"#%% Inference\n",
"def run_inference(model, image_tensor: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:\n",
" preprocessed_image, shape = model.preprocess(tf.expand_dims(image_tensor, axis=0))\n",
" prediction_dict = model.predict(preprocessed_image, shape)\n",
" #print(prediction_dict[\"class_predictions_with_background\"])\n",
" detection_result = model.postprocess(prediction_dict, shape)\n",
" return detection_result[\"detection_boxes\"][0, :], detection_result[\"detection_classes\"][\n",
" 0, :], detection_result[\"detection_scores\"][0, :]\n",
"\n",
"\n",
"img, box, klass = take_kitti_example(\"validation\", batch=0)\n",
"det_boxes, det_classes, det_scores = run_inference(detection_model, img)\n",
"display(render_kitti(img, det_boxes[:10], det_classes[:10]))\n",
"\n"
],
"execution_count": 2,
"outputs": [
{
"output_type": "display_data",
"data": {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment