Skip to content

Instantly share code, notes, and snippets.

@moarshy
Created March 23, 2022 07:32
Show Gist options
  • Save moarshy/0b1edde8afd538e5073fb771b2753315 to your computer and use it in GitHub Desktop.
Save moarshy/0b1edde8afd538e5073fb771b2753315 to your computer and use it in GitHub Desktop.
MNUnetModel.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "MNUnetModel.ipynb",
"provenance": [],
"collapsed_sections": [
"nl_Y1QW9OsZR"
],
"authorship_tag": "ABX9TyNFaFW9sZHD/tLhto6ay9vQ",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"8eb841b4c62b4042bd9a5af15342c012": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_dd011ba65aa04664bd63eb8f79ca29c8",
"IPY_MODEL_46490715f4934745aa454ae8e119f40d",
"IPY_MODEL_f7444e1bb42a4c6f913e05514445ce59"
],
"layout": "IPY_MODEL_cd7643d9744340ec99319a52c2faf2c9"
}
},
"dd011ba65aa04664bd63eb8f79ca29c8": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_a1860ff990b042a5b50f87066e30bb9b",
"placeholder": "",
"style": "IPY_MODEL_18da3006e55d432d9038da4f74558e4f",
"value": "100%"
}
},
"46490715f4934745aa454ae8e119f40d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_5ca90a8d4e0c4be3888dbccafe6aafce",
"max": 178793939,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_56aa6084019b49b1ad09c0f1aaabdf01",
"value": 178793939
}
},
"f7444e1bb42a4c6f913e05514445ce59": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_e43806d23dad4c86afb2428ce56504b3",
"placeholder": "",
"style": "IPY_MODEL_4eec62679381415ab632253d8f043108",
"value": " 171M/171M [00:03<00:00, 44.6MB/s]"
}
},
"cd7643d9744340ec99319a52c2faf2c9": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"a1860ff990b042a5b50f87066e30bb9b": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"18da3006e55d432d9038da4f74558e4f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"5ca90a8d4e0c4be3888dbccafe6aafce": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"56aa6084019b49b1ad09c0f1aaabdf01": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"e43806d23dad4c86afb2428ce56504b3": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"4eec62679381415ab632253d8f043108": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/moarshy/0b1edde8afd538e5073fb771b2753315/mnunetmodel.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"!pip install timm fastai -Uqq"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "cJ_Y_aOreCYt",
"outputId": "bdc1576e-a1fa-4618-9123-e69fd9b926be"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[K |████████████████████████████████| 431 kB 4.0 MB/s \n",
"\u001b[K |████████████████████████████████| 189 kB 45.0 MB/s \n",
"\u001b[K |████████████████████████████████| 55 kB 3.9 MB/s \n",
"\u001b[?25h"
]
}
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"id": "ynxvBvjJMH8F"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"import timm\n",
"from timm import create_model\n",
"from timm.models.efficientnet_blocks import DepthwiseSeparableConv\n",
"\n",
"from fastai.vision.all import *\n",
"from fastai.callback.all import *"
]
},
{
"cell_type": "markdown",
"source": [
"In this notebook,\n",
"- [x] Implement this [paper](https://openaccess.thecvf.com/content/CVPR2021W/MAI/papers/Zhang_A_Simple_Baseline_for_Fast_and_Accurate_Depth_Estimation_on_CVPRW_2021_paper.pdf)\n",
"- [x] Train using knowledge distillation\n"
],
"metadata": {
"id": "QBoa9XfUjYsV"
}
},
{
"cell_type": "markdown",
"source": [
"**The paper abstract**\n",
"\n",
"![image.png]() "
],
"metadata": {
"id": "UseLzzkckySY"
}
},
{
"cell_type": "markdown",
"source": [
"**The architecture**\n",
"![image.png]()"
],
"metadata": {
"id": "k5lA8zzHk-z9"
}
},
{
"cell_type": "code",
"source": [
"# Some codes are based off https://gist.github.com/rwightman/f8b24f4e6f5504aba03e999e02460d31\n",
"\n",
"class Conv2dBnAct(nn.Module):\n",
" def __init__(self, \n",
" in_channels, \n",
" out_channels, \n",
" kernel_size, \n",
" padding=0,\n",
" stride=1, \n",
" act_layer=nn.ReLU, \n",
" norm_layer=nn.BatchNorm2d\n",
" ):\n",
" super().__init__()\n",
" self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False)\n",
" self.bn = norm_layer(out_channels)\n",
" self.act = act_layer(inplace=True)\n",
"\n",
" def forward(self, x):\n",
" x = self.conv(x)\n",
" x = self.bn(x)\n",
" x = self.act(x)\n",
" return x\n",
"\n",
"\n",
"class FeatureFusionModule(nn.Module):\n",
" def __init__(self, \n",
" enc_in_channels, # Encoder in channels\n",
" enc_out_channels, # Encoder out channels\n",
" dec_in_channels, # Decoder in channels\n",
" out_channels, # Final out channels\n",
" ):\n",
" super().__init__()\n",
" #encoderoutput\n",
" self.enc_conv1 = nn.Conv2d(enc_in_channels, enc_out_channels, kernel_size=1, stride=1, padding='same')\n",
" self.enc_up = nn.ConvTranspose2d(enc_out_channels, enc_out_channels, kernel_size=1)\n",
" self.enc_dconv = DepthwiseSeparableConv(enc_out_channels, enc_out_channels)\n",
" self.enc_conv2 = nn.Conv2d(enc_out_channels, enc_out_channels, kernel_size=1, stride=1, padding='same')\n",
" \n",
" #decoderoutput\n",
" self.dec_dconv = DepthwiseSeparableConv(enc_out_channels+dec_in_channels, enc_out_channels+dec_in_channels)\n",
" self.dec_conv1 = nn.Conv2d(enc_out_channels+dec_in_channels, out_channels, kernel_size=1, stride=1, padding='same')\n",
"\n",
"\n",
" def forward(self, enc_x, dec_x):\n",
" enc_x = self.enc_conv1(enc_x)\n",
" enc_x = self.enc_up(enc_x)\n",
" enc_x = self.enc_dconv(enc_x)\n",
" enc_x = self.enc_conv2(enc_x)\n",
" \n",
" x = torch.cat([enc_x, dec_x], dim=1)\n",
"\n",
" dec_x = self.dec_dconv(x)\n",
" dec_x = self.dec_conv1(dec_x)\n",
" \n",
" return dec_x\n",
"\n",
"\n",
"class DecoderBlock(nn.Module):\n",
" def __init__(self, \n",
" enc_channels,\n",
" dec_prev_channels, \n",
" dec_channels,\n",
" act_layer=nn.ReLU, \n",
" norm_layer=nn.BatchNorm2d,\n",
" ffm=True,\n",
" ):\n",
" super().__init__()\n",
" conv_args = dict(kernel_size=3, padding=1, act_layer=act_layer)\n",
" self.ffm = ffm\n",
" \n",
" if ffm:\n",
" self.ffm = FeatureFusionModule(enc_channels, enc_channels, dec_prev_channels, dec_channels)\n",
"\n",
" self.conv1 = Conv2dBnAct(enc_channels, dec_channels, norm_layer=norm_layer, **conv_args)\n",
" self.conv2 = Conv2dBnAct(dec_channels, dec_channels, norm_layer=norm_layer, **conv_args)\n",
" \n",
"\n",
" def forward(self, x_enc, x_dec):\n",
" if self.ffm:\n",
" x = self.ffm(x_enc, x_dec)\n",
"\n",
" x = F.interpolate(x_enc, scale_factor=2, mode='nearest')\n",
"\n",
" x = self.conv1(x)\n",
" x = self.conv2(x)\n",
"\n",
" return x\n",
"\n",
"\n",
"class UnetDecoder(nn.Module):\n",
"\n",
" def __init__(self,\n",
" encoder_channels,\n",
" decoder_channels=(256, 128, 64, 32, 16),\n",
" final_channels=3,\n",
" norm_layer=nn.BatchNorm2d,\n",
" ):\n",
" super().__init__()\n",
"\n",
" self.decoders = nn.ModuleList()\n",
" for i, (e_ch, d_ch) in enumerate(zip(encoder_channels, decoder_channels)):\n",
" if i== 0:\n",
" self.decoders.append(DecoderBlock(enc_channels=e_ch, \n",
" dec_prev_channels=None, \n",
" dec_channels=d_ch,\n",
" act_layer=nn.ReLU, \n",
" norm_layer=nn.BatchNorm2d,\n",
" ffm=False,\n",
" ))\n",
"\n",
" else:\n",
" self.decoders.append(DecoderBlock(enc_channels=e_ch, \n",
" dec_prev_channels=decoder_channels[i-1], \n",
" dec_channels=d_ch,\n",
" act_layer=nn.ReLU, \n",
" norm_layer=nn.BatchNorm2d,\n",
" ffm=True,\n",
" ))\n",
"\n",
" self.final_conv = nn.Conv2d(decoder_channels[-1], final_channels, kernel_size=(1, 1))\n",
" self.tensor_base = ToTensorBase()\n",
" self._init_weight()\n",
"\n",
" def _init_weight(self):\n",
" for m in self.modules():\n",
" if isinstance(m, nn.Conv2d):\n",
" torch.nn.init.kaiming_normal_(m.weight)\n",
" elif isinstance(m, nn.BatchNorm2d):\n",
" m.weight.data.fill_(1)\n",
" m.bias.data.zero_()\n",
"\n",
"\n",
" def forward(self, x):\n",
" enc_outs_r = x\n",
" dec_out = None\n",
" for i, each in enumerate(self.decoders):\n",
" dec_out = each(enc_outs_r[i], dec_out)\n",
" x = self.final_conv(dec_out)\n",
" x = self.tensor_base(x)\n",
" return x\n",
"\n",
"\n",
"class Unet(nn.Module):\n",
" def __init__(self,\n",
" backbone='resnet50',\n",
" backbone_kwargs=None,\n",
" backbone_indices=None,\n",
" decoder_use_batchnorm=True,\n",
" decoder_channels=(256, 128, 64, 32, 16),\n",
" in_chans=3,\n",
" num_classes=3,\n",
" norm_layer=nn.BatchNorm2d,\n",
" pretrained=True,\n",
" ):\n",
" super().__init__()\n",
" backbone_kwargs = backbone_kwargs or {}\n",
" # NOTE some models need different backbone indices specified based on the alignment of features\n",
" # and some models won't have a full enough range of feature strides to work properly.\n",
" encoder = create_model(\n",
" backbone, features_only=True, out_indices=backbone_indices, in_chans=in_chans,\n",
" pretrained=pretrained, **backbone_kwargs)\n",
" encoder_channels = encoder.feature_info.channels()[::-1]\n",
" self.encoder = encoder\n",
"\n",
" self.decoder = UnetDecoder(\n",
" encoder_channels=encoder_channels,\n",
" decoder_channels=decoder_channels,\n",
" final_channels=num_classes,\n",
" norm_layer=norm_layer,\n",
" )\n",
"\n",
" def forward(self, x: torch.Tensor):\n",
" x = self.encoder(x)\n",
" x.reverse() # torchscript doesn't work with [::-1]\n",
" x = self.decoder(x)\n",
" return x"
],
"metadata": {
"id": "t0IVmDv-BZsN"
},
"execution_count": 24,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Test individuals"
],
"metadata": {
"id": "nl_Y1QW9OsZR"
}
},
{
"cell_type": "code",
"source": [
"model = Unet('mobilenetv3_rw')"
],
"metadata": {
"id": "acc4eR-HNati"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"o = model(torch.randn(2,3,128,160))"
],
"metadata": {
"id": "JMxOElmXNyfJ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"o.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "s5ITHIt_N3Za",
"outputId": "985af779-f5a7-4c0e-aded-398b5cce175f"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Size([2, 3, 128, 160])"
]
},
"metadata": {},
"execution_count": 10
}
]
},
{
"cell_type": "code",
"source": [
"encoder = create_model(\n",
" 'mobilenetv3_rw', features_only=True, out_indices=None, in_chans=3,\n",
" pretrained=True,)"
],
"metadata": {
"id": "d6nwuci0vEJi"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"enc_outs = encoder(torch.randn(2, 3, 128, 160))\n",
"enc_outs_r = enc_outs[::-1]"
],
"metadata": {
"id": "cpIXAnerTZ_v"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"for e in enc_outs:\n",
" print(e.shape)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "oAE8LNQQTqwG",
"outputId": "512ff956-62a4-4bd1-ad36-25c4bd418ef7"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"torch.Size([2, 16, 64, 80])\n",
"torch.Size([2, 24, 32, 40])\n",
"torch.Size([2, 40, 16, 20])\n",
"torch.Size([2, 112, 8, 10])\n",
"torch.Size([2, 960, 4, 5])\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"enc_channels = encoder.feature_info.channels()[::-1]; enc_channels"
],
"metadata": {
"id": "YhvGiHYNNgcH",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "acf5e4d3-7f1e-4332-a7ad-08ae84ee10ef"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[960, 112, 40, 24, 16]"
]
},
"metadata": {},
"execution_count": 159
}
]
},
{
"cell_type": "code",
"source": [
"dec_channels = [256, 128, 64, 32, 16]"
],
"metadata": {
"id": "yOoZwawcNvaD"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"dec1 = DecoderBlock(enc_channels[0], \n",
" None, \n",
" dec_channels[0],\n",
" ffm=False)"
],
"metadata": {
"id": "oZLm9YwkODQv"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"dec1_out = dec1(enc_outs_r[0], None)"
],
"metadata": {
"id": "3h-2jsN2OVNO"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"dec1_out.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "oE-X8RKNOl8e",
"outputId": "51ccf30e-782d-462f-fcfc-28a0b8347e7b"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Size([2, 256, 8, 10])"
]
},
"metadata": {},
"execution_count": 163
}
]
},
{
"cell_type": "code",
"source": [
"dec2 = DecoderBlock(enc_channels[1], \n",
" dec_channels[1-1], \n",
" dec_channels[1],\n",
" ffm=True)"
],
"metadata": {
"id": "uchSrgNlwUg5"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"dec2_out = dec2(enc_outs_r[1], dec1_out)"
],
"metadata": {
"id": "PTGDO0ooweXJ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"dec2_out.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2iqQyBVvwpFx",
"outputId": "3c1e185f-2d82-49ae-92b8-64d52000c875"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Size([2, 128, 16, 20])"
]
},
"metadata": {},
"execution_count": 154
}
]
},
{
"cell_type": "code",
"source": [
"dec3 = DecoderBlock(enc_channels[2], \n",
" dec_channels[2-1], \n",
" dec_channels[2],\n",
" ffm=True)"
],
"metadata": {
"id": "YI38w3tCwq9K"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"dec3_out = dec3(enc_outs_r[2], dec2_out)"
],
"metadata": {
"id": "Id5PcFlDw4M6"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"dec3_out.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "QR3fRqsh1Hrx",
"outputId": "76152af4-972f-45c9-c5d1-dcc795edec2a"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Size([2, 64, 32, 40])"
]
},
"metadata": {},
"execution_count": 169
}
]
},
{
"cell_type": "code",
"source": [
"dec4 = DecoderBlock(enc_channels[3], \n",
" dec_channels[3-1], \n",
" dec_channels[3],\n",
" ffm=True)"
],
"metadata": {
"id": "ytatkui20UNi"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"dec4_out = dec4(enc_outs_r[3], dec3_out)"
],
"metadata": {
"id": "0YxQrVnB0dDl"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"dec4_out.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2G3Uht6T1Kbi",
"outputId": "7df41391-e29d-4c2f-ffa9-439f7033717d"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Size([2, 32, 64, 80])"
]
},
"metadata": {},
"execution_count": 172
}
]
},
{
"cell_type": "code",
"source": [
"dec5 = DecoderBlock(enc_channels[4], \n",
" dec_channels[4-1], \n",
" dec_channels[4],\n",
" ffm=True)"
],
"metadata": {
"id": "dITgOOl81QcL"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"dec5_out = dec5(enc_outs_r[4], dec4_out)"
],
"metadata": {
"id": "LqHdur5d1WCI"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"dec5_out.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "s1SuPses1dKL",
"outputId": "3c23592b-b428-43df-ea7c-6335f7ddd58b"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Size([2, 16, 128, 160])"
]
},
"metadata": {},
"execution_count": 175
}
]
},
{
"cell_type": "code",
"source": [
""
],
"metadata": {
"id": "VzKVZSeTPXx0"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"dec = UnetDecoder(enc_channels)"
],
"metadata": {
"id": "HCu10Ce_jhrj"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"x = dec(enc_outs_r)"
],
"metadata": {
"id": "QSZ27lM3jr7L"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
""
],
"metadata": {
"id": "9QPdsDft2SrY"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Test Unet training"
],
"metadata": {
"id": "axnPMcaJOvGp"
}
},
{
"cell_type": "code",
"source": [
"path = untar_data(URLs.CAMVID)\n",
"path.ls()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "QJjxTv5xOxNF",
"outputId": "a72a4983-fd95-44e6-dd81-b3eddd77291b"
},
"execution_count": 25,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(#4) [Path('/root/.fastai/data/camvid/codes.txt'),Path('/root/.fastai/data/camvid/images'),Path('/root/.fastai/data/camvid/valid.txt'),Path('/root/.fastai/data/camvid/labels')]"
]
},
"metadata": {},
"execution_count": 25
}
]
},
{
"cell_type": "code",
"source": [
"codes = np.loadtxt(path/'codes.txt', dtype=str)\n",
"codes"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PCkr63ykO-LR",
"outputId": "b1c7586b-9e84-4d89-9c19-7cf98384e74a"
},
"execution_count": 26,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array(['Animal', 'Archway', 'Bicyclist', 'Bridge', 'Building', 'Car',\n",
" 'CartLuggagePram', 'Child', 'Column_Pole', 'Fence', 'LaneMkgsDriv',\n",
" 'LaneMkgsNonDriv', 'Misc_Text', 'MotorcycleScooter', 'OtherMoving',\n",
" 'ParkingBlock', 'Pedestrian', 'Road', 'RoadShoulder', 'Sidewalk',\n",
" 'SignSymbol', 'Sky', 'SUVPickupTruck', 'TrafficCone',\n",
" 'TrafficLight', 'Train', 'Tree', 'Truck_Bus', 'Tunnel',\n",
" 'VegetationMisc', 'Void', 'Wall'], dtype='<U17')"
]
},
"metadata": {},
"execution_count": 26
}
]
},
{
"cell_type": "code",
"source": [
"fnames = get_image_files(path/\"images\")"
],
"metadata": {
"id": "dKQMJJ80PJwY"
},
"execution_count": 27,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def label_func(fn): return path/\"labels\"/f\"{fn.stem}_P{fn.suffix}\""
],
"metadata": {
"id": "9SZzi4m6PMpS"
},
"execution_count": 28,
"outputs": []
},
{
"cell_type": "code",
"source": [
"name2id = {v:k for k,v in enumerate(codes)}\n",
"void_code = name2id['Void']\n",
"def acc_camvid(inp, targ):\n",
" targ = targ.squeeze(1)\n",
" mask = targ != void_code\n",
" return np.mean(inp.argmax(dim=1)[mask].cpu().numpy()==targ[mask].cpu().numpy())"
],
"metadata": {
"id": "lU6X1jDGQExa"
},
"execution_count": 29,
"outputs": []
},
{
"cell_type": "code",
"source": [
"dls = SegmentationDataLoaders.from_label_func(path, \n",
" bs=8, \n",
" fnames = fnames, \n",
" label_func = label_func, \n",
" codes = codes,\n",
" item_tfms=Resize((128, 160)))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bj7NRyDCPNjl",
"outputId": "939b58d5-9db2-429e-dc88-c12f2a1220c3"
},
"execution_count": 30,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1051: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
" ret = func(*args, **kwargs)\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## with Dice Loss"
],
"metadata": {
"id": "vWuMb0ExqbSM"
}
},
{
"cell_type": "code",
"source": [
"model = Unet('mobilenetv3_rw',\n",
" num_classes=32)\n",
"\n",
"learn = Learner(dls, \n",
" model,\n",
" loss_func=DiceLoss(),\n",
" metrics=acc_camvid)"
],
"metadata": {
"id": "Utq5Utm0PQA0",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "2b494288-aaf0-4463-a33d-86ace3ea2ce9"
},
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Downloading: \"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth\" to /root/.cache/torch/hub/checkpoints/mobilenetv3_100-35495452.pth\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"learn.fit_one_cycle(10, 1e-3)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 418
},
"id": "RQkxMCbWQTgR",
"outputId": "6d31edfb-ed01-4f7c-a434-8fac551ec433"
},
"execution_count": 13,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>acc_camvid</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>252.215271</td>\n",
" <td>246.689651</td>\n",
" <td>0.183169</td>\n",
" <td>00:28</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>245.648544</td>\n",
" <td>237.859009</td>\n",
" <td>0.438296</td>\n",
" <td>00:26</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>240.166595</td>\n",
" <td>233.821426</td>\n",
" <td>0.527989</td>\n",
" <td>00:31</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>236.032440</td>\n",
" <td>230.155685</td>\n",
" <td>0.579975</td>\n",
" <td>00:26</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>232.820297</td>\n",
" <td>227.664886</td>\n",
" <td>0.598176</td>\n",
" <td>00:26</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>230.536667</td>\n",
" <td>226.120468</td>\n",
" <td>0.612691</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>229.252609</td>\n",
" <td>225.270096</td>\n",
" <td>0.618059</td>\n",
" <td>00:26</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>228.376007</td>\n",
" <td>224.750214</td>\n",
" <td>0.621125</td>\n",
" <td>00:26</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>227.940201</td>\n",
" <td>224.663452</td>\n",
" <td>0.622177</td>\n",
" <td>00:26</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>227.806488</td>\n",
" <td>224.604568</td>\n",
" <td>0.624282</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1051: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
" ret = func(*args, **kwargs)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"learn.summary()"
],
"metadata": {
"id": "L53Vk3fCZKDE"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## with CrossEntropy loss"
],
"metadata": {
"id": "tF_pW798qdgI"
}
},
{
"cell_type": "code",
"source": [
"model = Unet('mobilenetv3_rw',\n",
" num_classes=32)\n",
"\n",
"learn = Learner(dls, \n",
" model,\n",
" metrics=acc_camvid)"
],
"metadata": {
"id": "ZSLhn68ZmkmZ"
},
"execution_count": 31,
"outputs": []
},
{
"cell_type": "code",
"source": [
"learn.fit_one_cycle(25, 1e-3)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 888
},
"id": "qLJF8V6zmrs3",
"outputId": "5adf8515-9496-4f3a-be10-2ccbe7952a43"
},
"execution_count": 32,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>acc_camvid</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>3.952452</td>\n",
" <td>3.837103</td>\n",
" <td>0.009524</td>\n",
" <td>00:33</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>3.651547</td>\n",
" <td>3.323963</td>\n",
" <td>0.034026</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>3.141333</td>\n",
" <td>2.725803</td>\n",
" <td>0.219234</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>2.604731</td>\n",
" <td>2.217971</td>\n",
" <td>0.475758</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>2.168982</td>\n",
" <td>1.879862</td>\n",
" <td>0.516642</td>\n",
" <td>00:29</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>1.846897</td>\n",
" <td>1.635621</td>\n",
" <td>0.553690</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>1.600007</td>\n",
" <td>1.454400</td>\n",
" <td>0.619151</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>1.441168</td>\n",
" <td>1.355009</td>\n",
" <td>0.629975</td>\n",
" <td>00:28</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>1.343089</td>\n",
" <td>1.295351</td>\n",
" <td>0.633226</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>1.287046</td>\n",
" <td>1.254252</td>\n",
" <td>0.639729</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>1.246888</td>\n",
" <td>1.224334</td>\n",
" <td>0.643034</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>1.220407</td>\n",
" <td>1.203212</td>\n",
" <td>0.646837</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>1.192617</td>\n",
" <td>1.186723</td>\n",
" <td>0.649282</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>1.178188</td>\n",
" <td>1.172160</td>\n",
" <td>0.654529</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>1.167311</td>\n",
" <td>1.163162</td>\n",
" <td>0.654595</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>1.166651</td>\n",
" <td>1.157911</td>\n",
" <td>0.654737</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16</td>\n",
" <td>1.148694</td>\n",
" <td>1.148224</td>\n",
" <td>0.660711</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17</td>\n",
" <td>1.145139</td>\n",
" <td>1.146453</td>\n",
" <td>0.659014</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>1.135504</td>\n",
" <td>1.138023</td>\n",
" <td>0.661375</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>1.139854</td>\n",
" <td>1.135780</td>\n",
" <td>0.661414</td>\n",
" <td>00:28</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>1.132565</td>\n",
" <td>1.136258</td>\n",
" <td>0.660869</td>\n",
" <td>00:28</td>\n",
" </tr>\n",
" <tr>\n",
" <td>21</td>\n",
" <td>1.126904</td>\n",
" <td>1.133178</td>\n",
" <td>0.661520</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" <tr>\n",
" <td>22</td>\n",
" <td>1.126309</td>\n",
" <td>1.132679</td>\n",
" <td>0.663100</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" <tr>\n",
" <td>23</td>\n",
" <td>1.129516</td>\n",
" <td>1.131822</td>\n",
" <td>0.663083</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" <tr>\n",
" <td>24</td>\n",
" <td>1.132021</td>\n",
" <td>1.131812</td>\n",
" <td>0.662620</td>\n",
" <td>00:27</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1051: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
" ret = func(*args, **kwargs)\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# fastai unet-resnet34 performance"
],
"metadata": {
"id": "ZlkdkgOAjTGP"
}
},
{
"cell_type": "code",
"source": [
"teacher_learn = unet_learner(dls, \n",
" resnet101, \n",
" metrics=acc_camvid)"
],
"metadata": {
"id": "sy8eP1xjbI4h",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 121,
"referenced_widgets": [
"8eb841b4c62b4042bd9a5af15342c012",
"dd011ba65aa04664bd63eb8f79ca29c8",
"46490715f4934745aa454ae8e119f40d",
"f7444e1bb42a4c6f913e05514445ce59",
"cd7643d9744340ec99319a52c2faf2c9",
"a1860ff990b042a5b50f87066e30bb9b",
"18da3006e55d432d9038da4f74558e4f",
"5ca90a8d4e0c4be3888dbccafe6aafce",
"56aa6084019b49b1ad09c0f1aaabdf01",
"e43806d23dad4c86afb2428ce56504b3",
"4eec62679381415ab632253d8f043108"
]
},
"outputId": "f3d18777-eb58-4522-e37d-04532869eb1c"
},
"execution_count": 34,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1051: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
" ret = func(*args, **kwargs)\n",
"Downloading: \"https://download.pytorch.org/models/resnet101-63fe2227.pth\" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
" 0%| | 0.00/171M [00:00<?, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "8eb841b4c62b4042bd9a5af15342c012"
}
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"teacher_learn.fit_one_cycle(10, 1e-3)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 418
},
"id": "ojFW8MxuouxA",
"outputId": "06e9580e-f19b-4a31-ce60-63b522b03020"
},
"execution_count": 35,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>acc_camvid</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.079316</td>\n",
" <td>1.945745</td>\n",
" <td>0.791928</td>\n",
" <td>02:54</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.672722</td>\n",
" <td>3.414593</td>\n",
" <td>0.467281</td>\n",
" <td>02:33</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>2.836250</td>\n",
" <td>1.068983</td>\n",
" <td>0.708052</td>\n",
" <td>02:31</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>1.262702</td>\n",
" <td>1.077650</td>\n",
" <td>0.819915</td>\n",
" <td>02:31</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.755539</td>\n",
" <td>0.741171</td>\n",
" <td>0.839445</td>\n",
" <td>02:30</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>0.586151</td>\n",
" <td>0.615320</td>\n",
" <td>0.857607</td>\n",
" <td>02:31</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>0.495401</td>\n",
" <td>0.540893</td>\n",
" <td>0.869873</td>\n",
" <td>02:31</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>0.429804</td>\n",
" <td>0.520235</td>\n",
" <td>0.879431</td>\n",
" <td>02:31</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>0.395290</td>\n",
" <td>0.510423</td>\n",
" <td>0.879006</td>\n",
" <td>02:31</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>0.374698</td>\n",
" <td>0.482728</td>\n",
" <td>0.882320</td>\n",
" <td>02:31</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1051: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
" ret = func(*args, **kwargs)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"teacher_learn.loss_func"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "_ZMICwZzo0MN",
"outputId": "659ea2c1-8cc6-4aaa-8776-95e17d8c9e02"
},
"execution_count": 36,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"FlattenedLoss of CrossEntropyLoss()"
]
},
"metadata": {},
"execution_count": 36
}
]
},
{
"cell_type": "markdown",
"source": [
"# Knowledge Distillation"
],
"metadata": {
"id": "IVvuN2gEqjv6"
}
},
{
"cell_type": "code",
"source": [
"class DistillationLoss(nn.Module):\n",
" def __init__(self):\n",
" super(DistillationLoss, self).__init__()\n",
" self.distillation_loss = nn.KLDivLoss(reduction='batchmean')\n",
" \n",
" def forward(self,\n",
" student_preds, \n",
" teacher_preds, \n",
" acutal_target, \n",
" T, \n",
" alpha\n",
" ):\n",
"\n",
" return self.distillation_loss(F.softmax(student_preds / T, dim=1).reshape(-1),\n",
" F.softmax(teacher_preds / T, dim=1).reshape(-1))\n",
" \n",
"\n",
"\n",
"class KnowledgeDistillation(Callback):\n",
" def __init__(self, \n",
" teacher:Learner, \n",
" T:float=20., \n",
" a:float=0.7):\n",
" super(KnowledgeDistillation, self).__init__()\n",
" self.teacher = teacher\n",
" self.T, self.a = T, a\n",
" self.distillation_loss = DistillationLoss()\n",
" \n",
" def after_loss(self):\n",
" teacher_preds = self.teacher.model(self.learn.xb[0])\n",
" student_loss = self.learn.loss_grad * self.a\n",
" distillation_loss = self.distillation_loss(self.learn.pred, # Student preds\n",
" teacher_preds, # Teacher preds\n",
" self.learn.yb, # Ground truth\n",
" self.T, \n",
" self.a) * (1 - self.a)\n",
" self.learn.loss_grad = student_loss + distillation_loss"
],
"metadata": {
"id": "f1IO5Vw4tPXX"
},
"execution_count": 37,
"outputs": []
},
{
"cell_type": "code",
"source": [
"model = Unet('mobilenetv3_rw',\n",
" num_classes=32)\n",
"\n",
"student_learn = Learner(dls, \n",
" model,\n",
" metrics=acc_camvid,\n",
" cbs=[KnowledgeDistillation(teacher=teacher_learn)])"
],
"metadata": {
"id": "Si1Zl_2uy43B"
},
"execution_count": 38,
"outputs": []
},
{
"cell_type": "code",
"source": [
"student_learn.fit_one_cycle(25, 1e-3)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 888
},
"id": "QjXcKthW0B4t",
"outputId": "3aaf9f7b-5598-433e-8d36-f20d44a06b57"
},
"execution_count": 39,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"\n",
"<style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
"</style>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.HTML object>"
],
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>acc_camvid</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>4.149306</td>\n",
" <td>4.083494</td>\n",
" <td>0.012385</td>\n",
" <td>02:23</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>3.843522</td>\n",
" <td>3.565912</td>\n",
" <td>0.096796</td>\n",
" <td>02:23</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>3.263084</td>\n",
" <td>2.797809</td>\n",
" <td>0.369143</td>\n",
" <td>02:23</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>2.626706</td>\n",
" <td>2.134217</td>\n",
" <td>0.512857</td>\n",
" <td>02:24</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>2.120926</td>\n",
" <td>1.793929</td>\n",
" <td>0.577740</td>\n",
" <td>02:23</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>1.782979</td>\n",
" <td>1.582902</td>\n",
" <td>0.600909</td>\n",
" <td>02:23</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>1.565306</td>\n",
" <td>1.450105</td>\n",
" <td>0.618120</td>\n",
" <td>02:23</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>1.431738</td>\n",
" <td>1.363392</td>\n",
" <td>0.626684</td>\n",
" <td>02:23</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>1.350202</td>\n",
" <td>1.306182</td>\n",
" <td>0.632427</td>\n",
" <td>02:23</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>1.291344</td>\n",
" <td>1.265584</td>\n",
" <td>0.640540</td>\n",
" <td>02:23</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>1.253604</td>\n",
" <td>1.234486</td>\n",
" <td>0.646430</td>\n",
" <td>02:24</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>1.225404</td>\n",
" <td>1.220467</td>\n",
" <td>0.650743</td>\n",
" <td>02:24</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>1.201750</td>\n",
" <td>1.195458</td>\n",
" <td>0.652415</td>\n",
" <td>02:24</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>1.185998</td>\n",
" <td>1.175060</td>\n",
" <td>0.658728</td>\n",
" <td>02:24</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>1.171920</td>\n",
" <td>1.162316</td>\n",
" <td>0.664814</td>\n",
" <td>02:23</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>1.160010</td>\n",
" <td>1.154163</td>\n",
" <td>0.666426</td>\n",
" <td>02:24</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16</td>\n",
" <td>1.153085</td>\n",
" <td>1.149265</td>\n",
" <td>0.667741</td>\n",
" <td>02:23</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17</td>\n",
" <td>1.141003</td>\n",
" <td>1.143118</td>\n",
" <td>0.670634</td>\n",
" <td>02:23</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>1.138742</td>\n",
" <td>1.137632</td>\n",
" <td>0.671683</td>\n",
" <td>02:23</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>1.129604</td>\n",
" <td>1.136192</td>\n",
" <td>0.672162</td>\n",
" <td>02:23</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>1.125707</td>\n",
" <td>1.133340</td>\n",
" <td>0.672295</td>\n",
" <td>02:23</td>\n",
" </tr>\n",
" <tr>\n",
" <td>21</td>\n",
" <td>1.130975</td>\n",
" <td>1.133293</td>\n",
" <td>0.670490</td>\n",
" <td>02:23</td>\n",
" </tr>\n",
" <tr>\n",
" <td>22</td>\n",
" <td>1.134533</td>\n",
" <td>1.132475</td>\n",
" <td>0.670508</td>\n",
" <td>02:23</td>\n",
" </tr>\n",
" <tr>\n",
" <td>23</td>\n",
" <td>1.129584</td>\n",
" <td>1.130723</td>\n",
" <td>0.672191</td>\n",
" <td>02:23</td>\n",
" </tr>\n",
" <tr>\n",
" <td>24</td>\n",
" <td>1.126617</td>\n",
" <td>1.130564</td>\n",
" <td>0.672605</td>\n",
" <td>02:23</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1051: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
" ret = func(*args, **kwargs)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
""
],
"metadata": {
"id": "sgCb1pswUtv5"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment