Last active
March 17, 2022 10:32
-
-
Save moarshy/9bced4d19c4f88b826e5e712d711883e to your computer and use it in GitHub Desktop.
MNUnetModel.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "MNUnetModel.ipynb", | |
"provenance": [], | |
"collapsed_sections": [ | |
"nl_Y1QW9OsZR" | |
], | |
"authorship_tag": "ABX9TyN2SrM5d/WkPyup/Oo/aEly", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/moarshy/9bced4d19c4f88b826e5e712d711883e/mnunetmodel.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install timm fastai -Uqq" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "cJ_Y_aOreCYt", | |
"outputId": "017efb0d-a7ea-4cef-8659-1dcdb1b92d9c" | |
}, | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[K |████████████████████████████████| 431 kB 5.3 MB/s \n", | |
"\u001b[K |████████████████████████████████| 189 kB 42.8 MB/s \n", | |
"\u001b[K |████████████████████████████████| 55 kB 2.4 MB/s \n", | |
"\u001b[?25h" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"id": "ynxvBvjJMH8F" | |
}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"\n", | |
"import timm\n", | |
"from timm import create_model\n", | |
"from timm.models.efficientnet_blocks import DepthwiseSeparableConv\n", | |
"\n", | |
"from fastai.vision.all import *" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Some codes are based off https://gist.github.com/rwightman/f8b24f4e6f5504aba03e999e02460d31\n", | |
"\n", | |
"class Conv2dBnAct(nn.Module):\n", | |
" def __init__(self, \n", | |
" in_channels, \n", | |
" out_channels, \n", | |
" kernel_size, \n", | |
" padding=0,\n", | |
" stride=1, \n", | |
" act_layer=nn.ReLU, \n", | |
" norm_layer=nn.BatchNorm2d\n", | |
" ):\n", | |
" super().__init__()\n", | |
" self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False)\n", | |
" self.bn = norm_layer(out_channels)\n", | |
" self.act = act_layer(inplace=True)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" x = self.conv(x)\n", | |
" x = self.bn(x)\n", | |
" x = self.act(x)\n", | |
" return x\n", | |
"\n", | |
"\n", | |
"class FeatureFusionModule(nn.Module):\n", | |
" def __init__(self, \n", | |
" enc_in_channels, \n", | |
" enc_out_channels, \n", | |
" dec_in_channels, \n", | |
" out_channels\n", | |
" ):\n", | |
" super().__init__()\n", | |
" #encoderoutput\n", | |
" self.enc_conv1 = nn.Conv2d(enc_in_channels, enc_out_channels, kernel_size=1, stride=1, padding='same')\n", | |
" self.enc_up = nn.ConvTranspose2d(enc_out_channels, enc_out_channels, kernel_size=1)\n", | |
" self.enc_dconv = DepthwiseSeparableConv(enc_out_channels, enc_out_channels)\n", | |
" self.enc_conv2 = nn.Conv2d(enc_out_channels, enc_out_channels, kernel_size=1, stride=1, padding='same')\n", | |
" \n", | |
" #decoderoutput\n", | |
" self.dec_dconv = DepthwiseSeparableConv(enc_out_channels+dec_in_channels, enc_out_channels+dec_in_channels)\n", | |
" self.dec_conv1 = nn.Conv2d(enc_out_channels+dec_in_channels, out_channels, kernel_size=1, stride=1, padding='same')\n", | |
"\n", | |
"\n", | |
" def forward(self, enc_x, dec_x):\n", | |
" enc_x = self.enc_conv1(enc_x)\n", | |
" enc_x = self.enc_up(enc_x)\n", | |
" enc_x = self.enc_dconv(enc_x)\n", | |
" enc_x = self.enc_conv2(enc_x)\n", | |
" \n", | |
" x = torch.cat([enc_x, dec_x], dim=1)\n", | |
"\n", | |
" dec_x = self.dec_dconv(x)\n", | |
" dec_x = self.dec_conv1(dec_x)\n", | |
" \n", | |
" return dec_x\n", | |
"\n", | |
"\n", | |
"class DecoderBlock(nn.Module):\n", | |
" def __init__(self, \n", | |
" enc_channels, \n", | |
" dec_prev_channels, \n", | |
" dec_channels,\n", | |
" act_layer=nn.ReLU, \n", | |
" norm_layer=nn.BatchNorm2d,\n", | |
" ffm=True,\n", | |
" ):\n", | |
" super().__init__()\n", | |
" conv_args = dict(kernel_size=3, padding=1, act_layer=act_layer)\n", | |
" self.ffm = ffm\n", | |
" \n", | |
" if ffm:\n", | |
" self.ffm = FeatureFusionModule(enc_channels, enc_channels, dec_prev_channels, dec_channels)\n", | |
"\n", | |
" self.conv1 = Conv2dBnAct(enc_channels, dec_channels, norm_layer=norm_layer, **conv_args)\n", | |
" self.conv2 = Conv2dBnAct(dec_channels, dec_channels, norm_layer=norm_layer, **conv_args)\n", | |
" \n", | |
"\n", | |
" def forward(self, x_enc, x_dec):\n", | |
" if self.ffm:\n", | |
" x = self.ffm(x_enc, x_dec)\n", | |
"\n", | |
" x = F.interpolate(x_enc, scale_factor=2, mode='nearest')\n", | |
"\n", | |
" x = self.conv1(x)\n", | |
" x = self.conv2(x)\n", | |
"\n", | |
" return x\n", | |
"\n", | |
"\n", | |
"class UnetDecoder(nn.Module):\n", | |
"\n", | |
" def __init__(self,\n", | |
" encoder_channels,\n", | |
" decoder_channels=(256, 128, 64, 32, 16),\n", | |
" final_channels=3,\n", | |
" norm_layer=nn.BatchNorm2d,\n", | |
" ):\n", | |
" super().__init__()\n", | |
"\n", | |
" self.decoders = nn.ModuleList()\n", | |
" for i, (e_ch, d_ch) in enumerate(zip(encoder_channels, decoder_channels)):\n", | |
" if i== 0:\n", | |
" self.decoders.append(DecoderBlock(enc_channels=e_ch, \n", | |
" dec_prev_channels=None, \n", | |
" dec_channels=d_ch,\n", | |
" act_layer=nn.ReLU, \n", | |
" norm_layer=nn.BatchNorm2d,\n", | |
" ffm=False,\n", | |
" ))\n", | |
"\n", | |
" else:\n", | |
" self.decoders.append(DecoderBlock(enc_channels=e_ch, \n", | |
" dec_prev_channels=decoder_channels[i-1], \n", | |
" dec_channels=d_ch,\n", | |
" act_layer=nn.ReLU, \n", | |
" norm_layer=nn.BatchNorm2d,\n", | |
" ffm=True,\n", | |
" ))\n", | |
"\n", | |
" self.final_conv = nn.Conv2d(decoder_channels[-1], final_channels, kernel_size=(1, 1))\n", | |
"\n", | |
" self._init_weight()\n", | |
"\n", | |
" def _init_weight(self):\n", | |
" for m in self.modules():\n", | |
" if isinstance(m, nn.Conv2d):\n", | |
" torch.nn.init.kaiming_normal_(m.weight)\n", | |
" elif isinstance(m, nn.BatchNorm2d):\n", | |
" m.weight.data.fill_(1)\n", | |
" m.bias.data.zero_()\n", | |
"\n", | |
"\n", | |
" def forward(self, x):\n", | |
" enc_outs_r = x\n", | |
" dec_out = None\n", | |
" for i, each in enumerate(self.decoders):\n", | |
" dec_out = each(enc_outs_r[i], dec_out)\n", | |
" x = self.final_conv(dec_out)\n", | |
" return x\n", | |
"\n", | |
"\n", | |
"class Unet(nn.Module):\n", | |
" def __init__(self,\n", | |
" backbone='resnet50',\n", | |
" backbone_kwargs=None,\n", | |
" backbone_indices=None,\n", | |
" decoder_use_batchnorm=True,\n", | |
" decoder_channels=(256, 128, 64, 32, 16),\n", | |
" in_chans=3,\n", | |
" num_classes=3,\n", | |
" norm_layer=nn.BatchNorm2d,\n", | |
" pretrained=True,\n", | |
" ):\n", | |
" super().__init__()\n", | |
" backbone_kwargs = backbone_kwargs or {}\n", | |
" # NOTE some models need different backbone indices specified based on the alignment of features\n", | |
" # and some models won't have a full enough range of feature strides to work properly.\n", | |
" encoder = create_model(\n", | |
" backbone, features_only=True, out_indices=backbone_indices, in_chans=in_chans,\n", | |
" pretrained=pretrained, **backbone_kwargs)\n", | |
" encoder_channels = encoder.feature_info.channels()[::-1]\n", | |
" self.encoder = encoder\n", | |
"\n", | |
" self.decoder = UnetDecoder(\n", | |
" encoder_channels=encoder_channels,\n", | |
" decoder_channels=decoder_channels,\n", | |
" final_channels=num_classes,\n", | |
" norm_layer=norm_layer,\n", | |
" )\n", | |
"\n", | |
" def forward(self, x: torch.Tensor):\n", | |
" x = self.encoder(x)\n", | |
" x.reverse() # torchscript doesn't work with [::-1]\n", | |
" x = self.decoder(x)\n", | |
" return x" | |
], | |
"metadata": { | |
"id": "t0IVmDv-BZsN" | |
}, | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Test individuals" | |
], | |
"metadata": { | |
"id": "nl_Y1QW9OsZR" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model = Unet('mobilenetv3_rw')" | |
], | |
"metadata": { | |
"id": "acc4eR-HNati" | |
}, | |
"execution_count": 17, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"o = model(torch.randn(2,3,128,160))" | |
], | |
"metadata": { | |
"id": "JMxOElmXNyfJ" | |
}, | |
"execution_count": 18, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"o.shape" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "s5ITHIt_N3Za", | |
"outputId": "985af779-f5a7-4c0e-aded-398b5cce175f" | |
}, | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"torch.Size([2, 3, 128, 160])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 10 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"encoder = create_model(\n", | |
" 'mobilenetv3_rw', features_only=True, out_indices=None, in_chans=3,\n", | |
" pretrained=True,)" | |
], | |
"metadata": { | |
"id": "d6nwuci0vEJi" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"enc_outs = encoder(torch.randn(2, 3, 128, 160))\n", | |
"enc_outs_r = enc_outs[::-1]" | |
], | |
"metadata": { | |
"id": "cpIXAnerTZ_v" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"for e in enc_outs:\n", | |
" print(e.shape)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "oAE8LNQQTqwG", | |
"outputId": "512ff956-62a4-4bd1-ad36-25c4bd418ef7" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"torch.Size([2, 16, 64, 80])\n", | |
"torch.Size([2, 24, 32, 40])\n", | |
"torch.Size([2, 40, 16, 20])\n", | |
"torch.Size([2, 112, 8, 10])\n", | |
"torch.Size([2, 960, 4, 5])\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"enc_channels = encoder.feature_info.channels()[::-1]; enc_channels" | |
], | |
"metadata": { | |
"id": "YhvGiHYNNgcH", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "acf5e4d3-7f1e-4332-a7ad-08ae84ee10ef" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[960, 112, 40, 24, 16]" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 159 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec_channels = [256, 128, 64, 32, 16]" | |
], | |
"metadata": { | |
"id": "yOoZwawcNvaD" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec1 = DecoderBlock(enc_channels[0], \n", | |
" None, \n", | |
" dec_channels[0],\n", | |
" ffm=False)" | |
], | |
"metadata": { | |
"id": "oZLm9YwkODQv" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec1_out = dec1(enc_outs_r[0], None)" | |
], | |
"metadata": { | |
"id": "3h-2jsN2OVNO" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec1_out.shape" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "oE-X8RKNOl8e", | |
"outputId": "51ccf30e-782d-462f-fcfc-28a0b8347e7b" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"torch.Size([2, 256, 8, 10])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 163 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec2 = DecoderBlock(enc_channels[1], \n", | |
" dec_channels[1-1], \n", | |
" dec_channels[1],\n", | |
" ffm=True)" | |
], | |
"metadata": { | |
"id": "uchSrgNlwUg5" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec2_out = dec2(enc_outs_r[1], dec1_out)" | |
], | |
"metadata": { | |
"id": "PTGDO0ooweXJ" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec2_out.shape" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "2iqQyBVvwpFx", | |
"outputId": "3c1e185f-2d82-49ae-92b8-64d52000c875" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"torch.Size([2, 128, 16, 20])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 154 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec3 = DecoderBlock(enc_channels[2], \n", | |
" dec_channels[2-1], \n", | |
" dec_channels[2],\n", | |
" ffm=True)" | |
], | |
"metadata": { | |
"id": "YI38w3tCwq9K" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec3_out = dec3(enc_outs_r[2], dec2_out)" | |
], | |
"metadata": { | |
"id": "Id5PcFlDw4M6" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec3_out.shape" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "QR3fRqsh1Hrx", | |
"outputId": "76152af4-972f-45c9-c5d1-dcc795edec2a" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"torch.Size([2, 64, 32, 40])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 169 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec4 = DecoderBlock(enc_channels[3], \n", | |
" dec_channels[3-1], \n", | |
" dec_channels[3],\n", | |
" ffm=True)" | |
], | |
"metadata": { | |
"id": "ytatkui20UNi" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec4_out = dec4(enc_outs_r[3], dec3_out)" | |
], | |
"metadata": { | |
"id": "0YxQrVnB0dDl" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec4_out.shape" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "2G3Uht6T1Kbi", | |
"outputId": "7df41391-e29d-4c2f-ffa9-439f7033717d" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"torch.Size([2, 32, 64, 80])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 172 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec5 = DecoderBlock(enc_channels[4], \n", | |
" dec_channels[4-1], \n", | |
" dec_channels[4],\n", | |
" ffm=True)" | |
], | |
"metadata": { | |
"id": "dITgOOl81QcL" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec5_out = dec5(enc_outs_r[4], dec4_out)" | |
], | |
"metadata": { | |
"id": "LqHdur5d1WCI" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec5_out.shape" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "s1SuPses1dKL", | |
"outputId": "3c23592b-b428-43df-ea7c-6335f7ddd58b" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"torch.Size([2, 16, 128, 160])" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 175 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"" | |
], | |
"metadata": { | |
"id": "VzKVZSeTPXx0" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dec = UnetDecoder(enc_channels)" | |
], | |
"metadata": { | |
"id": "HCu10Ce_jhrj" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"x = dec(enc_outs_r)" | |
], | |
"metadata": { | |
"id": "QSZ27lM3jr7L" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"" | |
], | |
"metadata": { | |
"id": "9QPdsDft2SrY" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Test Unet training" | |
], | |
"metadata": { | |
"id": "axnPMcaJOvGp" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"path = untar_data(URLs.CAMVID)\n", | |
"path.ls()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 54 | |
}, | |
"id": "QJjxTv5xOxNF", | |
"outputId": "2a0dab31-989c-4a6b-fc31-ca3eaec50595" | |
}, | |
"execution_count": 42, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"\n", | |
"<style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
"</style>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"\n", | |
" <div>\n", | |
" <progress value='598917120' class='' max='598913237' style='width:300px; height:20px; vertical-align: middle;'></progress>\n", | |
" 100.00% [598917120/598913237 00:15<00:00]\n", | |
" </div>\n", | |
" " | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(#4) [Path('/root/.fastai/data/camvid/labels'),Path('/root/.fastai/data/camvid/images'),Path('/root/.fastai/data/camvid/valid.txt'),Path('/root/.fastai/data/camvid/codes.txt')]" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 42 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"codes = np.loadtxt(path/'codes.txt', dtype=str)\n", | |
"codes" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "PCkr63ykO-LR", | |
"outputId": "7508d0cb-146a-4b69-a06d-c8a88c91510a" | |
}, | |
"execution_count": 43, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array(['Animal', 'Archway', 'Bicyclist', 'Bridge', 'Building', 'Car',\n", | |
" 'CartLuggagePram', 'Child', 'Column_Pole', 'Fence', 'LaneMkgsDriv',\n", | |
" 'LaneMkgsNonDriv', 'Misc_Text', 'MotorcycleScooter', 'OtherMoving',\n", | |
" 'ParkingBlock', 'Pedestrian', 'Road', 'RoadShoulder', 'Sidewalk',\n", | |
" 'SignSymbol', 'Sky', 'SUVPickupTruck', 'TrafficCone',\n", | |
" 'TrafficLight', 'Train', 'Tree', 'Truck_Bus', 'Tunnel',\n", | |
" 'VegetationMisc', 'Void', 'Wall'], dtype='<U17')" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 43 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"fnames = get_image_files(path/\"images\")" | |
], | |
"metadata": { | |
"id": "dKQMJJ80PJwY" | |
}, | |
"execution_count": 44, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def label_func(fn): return path/\"labels\"/f\"{fn.stem}_P{fn.suffix}\"" | |
], | |
"metadata": { | |
"id": "9SZzi4m6PMpS" | |
}, | |
"execution_count": 45, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"name2id = {v:k for k,v in enumerate(codes)}\n", | |
"void_code = name2id['Void']\n", | |
"def acc_camvid(inp, targ):\n", | |
" targ = targ.squeeze(1)\n", | |
" mask = targ != void_code\n", | |
" return np.mean(inp.argmax(dim=1)[mask].cpu().numpy()==targ[mask].cpu().numpy())" | |
], | |
"metadata": { | |
"id": "lU6X1jDGQExa" | |
}, | |
"execution_count": 136, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"dls = SegmentationDataLoaders.from_label_func(path, \n", | |
" bs=8, \n", | |
" fnames = fnames, \n", | |
" label_func = label_func, \n", | |
" codes = codes,\n", | |
" item_tfms=Resize((128, 160)))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "bj7NRyDCPNjl", | |
"outputId": "05b6d6ea-b39b-46b7-e804-5c5868b9b960" | |
}, | |
"execution_count": 137, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1051: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n", | |
" ret = func(*args, **kwargs)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model = Unet('mobilenetv3_rw',\n", | |
" num_classes=32)\n", | |
"\n", | |
"learn = Learner(dls, \n", | |
" model,\n", | |
" loss_func=DiceLoss(),\n", | |
" metrics=acc_camvid)" | |
], | |
"metadata": { | |
"id": "Utq5Utm0PQA0" | |
}, | |
"execution_count": 140, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"learn.fit_one_cycle(10, 1e-3)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 418 | |
}, | |
"id": "RQkxMCbWQTgR", | |
"outputId": "8c9eaa51-5931-4784-b2a3-6a227a904ea0" | |
}, | |
"execution_count": 141, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"\n", | |
"<style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
"</style>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>acc_camvid</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>252.818253</td>\n", | |
" <td>247.858124</td>\n", | |
" <td>0.147991</td>\n", | |
" <td>00:26</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>1</td>\n", | |
" <td>247.943344</td>\n", | |
" <td>241.364197</td>\n", | |
" <td>0.215207</td>\n", | |
" <td>00:27</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>2</td>\n", | |
" <td>243.855286</td>\n", | |
" <td>237.792664</td>\n", | |
" <td>0.502635</td>\n", | |
" <td>00:26</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>3</td>\n", | |
" <td>238.966187</td>\n", | |
" <td>232.396576</td>\n", | |
" <td>0.571493</td>\n", | |
" <td>00:26</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>4</td>\n", | |
" <td>235.020081</td>\n", | |
" <td>229.611572</td>\n", | |
" <td>0.596396</td>\n", | |
" <td>00:26</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>5</td>\n", | |
" <td>232.624008</td>\n", | |
" <td>228.144196</td>\n", | |
" <td>0.605240</td>\n", | |
" <td>00:26</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>6</td>\n", | |
" <td>231.181992</td>\n", | |
" <td>227.265793</td>\n", | |
" <td>0.609421</td>\n", | |
" <td>00:26</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>7</td>\n", | |
" <td>230.417694</td>\n", | |
" <td>226.769119</td>\n", | |
" <td>0.612933</td>\n", | |
" <td>00:26</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>8</td>\n", | |
" <td>229.980698</td>\n", | |
" <td>226.630600</td>\n", | |
" <td>0.614192</td>\n", | |
" <td>00:26</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>9</td>\n", | |
" <td>229.766006</td>\n", | |
" <td>226.558304</td>\n", | |
" <td>0.612874</td>\n", | |
" <td>00:26</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1051: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n", | |
" ret = func(*args, **kwargs)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"learn.summary()" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000 | |
}, | |
"id": "L53Vk3fCZKDE", | |
"outputId": "95dfa97c-810f-494f-f062-6faeeeaaa9af" | |
}, | |
"execution_count": 142, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1051: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n", | |
" ret = func(*args, **kwargs)\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"\n", | |
"<style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
"</style>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
], | |
"text/html": [ | |
"" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"Unet (Input shape: 8 x 3 x 128 x 160)\n", | |
"============================================================================\n", | |
"Layer (type) Output Shape Param # Trainable \n", | |
"============================================================================\n", | |
" 8 x 16 x 64 x 80 \n", | |
"Conv2d 432 True \n", | |
"BatchNorm2d 32 True \n", | |
"Hardswish \n", | |
"Conv2d 144 True \n", | |
"BatchNorm2d 32 True \n", | |
"ReLU \n", | |
"Identity \n", | |
"Conv2d 256 True \n", | |
"BatchNorm2d 32 True \n", | |
"Identity \n", | |
"____________________________________________________________________________\n", | |
" 8 x 64 x 64 x 80 \n", | |
"Conv2d 1024 True \n", | |
"BatchNorm2d 128 True \n", | |
"ReLU \n", | |
"____________________________________________________________________________\n", | |
" 8 x 64 x 32 x 40 \n", | |
"Conv2d 576 True \n", | |
"BatchNorm2d 128 True \n", | |
"ReLU \n", | |
"Identity \n", | |
"____________________________________________________________________________\n", | |
" 8 x 24 x 32 x 40 \n", | |
"Conv2d 1536 True \n", | |
"BatchNorm2d 48 True \n", | |
"____________________________________________________________________________\n", | |
" 8 x 72 x 32 x 40 \n", | |
"Conv2d 1728 True \n", | |
"BatchNorm2d 144 True \n", | |
"ReLU \n", | |
"Conv2d 648 True \n", | |
"BatchNorm2d 144 True \n", | |
"ReLU \n", | |
"Identity \n", | |
"____________________________________________________________________________\n", | |
" 8 x 24 x 32 x 40 \n", | |
"Conv2d 1728 True \n", | |
"BatchNorm2d 48 True \n", | |
"____________________________________________________________________________\n", | |
" 8 x 72 x 32 x 40 \n", | |
"Conv2d 1728 True \n", | |
"BatchNorm2d 144 True \n", | |
"ReLU \n", | |
"____________________________________________________________________________\n", | |
" 8 x 72 x 16 x 20 \n", | |
"Conv2d 1800 True \n", | |
"BatchNorm2d 144 True \n", | |
"ReLU \n", | |
"____________________________________________________________________________\n", | |
" 8 x 18 x 1 x 1 \n", | |
"Conv2d 1314 True \n", | |
"ReLU \n", | |
"____________________________________________________________________________\n", | |
" 8 x 72 x 1 x 1 \n", | |
"Conv2d 1368 True \n", | |
"Hardsigmoid \n", | |
"____________________________________________________________________________\n", | |
" 8 x 40 x 16 x 20 \n", | |
"Conv2d 2880 True \n", | |
"BatchNorm2d 80 True \n", | |
"____________________________________________________________________________\n", | |
" 8 x 120 x 16 x 20 \n", | |
"Conv2d 4800 True \n", | |
"BatchNorm2d 240 True \n", | |
"ReLU \n", | |
"Conv2d 3000 True \n", | |
"BatchNorm2d 240 True \n", | |
"ReLU \n", | |
"____________________________________________________________________________\n", | |
" 8 x 30 x 1 x 1 \n", | |
"Conv2d 3630 True \n", | |
"ReLU \n", | |
"____________________________________________________________________________\n", | |
" 8 x 120 x 1 x 1 \n", | |
"Conv2d 3720 True \n", | |
"Hardsigmoid \n", | |
"____________________________________________________________________________\n", | |
" 8 x 40 x 16 x 20 \n", | |
"Conv2d 4800 True \n", | |
"BatchNorm2d 80 True \n", | |
"____________________________________________________________________________\n", | |
" 8 x 120 x 16 x 20 \n", | |
"Conv2d 4800 True \n", | |
"BatchNorm2d 240 True \n", | |
"ReLU \n", | |
"Conv2d 3000 True \n", | |
"BatchNorm2d 240 True \n", | |
"ReLU \n", | |
"____________________________________________________________________________\n", | |
" 8 x 30 x 1 x 1 \n", | |
"Conv2d 3630 True \n", | |
"ReLU \n", | |
"____________________________________________________________________________\n", | |
" 8 x 120 x 1 x 1 \n", | |
"Conv2d 3720 True \n", | |
"Hardsigmoid \n", | |
"____________________________________________________________________________\n", | |
" 8 x 40 x 16 x 20 \n", | |
"Conv2d 4800 True \n", | |
"BatchNorm2d 80 True \n", | |
"____________________________________________________________________________\n", | |
" 8 x 240 x 16 x 20 \n", | |
"Conv2d 9600 True \n", | |
"BatchNorm2d 480 True \n", | |
"Hardswish \n", | |
"____________________________________________________________________________\n", | |
" 8 x 240 x 8 x 10 \n", | |
"Conv2d 2160 True \n", | |
"BatchNorm2d 480 True \n", | |
"Hardswish \n", | |
"Identity \n", | |
"____________________________________________________________________________\n", | |
" 8 x 80 x 8 x 10 \n", | |
"Conv2d 19200 True \n", | |
"BatchNorm2d 160 True \n", | |
"____________________________________________________________________________\n", | |
" 8 x 200 x 8 x 10 \n", | |
"Conv2d 16000 True \n", | |
"BatchNorm2d 400 True \n", | |
"Hardswish \n", | |
"Conv2d 1800 True \n", | |
"BatchNorm2d 400 True \n", | |
"Hardswish \n", | |
"Identity \n", | |
"____________________________________________________________________________\n", | |
" 8 x 80 x 8 x 10 \n", | |
"Conv2d 16000 True \n", | |
"BatchNorm2d 160 True \n", | |
"____________________________________________________________________________\n", | |
" 8 x 184 x 8 x 10 \n", | |
"Conv2d 14720 True \n", | |
"BatchNorm2d 368 True \n", | |
"Hardswish \n", | |
"Conv2d 1656 True \n", | |
"BatchNorm2d 368 True \n", | |
"Hardswish \n", | |
"Identity \n", | |
"____________________________________________________________________________\n", | |
" 8 x 80 x 8 x 10 \n", | |
"Conv2d 14720 True \n", | |
"BatchNorm2d 160 True \n", | |
"____________________________________________________________________________\n", | |
" 8 x 184 x 8 x 10 \n", | |
"Conv2d 14720 True \n", | |
"BatchNorm2d 368 True \n", | |
"Hardswish \n", | |
"Conv2d 1656 True \n", | |
"BatchNorm2d 368 True \n", | |
"Hardswish \n", | |
"Identity \n", | |
"____________________________________________________________________________\n", | |
" 8 x 80 x 8 x 10 \n", | |
"Conv2d 14720 True \n", | |
"BatchNorm2d 160 True \n", | |
"____________________________________________________________________________\n", | |
" 8 x 480 x 8 x 10 \n", | |
"Conv2d 38400 True \n", | |
"BatchNorm2d 960 True \n", | |
"Hardswish \n", | |
"Conv2d 4320 True \n", | |
"BatchNorm2d 960 True \n", | |
"Hardswish \n", | |
"____________________________________________________________________________\n", | |
" 8 x 120 x 1 x 1 \n", | |
"Conv2d 57720 True \n", | |
"Hardswish \n", | |
"____________________________________________________________________________\n", | |
" 8 x 480 x 1 x 1 \n", | |
"Conv2d 58080 True \n", | |
"Hardsigmoid \n", | |
"____________________________________________________________________________\n", | |
" 8 x 112 x 8 x 10 \n", | |
"Conv2d 53760 True \n", | |
"BatchNorm2d 224 True \n", | |
"____________________________________________________________________________\n", | |
" 8 x 672 x 8 x 10 \n", | |
"Conv2d 75264 True \n", | |
"BatchNorm2d 1344 True \n", | |
"Hardswish \n", | |
"Conv2d 6048 True \n", | |
"BatchNorm2d 1344 True \n", | |
"Hardswish \n", | |
"____________________________________________________________________________\n", | |
" 8 x 168 x 1 x 1 \n", | |
"Conv2d 113064 True \n", | |
"Hardswish \n", | |
"____________________________________________________________________________\n", | |
" 8 x 672 x 1 x 1 \n", | |
"Conv2d 113568 True \n", | |
"Hardsigmoid \n", | |
"____________________________________________________________________________\n", | |
" 8 x 112 x 8 x 10 \n", | |
"Conv2d 75264 True \n", | |
"BatchNorm2d 224 True \n", | |
"____________________________________________________________________________\n", | |
" 8 x 672 x 8 x 10 \n", | |
"Conv2d 75264 True \n", | |
"BatchNorm2d 1344 True \n", | |
"Hardswish \n", | |
"____________________________________________________________________________\n", | |
" 8 x 672 x 4 x 5 \n", | |
"Conv2d 16800 True \n", | |
"BatchNorm2d 1344 True \n", | |
"Hardswish \n", | |
"____________________________________________________________________________\n", | |
" 8 x 168 x 1 x 1 \n", | |
"Conv2d 113064 True \n", | |
"Hardswish \n", | |
"____________________________________________________________________________\n", | |
" 8 x 672 x 1 x 1 \n", | |
"Conv2d 113568 True \n", | |
"Hardsigmoid \n", | |
"____________________________________________________________________________\n", | |
" 8 x 160 x 4 x 5 \n", | |
"Conv2d 107520 True \n", | |
"BatchNorm2d 320 True \n", | |
"____________________________________________________________________________\n", | |
" 8 x 960 x 4 x 5 \n", | |
"Conv2d 153600 True \n", | |
"BatchNorm2d 1920 True \n", | |
"Hardswish \n", | |
"Conv2d 24000 True \n", | |
"BatchNorm2d 1920 True \n", | |
"Hardswish \n", | |
"____________________________________________________________________________\n", | |
" 8 x 240 x 1 x 1 \n", | |
"Conv2d 230640 True \n", | |
"Hardswish \n", | |
"____________________________________________________________________________\n", | |
" 8 x 960 x 1 x 1 \n", | |
"Conv2d 231360 True \n", | |
"Hardsigmoid \n", | |
"____________________________________________________________________________\n", | |
" 8 x 160 x 4 x 5 \n", | |
"Conv2d 153600 True \n", | |
"BatchNorm2d 320 True \n", | |
"____________________________________________________________________________\n", | |
" 8 x 960 x 4 x 5 \n", | |
"Conv2d 153600 True \n", | |
"BatchNorm2d 1920 True \n", | |
"Hardswish \n", | |
"Conv2d 24000 True \n", | |
"BatchNorm2d 1920 True \n", | |
"Hardswish \n", | |
"____________________________________________________________________________\n", | |
" 8 x 240 x 1 x 1 \n", | |
"Conv2d 230640 True \n", | |
"Hardswish \n", | |
"____________________________________________________________________________\n", | |
" 8 x 960 x 1 x 1 \n", | |
"Conv2d 231360 True \n", | |
"Hardsigmoid \n", | |
"____________________________________________________________________________\n", | |
" 8 x 160 x 4 x 5 \n", | |
"Conv2d 153600 True \n", | |
"BatchNorm2d 320 True \n", | |
"____________________________________________________________________________\n", | |
" 8 x 960 x 4 x 5 \n", | |
"Conv2d 153600 True \n", | |
"BatchNorm2d 1920 True \n", | |
"Hardswish \n", | |
"____________________________________________________________________________\n", | |
" 8 x 256 x 8 x 10 \n", | |
"Conv2d 2211840 True \n", | |
"BatchNorm2d 512 True \n", | |
"ReLU \n", | |
"Conv2d 589824 True \n", | |
"BatchNorm2d 512 True \n", | |
"ReLU \n", | |
"Conv2d 12656 True \n", | |
"ConvTranspose2d 12656 True \n", | |
"Conv2d 1008 True \n", | |
"BatchNorm2d 224 True \n", | |
"ReLU \n", | |
"Identity \n", | |
"Conv2d 12544 True \n", | |
"BatchNorm2d 224 True \n", | |
"Identity \n", | |
"Conv2d 12656 True \n", | |
"Conv2d 3312 True \n", | |
"BatchNorm2d 736 True \n", | |
"ReLU \n", | |
"Identity \n", | |
"Conv2d 135424 True \n", | |
"BatchNorm2d 736 True \n", | |
"Identity \n", | |
"____________________________________________________________________________\n", | |
" 8 x 128 x 8 x 10 \n", | |
"Conv2d 47232 True \n", | |
"____________________________________________________________________________\n", | |
" 8 x 128 x 16 x 20 \n", | |
"Conv2d 129024 True \n", | |
"BatchNorm2d 256 True \n", | |
"ReLU \n", | |
"Conv2d 147456 True \n", | |
"BatchNorm2d 256 True \n", | |
"ReLU \n", | |
"Conv2d 1640 True \n", | |
"ConvTranspose2d 1640 True \n", | |
"Conv2d 360 True \n", | |
"BatchNorm2d 80 True \n", | |
"ReLU \n", | |
"Identity \n", | |
"Conv2d 1600 True \n", | |
"BatchNorm2d 80 True \n", | |
"Identity \n", | |
"Conv2d 1640 True \n", | |
"Conv2d 1512 True \n", | |
"BatchNorm2d 336 True \n", | |
"ReLU \n", | |
"Identity \n", | |
"Conv2d 28224 True \n", | |
"BatchNorm2d 336 True \n", | |
"Identity \n", | |
"____________________________________________________________________________\n", | |
" 8 x 64 x 16 x 20 \n", | |
"Conv2d 10816 True \n", | |
"____________________________________________________________________________\n", | |
" 8 x 64 x 32 x 40 \n", | |
"Conv2d 23040 True \n", | |
"BatchNorm2d 128 True \n", | |
"ReLU \n", | |
"Conv2d 36864 True \n", | |
"BatchNorm2d 128 True \n", | |
"ReLU \n", | |
"Conv2d 600 True \n", | |
"ConvTranspose2d 600 True \n", | |
"Conv2d 216 True \n", | |
"BatchNorm2d 48 True \n", | |
"ReLU \n", | |
"Identity \n", | |
"Conv2d 576 True \n", | |
"BatchNorm2d 48 True \n", | |
"Identity \n", | |
"Conv2d 600 True \n", | |
"Conv2d 792 True \n", | |
"BatchNorm2d 176 True \n", | |
"ReLU \n", | |
"Identity \n", | |
"Conv2d 7744 True \n", | |
"BatchNorm2d 176 True \n", | |
"Identity \n", | |
"____________________________________________________________________________\n", | |
" 8 x 32 x 32 x 40 \n", | |
"Conv2d 2848 True \n", | |
"____________________________________________________________________________\n", | |
" 8 x 32 x 64 x 80 \n", | |
"Conv2d 6912 True \n", | |
"BatchNorm2d 64 True \n", | |
"ReLU \n", | |
"Conv2d 9216 True \n", | |
"BatchNorm2d 64 True \n", | |
"ReLU \n", | |
"Conv2d 272 True \n", | |
"ConvTranspose2d 272 True \n", | |
"Conv2d 144 True \n", | |
"BatchNorm2d 32 True \n", | |
"ReLU \n", | |
"Identity \n", | |
"Conv2d 256 True \n", | |
"BatchNorm2d 32 True \n", | |
"Identity \n", | |
"Conv2d 272 True \n", | |
"Conv2d 432 True \n", | |
"BatchNorm2d 96 True \n", | |
"ReLU \n", | |
"Identity \n", | |
"Conv2d 2304 True \n", | |
"BatchNorm2d 96 True \n", | |
"Identity \n", | |
"____________________________________________________________________________\n", | |
" 8 x 16 x 64 x 80 \n", | |
"Conv2d 784 True \n", | |
"Conv2d 2304 True \n", | |
"BatchNorm2d 32 True \n", | |
"ReLU \n", | |
"Conv2d 2304 True \n", | |
"BatchNorm2d 32 True \n", | |
"ReLU \n", | |
"____________________________________________________________________________\n", | |
" 8 x 32 x 128 x 160 \n", | |
"Conv2d 544 True \n", | |
"____________________________________________________________________________\n", | |
"\n", | |
"Total params: 6,438,518\n", | |
"Total trainable params: 6,438,518\n", | |
"Total non-trainable params: 0\n", | |
"\n", | |
"Optimizer used: <function Adam at 0x7f9a487a6e60>\n", | |
"Loss function: <fastai.losses.DiceLoss object at 0x7f9a43e606d0>\n", | |
"\n", | |
"Model unfrozen\n", | |
"\n", | |
"Callbacks:\n", | |
" - TrainEvalCallback\n", | |
" - Recorder\n", | |
" - ProgressCallback" | |
], | |
"application/vnd.google.colaboratory.intrinsic+json": { | |
"type": "string" | |
} | |
}, | |
"metadata": {}, | |
"execution_count": 142 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"" | |
], | |
"metadata": { | |
"id": "sy8eP1xjbI4h" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment