Created
March 23, 2022 07:32
-
-
Save moarshy/0b1edde8afd538e5073fb771b2753315 to your computer and use it in GitHub Desktop.
MNUnetModel.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "MNUnetModel.ipynb", | |
"provenance": [], | |
"collapsed_sections": [ | |
"nl_Y1QW9OsZR" | |
], | |
"authorship_tag": "ABX9TyNFaFW9sZHD/tLhto6ay9vQ", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU", | |
"widgets": { | |
"application/vnd.jupyter.widget-state+json": { | |
"8eb841b4c62b4042bd9a5af15342c012": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HBoxModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HBoxView", | |
"box_style": "", | |
"children": [ | |
"IPY_MODEL_dd011ba65aa04664bd63eb8f79ca29c8", | |
"IPY_MODEL_46490715f4934745aa454ae8e119f40d", | |
"IPY_MODEL_f7444e1bb42a4c6f913e05514445ce59" | |
], | |
"layout": "IPY_MODEL_cd7643d9744340ec99319a52c2faf2c9" | |
} | |
}, | |
"dd011ba65aa04664bd63eb8f79ca29c8": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_a1860ff990b042a5b50f87066e30bb9b", | |
"placeholder": "", | |
"style": "IPY_MODEL_18da3006e55d432d9038da4f74558e4f", | |
"value": "100%" | |
} | |
}, | |
"46490715f4934745aa454ae8e119f40d": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "FloatProgressModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "ProgressView", | |
"bar_style": "success", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_5ca90a8d4e0c4be3888dbccafe6aafce", | |
"max": 178793939, | |
"min": 0, | |
"orientation": "horizontal", | |
"style": "IPY_MODEL_56aa6084019b49b1ad09c0f1aaabdf01", | |
"value": 178793939 | |
} | |
}, | |
"f7444e1bb42a4c6f913e05514445ce59": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_dom_classes": [], | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "HTMLModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/controls", | |
"_view_module_version": "1.5.0", | |
"_view_name": "HTMLView", | |
"description": "", | |
"description_tooltip": null, | |
"layout": "IPY_MODEL_e43806d23dad4c86afb2428ce56504b3", | |
"placeholder": "", | |
"style": "IPY_MODEL_4eec62679381415ab632253d8f043108", | |
"value": " 171M/171M [00:03<00:00, 44.6MB/s]" | |
} | |
}, | |
"cd7643d9744340ec99319a52c2faf2c9": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"a1860ff990b042a5b50f87066e30bb9b": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"18da3006e55d432d9038da4f74558e4f": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
}, | |
"5ca90a8d4e0c4be3888dbccafe6aafce": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"56aa6084019b49b1ad09c0f1aaabdf01": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "ProgressStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"bar_color": null, | |
"description_width": "" | |
} | |
}, | |
"e43806d23dad4c86afb2428ce56504b3": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"model_module_version": "1.2.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.2.0", | |
"_model_name": "LayoutModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "LayoutView", | |
"align_content": null, | |
"align_items": null, | |
"align_self": null, | |
"border": null, | |
"bottom": null, | |
"display": null, | |
"flex": null, | |
"flex_flow": null, | |
"grid_area": null, | |
"grid_auto_columns": null, | |
"grid_auto_flow": null, | |
"grid_auto_rows": null, | |
"grid_column": null, | |
"grid_gap": null, | |
"grid_row": null, | |
"grid_template_areas": null, | |
"grid_template_columns": null, | |
"grid_template_rows": null, | |
"height": null, | |
"justify_content": null, | |
"justify_items": null, | |
"left": null, | |
"margin": null, | |
"max_height": null, | |
"max_width": null, | |
"min_height": null, | |
"min_width": null, | |
"object_fit": null, | |
"object_position": null, | |
"order": null, | |
"overflow": null, | |
"overflow_x": null, | |
"overflow_y": null, | |
"padding": null, | |
"right": null, | |
"top": null, | |
"visibility": null, | |
"width": null | |
} | |
}, | |
"4eec62679381415ab632253d8f043108": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"model_module_version": "1.5.0", | |
"state": { | |
"_model_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_model_name": "DescriptionStyleModel", | |
"_view_count": null, | |
"_view_module": "@jupyter-widgets/base", | |
"_view_module_version": "1.2.0", | |
"_view_name": "StyleView", | |
"description_width": "" | |
} | |
} | |
} | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/moarshy/0b1edde8afd538e5073fb771b2753315/mnunetmodel.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install timm fastai -Uqq" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "cJ_Y_aOreCYt", | |
"outputId": "bdc1576e-a1fa-4618-9123-e69fd9b926be" | |
}, | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[K |████████████████████████████████| 431 kB 4.0 MB/s \n", | |
"\u001b[K |████████████████████████████████| 189 kB 45.0 MB/s \n", | |
"\u001b[K |████████████████████████████████| 55 kB 3.9 MB/s \n", | |
"\u001b[?25h" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": { | |
"id": "ynxvBvjJMH8F" | |
}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"\n", | |
"import timm\n", | |
"from timm import create_model\n", | |
"from timm.models.efficientnet_blocks import DepthwiseSeparableConv\n", | |
"\n", | |
"from fastai.vision.all import *\n", | |
"from fastai.callback.all import *" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"In this notebook,\n", | |
"- [x] Implement this [paper](https://openaccess.thecvf.com/content/CVPR2021W/MAI/papers/Zhang_A_Simple_Baseline_for_Fast_and_Accurate_Depth_Estimation_on_CVPRW_2021_paper.pdf)\n", | |
"- [x] Train using knowledge distillation\n" | |
], | |
"metadata": { | |
"id": "QBoa9XfUjYsV" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"**The paper abstract**\n", | |
"\n", | |
" " | |
], | |
"metadata": { | |
"id": "UseLzzkckySY" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"**The architecture**\n", | |
"" | |
], | |
"metadata": { | |
"id": "k5lA8zzHk-z9" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Some codes are based off https://gist.github.com/rwightman/f8b24f4e6f5504aba03e999e02460d31\n", | |
"\n", | |
"class Conv2dBnAct(nn.Module):\n", | |
" def __init__(self, \n", | |
" in_channels, \n", | |
" out_channels, \n", | |
" kernel_size, \n", | |
" padding=0,\n", | |
" stride=1, \n", | |
" act_layer=nn.ReLU, \n", | |
" norm_layer=nn.BatchNorm2d\n", | |
" ):\n", | |
" super().__init__()\n", | |
" self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False)\n", | |
" self.bn = norm_layer(out_channels)\n", | |
" self.act = act_layer(inplace=True)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" x = self.conv(x)\n", | |
" x = self.bn(x)\n", | |
" x = self.act(x)\n", | |
" return x\n", | |
"\n", | |
"\n", | |
"class FeatureFusionModule(nn.Module):\n", | |
" def __init__(self, \n", | |
" enc_in_channels, # Encoder in channels\n", | |
" enc_out_channels, # Encoder out channels\n", | |
" dec_in_channels, # Decoder in channels\n", | |
" out_channels, # Final out channels\n", | |
" ):\n", | |
" super().__init__()\n", | |
" #encoderoutput\n", | |
" self.enc_conv1 = nn.Conv2d(enc_in_channels, enc_out_channels, kernel_size=1, stride=1, padding='same')\n", | |
" self.enc_up = nn.ConvTranspose2d(enc_out_channels, enc_out_channels, kernel_size=1)\n", | |
" self.enc_dconv = DepthwiseSeparableConv(enc_out_channels, enc_out_channels)\n", | |
" self.enc_conv2 = nn.Conv2d(enc_out_channels, enc_out_channels, kernel_size=1, stride=1, padding='same')\n", | |
" \n", | |
" #decoderoutput\n", | |
" self.dec_dconv = DepthwiseSeparableConv(enc_out_channels+dec_in_channels, enc_out_channels+dec_in_channels)\n", | |
" self.dec_conv1 = nn.Conv2d(enc_out_channels+dec_in_channels, out_channels, kernel_size=1, stride=1, padding='same')\n", | |
"\n", | |
"\n", | |
" def forward(self, enc_x, dec_x):\n", | |
" enc_x = self.enc_conv1(enc_x)\n", | |
" enc_x = self.enc_up(enc_x)\n", | |
" enc_x = self.enc_dconv(enc_x)\n", | |
" enc_x = self.enc_conv2(enc_x)\n", | |
" \n", | |
" x = torch.cat([enc_x, dec_x], dim=1)\n", | |
"\n", | |
" dec_x = self.dec_dconv(x)\n", | |
" dec_x = self.dec_conv1(dec_x)\n", | |
" \n", | |
" return dec_x\n", | |
"\n", | |
"\n", | |
"class DecoderBlock(nn.Module):\n", | |
" def __init__(self, \n", | |
" enc_channels,\n", | |
" dec_prev_channels, \n", | |
" dec_channels,\n", | |
" act_layer=nn.ReLU, \n", | |
" norm_layer=nn.BatchNorm2d,\n", | |
" ffm=True,\n", | |
" ):\n", | |
" super().__init__()\n", | |
" conv_args = dict(kernel_size=3, padding=1, act_layer=act_layer)\n", | |
" self.ffm = ffm\n", | |
" \n", | |
" if ffm:\n", | |
" self.ffm = FeatureFusionModule(enc_channels, enc_channels, dec_prev_channels, dec_channels)\n", | |
"\n", | |
" self.conv1 = Conv2dBnAct(enc_channels, dec_channels, norm_layer=norm_layer, **conv_args)\n", | |
" self.conv2 = Conv2dBnAct(dec_channels, dec_channels, norm_layer=norm_layer, **conv_args)\n", | |
" \n", | |
"\n", | |
" def forward(self, x_enc, x_dec):\n", | |
" if self.ffm:\n", | |
" x = self.ffm(x_enc, x_dec)\n", | |
"\n", | |
" x = F.interpolate(x_enc, scale_factor=2, mode='nearest')\n", | |
"\n", | |
" x = self.conv1(x)\n", | |
" x = self.conv2(x)\n", | |
"\n", | |
" return x\n", | |
"\n", | |
"\n", | |
"class UnetDecoder(nn.Module):\n", | |
"\n", | |
" def __init__(self,\n", | |
" encoder_channels,\n", | |
" decoder_channels=(256, 128, 64, 32, 16),\n", | |
" final_channels=3,\n", | |
" norm_layer=nn.BatchNorm2d,\n", | |
" ):\n", | |
" super().__init__()\n", | |
"\n", | |
" self.decoders = nn.ModuleList()\n", | |
" for i, (e_ch, d_ch) in enumerate(zip(encoder_channels, decoder_channels)):\n", | |
" if i== 0:\n", | |
" self.decoders.append(DecoderBlock(enc_channels=e_ch, \n", | |
" dec_prev_channels=None, \n", | |
" dec_channels=d_ch,\n", | |
" act_layer=nn.ReLU, \n", | |
" norm_layer=nn.BatchNorm2d,\n", | |
" ffm=False,\n", | |
" ))\n", | |
"\n", | |
" else:\n", | |
" self.decoders.append(DecoderBlock(enc_channels=e_ch, \n", | |
" dec_prev_channels=decoder_channels[i-1], \n", | |
" dec_channels=d_ch,\n", | |
" act_layer=nn.ReLU, \n", | |
" norm_layer=nn.BatchNorm2d,\n", | |
" ffm=True,\n", | |
" ))\n", | |
"\n", | |
" self.final_conv = nn.Conv2d(decoder_channels[-1], final_channels, kernel_size=(1, 1))\n", | |
" self.tensor_base = ToTensorBase()\n", | |
" self._init_weight()\n", | |
"\n", | |
" def _init_weight(self):\n", | |
" for m in self.modules():\n", | |
" if isinstance(m, nn.Conv2d):\n", | |
" torch.nn.init.kaiming_normal_(m.weight)\n", | |
" elif isinstance(m, nn.BatchNorm2d):\n", | |
" m.weight.data.fill_(1)\n", | |
" m.bias.data.zero_()\n", | |
"\n", | |
"\n", | |
" def forward(self, x):\n", | |
" enc_outs_r = x\n", | |
" dec_out = None\n", | |
" for i, each in enumerate(self.decoders):\n", | |
" dec_out = each(enc_outs_r[i], dec_out)\n", | |
" x = self.final_conv(dec_out)\n", | |
" x = self.tensor_base(x)\n", | |
" return x\n", | |
"\n", | |
"\n", | |
"class Unet(nn.Module):\n", | |
" def __init__(self,\n", | |
" backbone='resnet50',\n", | |
" backbone_kwargs=None,\n", | |
" backbone_indices=None,\n", | |
" decoder_use_batchnorm=True,\n", | |
" decoder_channels=(256, 128, 64, 32, 16),\n", | |
" in_chans=3,\n", | |
" num_classes=3,\n", | |
" norm_layer=nn.BatchNorm2d,\n", | |
" pretrained=True,\n", | |
" ):\n", | |
" super().__init__()\n", | |
" backbone_kwargs = backbone_kwargs or {}\n", | |
" # NOTE some models need different backbone indices specified based on the alignment of features\n", | |
" # and some models won't have a full enough range of feature strides to work properly.\n", | |
" encoder = create_model(\n", | |
" backbone, features_only=True, out_indices=backbone_indices, in_chans=in_chans,\n", | |
" pretrained=pretrained, **backbone_kwargs)\n", | |
" encoder_channels = encoder.feature_info.channels()[::-1]\n", | |
" self.encoder = encoder\n", | |
"\n", | |
" self.decoder = UnetDecoder(\n", | |
" encoder_channels=encoder_channels,\n", | |
" decoder_channels=decoder_channels,\n", | |
" final_channels=num_classes,\n", | |
" norm_layer=norm_layer,\n", | |
" )\n", | |
"\n", | |
" def forward(self, x: torch.Tensor):\n", | |
" x = self.encoder(x)\n", | |
" x.reverse() # torchscript doesn't work with [::-1]\n", | |
" x = self.decoder(x)\n", | |
" return x" | |
], | |
"metadata": { | |
"id": "t0IVmDv-BZsN" | |
}, | |
"execution_count": 24, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Test individuals" | |
], | |
"metadata": { | |
"id": "nl_Y1QW9OsZR" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model = Unet('mobilenetv3_rw')" | |
], | |
"metadata": { | |
"id": "acc4eR-HNati" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"o = model(torch.randn(2,3,128,160))" | |
], | |
"metadata": { | |
"id": "JMxOElmXNyfJ" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"o.shape" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "s5ITHIt_N3Za", | |
"outputId": "985af779-f5a7-4c0e-aded-398b5cce175f" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"torch.Size([2, 3, 128, 160])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 10 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"encoder = create_model(\n", | |
" 'mobilenetv3_rw', features_only=True, out_indices=None, in_chans=3,\n", | |
" pretrained=True,)" | |
], | |
"metadata": { | |
"id": "d6nwuci0vEJi" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"enc_outs = encoder(torch.randn(2, 3, 128, 160))\n", | |
"enc_outs_r = enc_outs[::-1]" | |
], | |
"metadata": { | |
"id": "cpIXAnerTZ_v" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"for e in enc_outs:\n", | |
" print(e.shape)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "oAE8LNQQTqwG", | |
"outputId": "512ff956-62a4-4bd1-ad36-25c4bd418ef7" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"torch.Size([2, 16, 64, 80])\n", | |
"torch.Size([2, 24, 32, 40])\n", | |
"torch.Size([2, 40, 16, 20])\n", | |
"torch.Size([2, 112, 8, 10])\n", | |
"torch.Size([2, 960, 4, 5])\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"enc_channels = encoder.feature_info.channels()[::-1]; enc_channels" | |
], | |
"metadata": { | |
"id": "YhvGiHYNNgcH", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "acf5e4d3-7f1e-4332-a7ad-08ae84ee10ef" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[960, 112, 40, 24, 16]" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 159 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec_channels = [256, 128, 64, 32, 16]" | |
], | |
"metadata": { | |
"id": "yOoZwawcNvaD" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec1 = DecoderBlock(enc_channels[0], \n", | |
" None, \n", | |
" dec_channels[0],\n", | |
" ffm=False)" | |
], | |
"metadata": { | |
"id": "oZLm9YwkODQv" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec1_out = dec1(enc_outs_r[0], None)" | |
], | |
"metadata": { | |
"id": "3h-2jsN2OVNO" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec1_out.shape" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "oE-X8RKNOl8e", | |
"outputId": "51ccf30e-782d-462f-fcfc-28a0b8347e7b" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"torch.Size([2, 256, 8, 10])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 163 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec2 = DecoderBlock(enc_channels[1], \n", | |
" dec_channels[1-1], \n", | |
" dec_channels[1],\n", | |
" ffm=True)" | |
], | |
"metadata": { | |
"id": "uchSrgNlwUg5" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec2_out = dec2(enc_outs_r[1], dec1_out)" | |
], | |
"metadata": { | |
"id": "PTGDO0ooweXJ" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec2_out.shape" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "2iqQyBVvwpFx", | |
"outputId": "3c1e185f-2d82-49ae-92b8-64d52000c875" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"torch.Size([2, 128, 16, 20])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 154 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec3 = DecoderBlock(enc_channels[2], \n", | |
" dec_channels[2-1], \n", | |
" dec_channels[2],\n", | |
" ffm=True)" | |
], | |
"metadata": { | |
"id": "YI38w3tCwq9K" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec3_out = dec3(enc_outs_r[2], dec2_out)" | |
], | |
"metadata": { | |
"id": "Id5PcFlDw4M6" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec3_out.shape" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "QR3fRqsh1Hrx", | |
"outputId": "76152af4-972f-45c9-c5d1-dcc795edec2a" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"torch.Size([2, 64, 32, 40])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 169 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec4 = DecoderBlock(enc_channels[3], \n", | |
" dec_channels[3-1], \n", | |
" dec_channels[3],\n", | |
" ffm=True)" | |
], | |
"metadata": { | |
"id": "ytatkui20UNi" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec4_out = dec4(enc_outs_r[3], dec3_out)" | |
], | |
"metadata": { | |
"id": "0YxQrVnB0dDl" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec4_out.shape" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "2G3Uht6T1Kbi", | |
"outputId": "7df41391-e29d-4c2f-ffa9-439f7033717d" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"torch.Size([2, 32, 64, 80])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 172 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec5 = DecoderBlock(enc_channels[4], \n", | |
" dec_channels[4-1], \n", | |
" dec_channels[4],\n", | |
" ffm=True)" | |
], | |
"metadata": { | |
"id": "dITgOOl81QcL" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec5_out = dec5(enc_outs_r[4], dec4_out)" | |
], | |
"metadata": { | |
"id": "LqHdur5d1WCI" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec5_out.shape" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "s1SuPses1dKL", | |
"outputId": "3c23592b-b428-43df-ea7c-6335f7ddd58b" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"torch.Size([2, 16, 128, 160])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 175 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"" | |
], | |
"metadata": { | |
"id": "VzKVZSeTPXx0" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec = UnetDecoder(enc_channels)" | |
], | |
"metadata": { | |
"id": "HCu10Ce_jhrj" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"x = dec(enc_outs_r)" | |
], | |
"metadata": { | |
"id": "QSZ27lM3jr7L" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"" | |
], | |
"metadata": { | |
"id": "9QPdsDft2SrY" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Test Unet training" | |
], | |
"metadata": { | |
"id": "axnPMcaJOvGp" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"path = untar_data(URLs.CAMVID)\n", | |
"path.ls()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "QJjxTv5xOxNF", | |
"outputId": "a72a4983-fd95-44e6-dd81-b3eddd77291b" | |
}, | |
"execution_count": 25, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(#4) [Path('/root/.fastai/data/camvid/codes.txt'),Path('/root/.fastai/data/camvid/images'),Path('/root/.fastai/data/camvid/valid.txt'),Path('/root/.fastai/data/camvid/labels')]" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 25 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"codes = np.loadtxt(path/'codes.txt', dtype=str)\n", | |
"codes" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "PCkr63ykO-LR", | |
"outputId": "b1c7586b-9e84-4d89-9c19-7cf98384e74a" | |
}, | |
"execution_count": 26, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array(['Animal', 'Archway', 'Bicyclist', 'Bridge', 'Building', 'Car',\n", | |
" 'CartLuggagePram', 'Child', 'Column_Pole', 'Fence', 'LaneMkgsDriv',\n", | |
" 'LaneMkgsNonDriv', 'Misc_Text', 'MotorcycleScooter', 'OtherMoving',\n", | |
" 'ParkingBlock', 'Pedestrian', 'Road', 'RoadShoulder', 'Sidewalk',\n", | |
" 'SignSymbol', 'Sky', 'SUVPickupTruck', 'TrafficCone',\n", | |
" 'TrafficLight', 'Train', 'Tree', 'Truck_Bus', 'Tunnel',\n", | |
" 'VegetationMisc', 'Void', 'Wall'], dtype='<U17')" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 26 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"fnames = get_image_files(path/\"images\")" | |
], | |
"metadata": { | |
"id": "dKQMJJ80PJwY" | |
}, | |
"execution_count": 27, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def label_func(fn): return path/\"labels\"/f\"{fn.stem}_P{fn.suffix}\"" | |
], | |
"metadata": { | |
"id": "9SZzi4m6PMpS" | |
}, | |
"execution_count": 28, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"name2id = {v:k for k,v in enumerate(codes)}\n", | |
"void_code = name2id['Void']\n", | |
"def acc_camvid(inp, targ):\n", | |
" targ = targ.squeeze(1)\n", | |
" mask = targ != void_code\n", | |
" return np.mean(inp.argmax(dim=1)[mask].cpu().numpy()==targ[mask].cpu().numpy())" | |
], | |
"metadata": { | |
"id": "lU6X1jDGQExa" | |
}, | |
"execution_count": 29, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dls = SegmentationDataLoaders.from_label_func(path, \n", | |
" bs=8, \n", | |
" fnames = fnames, \n", | |
" label_func = label_func, \n", | |
" codes = codes,\n", | |
" item_tfms=Resize((128, 160)))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "bj7NRyDCPNjl", | |
"outputId": "939b58d5-9db2-429e-dc88-c12f2a1220c3" | |
}, | |
"execution_count": 30, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1051: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n", | |
" ret = func(*args, **kwargs)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## with Dice Loss" | |
], | |
"metadata": { | |
"id": "vWuMb0ExqbSM" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model = Unet('mobilenetv3_rw',\n", | |
" num_classes=32)\n", | |
"\n", | |
"learn = Learner(dls, \n", | |
" model,\n", | |
" loss_func=DiceLoss(),\n", | |
" metrics=acc_camvid)" | |
], | |
"metadata": { | |
"id": "Utq5Utm0PQA0", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "2b494288-aaf0-4463-a33d-86ace3ea2ce9" | |
}, | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Downloading: \"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth\" to /root/.cache/torch/hub/checkpoints/mobilenetv3_100-35495452.pth\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"learn.fit_one_cycle(10, 1e-3)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 418 | |
}, | |
"id": "RQkxMCbWQTgR", | |
"outputId": "6d31edfb-ed01-4f7c-a434-8fac551ec433" | |
}, | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"\n", | |
"<style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
"</style>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>acc_camvid</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>252.215271</td>\n", | |
" <td>246.689651</td>\n", | |
" <td>0.183169</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>1</td>\n", | |
" <td>245.648544</td>\n", | |
" <td>237.859009</td>\n", | |
" <td>0.438296</td>\n", | |
" <td>00:26</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>2</td>\n", | |
" <td>240.166595</td>\n", | |
" <td>233.821426</td>\n", | |
" <td>0.527989</td>\n", | |
" <td>00:31</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>3</td>\n", | |
" <td>236.032440</td>\n", | |
" <td>230.155685</td>\n", | |
" <td>0.579975</td>\n", | |
" <td>00:26</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>4</td>\n", | |
" <td>232.820297</td>\n", | |
" <td>227.664886</td>\n", | |
" <td>0.598176</td>\n", | |
" <td>00:26</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>5</td>\n", | |
" <td>230.536667</td>\n", | |
" <td>226.120468</td>\n", | |
" <td>0.612691</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>6</td>\n", | |
" <td>229.252609</td>\n", | |
" <td>225.270096</td>\n", | |
" <td>0.618059</td>\n", | |
" <td>00:26</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>7</td>\n", | |
" <td>228.376007</td>\n", | |
" <td>224.750214</td>\n", | |
" <td>0.621125</td>\n", | |
" <td>00:26</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>8</td>\n", | |
" <td>227.940201</td>\n", | |
" <td>224.663452</td>\n", | |
" <td>0.622177</td>\n", | |
" <td>00:26</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>9</td>\n", | |
" <td>227.806488</td>\n", | |
" <td>224.604568</td>\n", | |
" <td>0.624282</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1051: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n", | |
" ret = func(*args, **kwargs)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"learn.summary()" | |
], | |
"metadata": { | |
"id": "L53Vk3fCZKDE" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## with CrossEntropy loss" | |
], | |
"metadata": { | |
"id": "tF_pW798qdgI" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model = Unet('mobilenetv3_rw',\n", | |
" num_classes=32)\n", | |
"\n", | |
"learn = Learner(dls, \n", | |
" model,\n", | |
" metrics=acc_camvid)" | |
], | |
"metadata": { | |
"id": "ZSLhn68ZmkmZ" | |
}, | |
"execution_count": 31, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"learn.fit_one_cycle(25, 1e-3)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 888 | |
}, | |
"id": "qLJF8V6zmrs3", | |
"outputId": "5adf8515-9496-4f3a-be10-2ccbe7952a43" | |
}, | |
"execution_count": 32, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"\n", | |
"<style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
"</style>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>acc_camvid</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>3.952452</td>\n", | |
" <td>3.837103</td>\n", | |
" <td>0.009524</td>\n", | |
" <td>00:33</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>1</td>\n", | |
" <td>3.651547</td>\n", | |
" <td>3.323963</td>\n", | |
" <td>0.034026</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>2</td>\n", | |
" <td>3.141333</td>\n", | |
" <td>2.725803</td>\n", | |
" <td>0.219234</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>3</td>\n", | |
" <td>2.604731</td>\n", | |
" <td>2.217971</td>\n", | |
" <td>0.475758</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>4</td>\n", | |
" <td>2.168982</td>\n", | |
" <td>1.879862</td>\n", | |
" <td>0.516642</td>\n", | |
" <td>00:29</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>5</td>\n", | |
" <td>1.846897</td>\n", | |
" <td>1.635621</td>\n", | |
" <td>0.553690</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>6</td>\n", | |
" <td>1.600007</td>\n", | |
" <td>1.454400</td>\n", | |
" <td>0.619151</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>7</td>\n", | |
" <td>1.441168</td>\n", | |
" <td>1.355009</td>\n", | |
" <td>0.629975</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>8</td>\n", | |
" <td>1.343089</td>\n", | |
" <td>1.295351</td>\n", | |
" <td>0.633226</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>9</td>\n", | |
" <td>1.287046</td>\n", | |
" <td>1.254252</td>\n", | |
" <td>0.639729</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>10</td>\n", | |
" <td>1.246888</td>\n", | |
" <td>1.224334</td>\n", | |
" <td>0.643034</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>11</td>\n", | |
" <td>1.220407</td>\n", | |
" <td>1.203212</td>\n", | |
" <td>0.646837</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>12</td>\n", | |
" <td>1.192617</td>\n", | |
" <td>1.186723</td>\n", | |
" <td>0.649282</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>13</td>\n", | |
" <td>1.178188</td>\n", | |
" <td>1.172160</td>\n", | |
" <td>0.654529</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>14</td>\n", | |
" <td>1.167311</td>\n", | |
" <td>1.163162</td>\n", | |
" <td>0.654595</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>15</td>\n", | |
" <td>1.166651</td>\n", | |
" <td>1.157911</td>\n", | |
" <td>0.654737</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>16</td>\n", | |
" <td>1.148694</td>\n", | |
" <td>1.148224</td>\n", | |
" <td>0.660711</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>17</td>\n", | |
" <td>1.145139</td>\n", | |
" <td>1.146453</td>\n", | |
" <td>0.659014</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>18</td>\n", | |
" <td>1.135504</td>\n", | |
" <td>1.138023</td>\n", | |
" <td>0.661375</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>19</td>\n", | |
" <td>1.139854</td>\n", | |
" <td>1.135780</td>\n", | |
" <td>0.661414</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>20</td>\n", | |
" <td>1.132565</td>\n", | |
" <td>1.136258</td>\n", | |
" <td>0.660869</td>\n", | |
" <td>00:28</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>21</td>\n", | |
" <td>1.126904</td>\n", | |
" <td>1.133178</td>\n", | |
" <td>0.661520</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>22</td>\n", | |
" <td>1.126309</td>\n", | |
" <td>1.132679</td>\n", | |
" <td>0.663100</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>23</td>\n", | |
" <td>1.129516</td>\n", | |
" <td>1.131822</td>\n", | |
" <td>0.663083</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>24</td>\n", | |
" <td>1.132021</td>\n", | |
" <td>1.131812</td>\n", | |
" <td>0.662620</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1051: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n", | |
" ret = func(*args, **kwargs)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# fastai unet-resnet34 performance" | |
], | |
"metadata": { | |
"id": "ZlkdkgOAjTGP" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"teacher_learn = unet_learner(dls, \n", | |
" resnet101, \n", | |
" metrics=acc_camvid)" | |
], | |
"metadata": { | |
"id": "sy8eP1xjbI4h", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 121, | |
"referenced_widgets": [ | |
"8eb841b4c62b4042bd9a5af15342c012", | |
"dd011ba65aa04664bd63eb8f79ca29c8", | |
"46490715f4934745aa454ae8e119f40d", | |
"f7444e1bb42a4c6f913e05514445ce59", | |
"cd7643d9744340ec99319a52c2faf2c9", | |
"a1860ff990b042a5b50f87066e30bb9b", | |
"18da3006e55d432d9038da4f74558e4f", | |
"5ca90a8d4e0c4be3888dbccafe6aafce", | |
"56aa6084019b49b1ad09c0f1aaabdf01", | |
"e43806d23dad4c86afb2428ce56504b3", | |
"4eec62679381415ab632253d8f043108" | |
] | |
}, | |
"outputId": "f3d18777-eb58-4522-e37d-04532869eb1c" | |
}, | |
"execution_count": 34, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1051: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n", | |
" ret = func(*args, **kwargs)\n", | |
"Downloading: \"https://download.pytorch.org/models/resnet101-63fe2227.pth\" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
" 0%| | 0.00/171M [00:00<?, ?B/s]" | |
], | |
"application/vnd.jupyter.widget-view+json": { | |
"version_major": 2, | |
"version_minor": 0, | |
"model_id": "8eb841b4c62b4042bd9a5af15342c012" | |
} | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"teacher_learn.fit_one_cycle(10, 1e-3)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 418 | |
}, | |
"id": "ojFW8MxuouxA", | |
"outputId": "06e9580e-f19b-4a31-ce60-63b522b03020" | |
}, | |
"execution_count": 35, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"\n", | |
"<style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
"</style>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>acc_camvid</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>1.079316</td>\n", | |
" <td>1.945745</td>\n", | |
" <td>0.791928</td>\n", | |
" <td>02:54</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>1</td>\n", | |
" <td>1.672722</td>\n", | |
" <td>3.414593</td>\n", | |
" <td>0.467281</td>\n", | |
" <td>02:33</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>2</td>\n", | |
" <td>2.836250</td>\n", | |
" <td>1.068983</td>\n", | |
" <td>0.708052</td>\n", | |
" <td>02:31</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>3</td>\n", | |
" <td>1.262702</td>\n", | |
" <td>1.077650</td>\n", | |
" <td>0.819915</td>\n", | |
" <td>02:31</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>4</td>\n", | |
" <td>0.755539</td>\n", | |
" <td>0.741171</td>\n", | |
" <td>0.839445</td>\n", | |
" <td>02:30</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>5</td>\n", | |
" <td>0.586151</td>\n", | |
" <td>0.615320</td>\n", | |
" <td>0.857607</td>\n", | |
" <td>02:31</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>6</td>\n", | |
" <td>0.495401</td>\n", | |
" <td>0.540893</td>\n", | |
" <td>0.869873</td>\n", | |
" <td>02:31</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>7</td>\n", | |
" <td>0.429804</td>\n", | |
" <td>0.520235</td>\n", | |
" <td>0.879431</td>\n", | |
" <td>02:31</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>8</td>\n", | |
" <td>0.395290</td>\n", | |
" <td>0.510423</td>\n", | |
" <td>0.879006</td>\n", | |
" <td>02:31</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>9</td>\n", | |
" <td>0.374698</td>\n", | |
" <td>0.482728</td>\n", | |
" <td>0.882320</td>\n", | |
" <td>02:31</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1051: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n", | |
" ret = func(*args, **kwargs)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"teacher_learn.loss_func" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "_ZMICwZzo0MN", | |
"outputId": "659ea2c1-8cc6-4aaa-8776-95e17d8c9e02" | |
}, | |
"execution_count": 36, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"FlattenedLoss of CrossEntropyLoss()" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 36 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Knowledge Distillation" | |
], | |
"metadata": { | |
"id": "IVvuN2gEqjv6" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class DistillationLoss(nn.Module):\n", | |
" def __init__(self):\n", | |
" super(DistillationLoss, self).__init__()\n", | |
" self.distillation_loss = nn.KLDivLoss(reduction='batchmean')\n", | |
" \n", | |
" def forward(self,\n", | |
" student_preds, \n", | |
" teacher_preds, \n", | |
" acutal_target, \n", | |
" T, \n", | |
" alpha\n", | |
" ):\n", | |
"\n", | |
" return self.distillation_loss(F.softmax(student_preds / T, dim=1).reshape(-1),\n", | |
" F.softmax(teacher_preds / T, dim=1).reshape(-1))\n", | |
" \n", | |
"\n", | |
"\n", | |
"class KnowledgeDistillation(Callback):\n", | |
" def __init__(self, \n", | |
" teacher:Learner, \n", | |
" T:float=20., \n", | |
" a:float=0.7):\n", | |
" super(KnowledgeDistillation, self).__init__()\n", | |
" self.teacher = teacher\n", | |
" self.T, self.a = T, a\n", | |
" self.distillation_loss = DistillationLoss()\n", | |
" \n", | |
" def after_loss(self):\n", | |
" teacher_preds = self.teacher.model(self.learn.xb[0])\n", | |
" student_loss = self.learn.loss_grad * self.a\n", | |
" distillation_loss = self.distillation_loss(self.learn.pred, # Student preds\n", | |
" teacher_preds, # Teacher preds\n", | |
" self.learn.yb, # Ground truth\n", | |
" self.T, \n", | |
" self.a) * (1 - self.a)\n", | |
" self.learn.loss_grad = student_loss + distillation_loss" | |
], | |
"metadata": { | |
"id": "f1IO5Vw4tPXX" | |
}, | |
"execution_count": 37, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model = Unet('mobilenetv3_rw',\n", | |
" num_classes=32)\n", | |
"\n", | |
"student_learn = Learner(dls, \n", | |
" model,\n", | |
" metrics=acc_camvid,\n", | |
" cbs=[KnowledgeDistillation(teacher=teacher_learn)])" | |
], | |
"metadata": { | |
"id": "Si1Zl_2uy43B" | |
}, | |
"execution_count": 38, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"student_learn.fit_one_cycle(25, 1e-3)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 888 | |
}, | |
"id": "QjXcKthW0B4t", | |
"outputId": "3aaf9f7b-5598-433e-8d36-f20d44a06b57" | |
}, | |
"execution_count": 39, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"\n", | |
"<style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
"</style>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>acc_camvid</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>4.149306</td>\n", | |
" <td>4.083494</td>\n", | |
" <td>0.012385</td>\n", | |
" <td>02:23</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>1</td>\n", | |
" <td>3.843522</td>\n", | |
" <td>3.565912</td>\n", | |
" <td>0.096796</td>\n", | |
" <td>02:23</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>2</td>\n", | |
" <td>3.263084</td>\n", | |
" <td>2.797809</td>\n", | |
" <td>0.369143</td>\n", | |
" <td>02:23</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>3</td>\n", | |
" <td>2.626706</td>\n", | |
" <td>2.134217</td>\n", | |
" <td>0.512857</td>\n", | |
" <td>02:24</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>4</td>\n", | |
" <td>2.120926</td>\n", | |
" <td>1.793929</td>\n", | |
" <td>0.577740</td>\n", | |
" <td>02:23</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>5</td>\n", | |
" <td>1.782979</td>\n", | |
" <td>1.582902</td>\n", | |
" <td>0.600909</td>\n", | |
" <td>02:23</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>6</td>\n", | |
" <td>1.565306</td>\n", | |
" <td>1.450105</td>\n", | |
" <td>0.618120</td>\n", | |
" <td>02:23</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>7</td>\n", | |
" <td>1.431738</td>\n", | |
" <td>1.363392</td>\n", | |
" <td>0.626684</td>\n", | |
" <td>02:23</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>8</td>\n", | |
" <td>1.350202</td>\n", | |
" <td>1.306182</td>\n", | |
" <td>0.632427</td>\n", | |
" <td>02:23</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>9</td>\n", | |
" <td>1.291344</td>\n", | |
" <td>1.265584</td>\n", | |
" <td>0.640540</td>\n", | |
" <td>02:23</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>10</td>\n", | |
" <td>1.253604</td>\n", | |
" <td>1.234486</td>\n", | |
" <td>0.646430</td>\n", | |
" <td>02:24</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>11</td>\n", | |
" <td>1.225404</td>\n", | |
" <td>1.220467</td>\n", | |
" <td>0.650743</td>\n", | |
" <td>02:24</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>12</td>\n", | |
" <td>1.201750</td>\n", | |
" <td>1.195458</td>\n", | |
" <td>0.652415</td>\n", | |
" <td>02:24</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>13</td>\n", | |
" <td>1.185998</td>\n", | |
" <td>1.175060</td>\n", | |
" <td>0.658728</td>\n", | |
" <td>02:24</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>14</td>\n", | |
" <td>1.171920</td>\n", | |
" <td>1.162316</td>\n", | |
" <td>0.664814</td>\n", | |
" <td>02:23</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>15</td>\n", | |
" <td>1.160010</td>\n", | |
" <td>1.154163</td>\n", | |
" <td>0.666426</td>\n", | |
" <td>02:24</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>16</td>\n", | |
" <td>1.153085</td>\n", | |
" <td>1.149265</td>\n", | |
" <td>0.667741</td>\n", | |
" <td>02:23</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>17</td>\n", | |
" <td>1.141003</td>\n", | |
" <td>1.143118</td>\n", | |
" <td>0.670634</td>\n", | |
" <td>02:23</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>18</td>\n", | |
" <td>1.138742</td>\n", | |
" <td>1.137632</td>\n", | |
" <td>0.671683</td>\n", | |
" <td>02:23</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>19</td>\n", | |
" <td>1.129604</td>\n", | |
" <td>1.136192</td>\n", | |
" <td>0.672162</td>\n", | |
" <td>02:23</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>20</td>\n", | |
" <td>1.125707</td>\n", | |
" <td>1.133340</td>\n", | |
" <td>0.672295</td>\n", | |
" <td>02:23</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>21</td>\n", | |
" <td>1.130975</td>\n", | |
" <td>1.133293</td>\n", | |
" <td>0.670490</td>\n", | |
" <td>02:23</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>22</td>\n", | |
" <td>1.134533</td>\n", | |
" <td>1.132475</td>\n", | |
" <td>0.670508</td>\n", | |
" <td>02:23</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>23</td>\n", | |
" <td>1.129584</td>\n", | |
" <td>1.130723</td>\n", | |
" <td>0.672191</td>\n", | |
" <td>02:23</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>24</td>\n", | |
" <td>1.126617</td>\n", | |
" <td>1.130564</td>\n", | |
" <td>0.672605</td>\n", | |
" <td>02:23</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1051: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n", | |
" ret = func(*args, **kwargs)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"" | |
], | |
"metadata": { | |
"id": "sgCb1pswUtv5" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment