Created
August 26, 2021 04:38
-
-
Save calebrob6/93b34537d960265eb6c23e5b877bafc5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "a6bc9db5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%matplotlib inline\n", | |
"import sys\n", | |
"import os\n", | |
"import time\n", | |
"\n", | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"import torch\n", | |
"torch.backends.cudnn.deterministic = False\n", | |
"torch.backends.cudnn.benchmark = True\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"import torch.optim as optim\n", | |
"\n", | |
"from torch import Tensor\n", | |
"from torch.nn.modules import Module\n", | |
"\n", | |
"import segmentation_models_pytorch as smp\n", | |
"\n", | |
"import rasterio" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "ec3250e2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"device = torch.device(\"cuda\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "124cceb8", | |
"metadata": {}, | |
"source": [ | |
"## Data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "c1ca86e7", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"fns = [\n", | |
" '2021-08-04-vv.tif',\n", | |
" '2021-07-23-vh.tif',\n", | |
" '2021-07-23-vv.tif',\n", | |
" '2021-08-04-vh.tif'\n", | |
"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "18fd96a2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"with rasterio.open(fns[1]) as f:\n", | |
" y = f.read()\n", | |
" \n", | |
"y = np.log1p(y)\n", | |
"mean = y.mean(dtype=np.float64)\n", | |
"std = y.std(dtype=np.float64)\n", | |
"y = (y - mean) / std " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "dad8e4bd", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"x = np.random.randn(y.shape[0], y.shape[1], y.shape[2])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "7f362ad3", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"x = x[:,2000:4000,2000:4000]\n", | |
"y = y[:,2000:4000,2000:4000]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "9be39ffd", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"x = torch.from_numpy(x).unsqueeze(0).float().to(device)\n", | |
"y = torch.from_numpy(y).unsqueeze(0).float().to(device)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8d007f11", | |
"metadata": {}, | |
"source": [ | |
"## Training" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "4ea08bda", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def fit(model, device, x, y, optimizer, criterion, epoch, memo=''):\n", | |
" model.train()\n", | |
" \n", | |
" tic = time.time()\n", | |
" \n", | |
" optimizer.zero_grad()\n", | |
" outputs = model(x)\n", | |
" loss = criterion(outputs, y)\n", | |
" epoch_loss = loss.item()\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" \n", | |
" print('[{}] Training Epoch: {}\\t Time elapsed: {:.2f} seconds\\t Loss: {:.2f}'.format(\n", | |
" memo, epoch, time.time()-tic, epoch_loss)\n", | |
" )\n", | |
" \n", | |
" return epoch_loss" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "1cdecb78", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = smp.Unet(\n", | |
" encoder_name='resnet18', encoder_depth=3, encoder_weights=None,\n", | |
" decoder_channels=(128, 64, 64), in_channels=1, classes=1\n", | |
")\n", | |
"model = model.to(device)\n", | |
"\n", | |
"optimizer = optim.AdamW(model.parameters(), lr=0.001, amsgrad=True)\n", | |
"criterion = nn.MSELoss()\n", | |
"scheduler = optim.lr_scheduler.StepLR(optimizer, 150, verbose=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "2b612c1f", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[] Training Epoch: 0\t Time elapsed: 4.78 seconds\t Loss: 1.83\n", | |
"[] Training Epoch: 1\t Time elapsed: 0.35 seconds\t Loss: 1.42\n", | |
"[] Training Epoch: 2\t Time elapsed: 0.66 seconds\t Loss: 1.28\n", | |
"[] Training Epoch: 3\t Time elapsed: 0.43 seconds\t Loss: 1.21\n", | |
"[] Training Epoch: 4\t Time elapsed: 0.48 seconds\t Loss: 1.11\n", | |
"[] Training Epoch: 5\t Time elapsed: 0.48 seconds\t Loss: 0.98\n", | |
"[] Training Epoch: 6\t Time elapsed: 0.48 seconds\t Loss: 0.84\n", | |
"[] Training Epoch: 7\t Time elapsed: 0.48 seconds\t Loss: 0.72\n", | |
"[] Training Epoch: 8\t Time elapsed: 0.48 seconds\t Loss: 0.64\n", | |
"[] Training Epoch: 9\t Time elapsed: 0.48 seconds\t Loss: 0.59\n", | |
"[] Training Epoch: 10\t Time elapsed: 0.48 seconds\t Loss: 0.56\n", | |
"[] Training Epoch: 11\t Time elapsed: 0.49 seconds\t Loss: 0.52\n", | |
"[] Training Epoch: 12\t Time elapsed: 0.48 seconds\t Loss: 0.49\n", | |
"[] Training Epoch: 13\t Time elapsed: 0.48 seconds\t Loss: 0.47\n", | |
"[] Training Epoch: 14\t Time elapsed: 0.48 seconds\t Loss: 0.45\n", | |
"[] Training Epoch: 15\t Time elapsed: 0.48 seconds\t Loss: 0.43\n", | |
"[] Training Epoch: 16\t Time elapsed: 0.48 seconds\t Loss: 0.42\n", | |
"[] Training Epoch: 17\t Time elapsed: 0.48 seconds\t Loss: 0.41\n", | |
"[] Training Epoch: 18\t Time elapsed: 0.48 seconds\t Loss: 0.40\n", | |
"[] Training Epoch: 19\t Time elapsed: 0.48 seconds\t Loss: 0.39\n", | |
"[] Training Epoch: 20\t Time elapsed: 0.48 seconds\t Loss: 0.38\n", | |
"[] Training Epoch: 21\t Time elapsed: 0.48 seconds\t Loss: 0.37\n", | |
"[] Training Epoch: 22\t Time elapsed: 0.49 seconds\t Loss: 0.37\n", | |
"[] Training Epoch: 23\t Time elapsed: 0.48 seconds\t Loss: 0.36\n", | |
"[] Training Epoch: 24\t Time elapsed: 0.49 seconds\t Loss: 0.35\n", | |
"[] Training Epoch: 25\t Time elapsed: 0.48 seconds\t Loss: 0.35\n", | |
"[] Training Epoch: 26\t Time elapsed: 0.48 seconds\t Loss: 0.34\n", | |
"[] Training Epoch: 27\t Time elapsed: 0.48 seconds\t Loss: 0.34\n", | |
"[] Training Epoch: 28\t Time elapsed: 0.49 seconds\t Loss: 0.34\n", | |
"[] Training Epoch: 29\t Time elapsed: 0.48 seconds\t Loss: 0.33\n", | |
"[] Training Epoch: 30\t Time elapsed: 0.48 seconds\t Loss: 0.33\n", | |
"[] Training Epoch: 31\t Time elapsed: 0.48 seconds\t Loss: 0.32\n", | |
"[] Training Epoch: 32\t Time elapsed: 0.48 seconds\t Loss: 0.32\n", | |
"[] Training Epoch: 33\t Time elapsed: 0.48 seconds\t Loss: 0.32\n", | |
"[] Training Epoch: 34\t Time elapsed: 0.48 seconds\t Loss: 0.31\n", | |
"[] Training Epoch: 35\t Time elapsed: 0.48 seconds\t Loss: 0.31\n", | |
"[] Training Epoch: 36\t Time elapsed: 0.48 seconds\t Loss: 0.31\n", | |
"[] Training Epoch: 37\t Time elapsed: 0.48 seconds\t Loss: 0.30\n", | |
"[] Training Epoch: 38\t Time elapsed: 0.48 seconds\t Loss: 0.30\n", | |
"[] Training Epoch: 39\t Time elapsed: 0.48 seconds\t Loss: 0.30\n", | |
"[] Training Epoch: 40\t Time elapsed: 0.48 seconds\t Loss: 0.30\n", | |
"[] Training Epoch: 41\t Time elapsed: 0.48 seconds\t Loss: 0.29\n", | |
"[] Training Epoch: 42\t Time elapsed: 0.48 seconds\t Loss: 0.29\n", | |
"[] Training Epoch: 43\t Time elapsed: 0.48 seconds\t Loss: 0.29\n", | |
"[] Training Epoch: 44\t Time elapsed: 0.48 seconds\t Loss: 0.29\n", | |
"[] Training Epoch: 45\t Time elapsed: 0.48 seconds\t Loss: 0.28\n", | |
"[] Training Epoch: 46\t Time elapsed: 0.48 seconds\t Loss: 0.28\n", | |
"[] Training Epoch: 47\t Time elapsed: 0.48 seconds\t Loss: 0.28\n", | |
"[] Training Epoch: 48\t Time elapsed: 0.48 seconds\t Loss: 0.28\n", | |
"[] Training Epoch: 49\t Time elapsed: 0.48 seconds\t Loss: 0.28\n", | |
"[] Training Epoch: 50\t Time elapsed: 0.48 seconds\t Loss: 0.27\n", | |
"[] Training Epoch: 51\t Time elapsed: 0.48 seconds\t Loss: 0.27\n", | |
"[] Training Epoch: 52\t Time elapsed: 0.48 seconds\t Loss: 0.27\n", | |
"[] Training Epoch: 53\t Time elapsed: 0.48 seconds\t Loss: 0.27\n", | |
"[] Training Epoch: 54\t Time elapsed: 0.48 seconds\t Loss: 0.27\n", | |
"[] Training Epoch: 55\t Time elapsed: 0.48 seconds\t Loss: 0.27\n", | |
"[] Training Epoch: 56\t Time elapsed: 0.48 seconds\t Loss: 0.26\n", | |
"[] Training Epoch: 57\t Time elapsed: 0.48 seconds\t Loss: 0.26\n", | |
"[] Training Epoch: 58\t Time elapsed: 0.48 seconds\t Loss: 0.26\n", | |
"[] Training Epoch: 59\t Time elapsed: 0.48 seconds\t Loss: 0.26\n", | |
"[] Training Epoch: 60\t Time elapsed: 0.48 seconds\t Loss: 0.26\n", | |
"[] Training Epoch: 61\t Time elapsed: 0.48 seconds\t Loss: 0.27\n", | |
"[] Training Epoch: 62\t Time elapsed: 0.48 seconds\t Loss: 0.27\n", | |
"[] Training Epoch: 63\t Time elapsed: 0.48 seconds\t Loss: 0.26\n", | |
"[] Training Epoch: 64\t Time elapsed: 0.48 seconds\t Loss: 0.26\n", | |
"[] Training Epoch: 65\t Time elapsed: 0.48 seconds\t Loss: 0.26\n", | |
"[] Training Epoch: 66\t Time elapsed: 0.48 seconds\t Loss: 0.25\n", | |
"[] Training Epoch: 67\t Time elapsed: 0.48 seconds\t Loss: 0.26\n", | |
"[] Training Epoch: 68\t Time elapsed: 0.48 seconds\t Loss: 0.25\n", | |
"[] Training Epoch: 69\t Time elapsed: 0.48 seconds\t Loss: 0.25\n", | |
"[] Training Epoch: 70\t Time elapsed: 0.48 seconds\t Loss: 0.25\n", | |
"[] Training Epoch: 71\t Time elapsed: 0.48 seconds\t Loss: 0.24\n", | |
"[] Training Epoch: 72\t Time elapsed: 0.48 seconds\t Loss: 0.25\n", | |
"[] Training Epoch: 73\t Time elapsed: 0.48 seconds\t Loss: 0.24\n", | |
"[] Training Epoch: 74\t Time elapsed: 0.48 seconds\t Loss: 0.24\n", | |
"[] Training Epoch: 75\t Time elapsed: 0.48 seconds\t Loss: 0.24\n", | |
"[] Training Epoch: 76\t Time elapsed: 0.48 seconds\t Loss: 0.24\n", | |
"[] Training Epoch: 77\t Time elapsed: 0.48 seconds\t Loss: 0.24\n", | |
"[] Training Epoch: 78\t Time elapsed: 0.48 seconds\t Loss: 0.24\n", | |
"[] Training Epoch: 79\t Time elapsed: 0.48 seconds\t Loss: 0.24\n", | |
"[] Training Epoch: 80\t Time elapsed: 0.48 seconds\t Loss: 0.24\n", | |
"[] Training Epoch: 81\t Time elapsed: 0.48 seconds\t Loss: 0.23\n", | |
"[] Training Epoch: 82\t Time elapsed: 0.48 seconds\t Loss: 0.23\n", | |
"[] Training Epoch: 83\t Time elapsed: 0.48 seconds\t Loss: 0.23\n", | |
"[] Training Epoch: 84\t Time elapsed: 0.48 seconds\t Loss: 0.23\n", | |
"[] Training Epoch: 85\t Time elapsed: 0.48 seconds\t Loss: 0.23\n", | |
"[] Training Epoch: 86\t Time elapsed: 0.48 seconds\t Loss: 0.23\n", | |
"[] Training Epoch: 87\t Time elapsed: 0.48 seconds\t Loss: 0.23\n", | |
"[] Training Epoch: 88\t Time elapsed: 0.48 seconds\t Loss: 0.23\n", | |
"[] Training Epoch: 89\t Time elapsed: 0.48 seconds\t Loss: 0.23\n", | |
"[] Training Epoch: 90\t Time elapsed: 0.48 seconds\t Loss: 0.23\n", | |
"[] Training Epoch: 91\t Time elapsed: 0.48 seconds\t Loss: 0.23\n", | |
"[] Training Epoch: 92\t Time elapsed: 0.48 seconds\t Loss: 0.23\n", | |
"[] Training Epoch: 93\t Time elapsed: 0.49 seconds\t Loss: 0.23\n", | |
"[] Training Epoch: 94\t Time elapsed: 0.48 seconds\t Loss: 0.22\n", | |
"[] Training Epoch: 95\t Time elapsed: 0.48 seconds\t Loss: 0.22\n", | |
"[] Training Epoch: 96\t Time elapsed: 0.48 seconds\t Loss: 0.22\n", | |
"[] Training Epoch: 97\t Time elapsed: 0.48 seconds\t Loss: 0.22\n", | |
"[] Training Epoch: 98\t Time elapsed: 0.48 seconds\t Loss: 0.22\n", | |
"[] Training Epoch: 99\t Time elapsed: 0.48 seconds\t Loss: 0.22\n", | |
"[] Training Epoch: 100\t Time elapsed: 0.48 seconds\t Loss: 0.22\n", | |
"[] Training Epoch: 101\t Time elapsed: 0.48 seconds\t Loss: 0.22\n", | |
"[] Training Epoch: 102\t Time elapsed: 0.48 seconds\t Loss: 0.22\n", | |
"[] Training Epoch: 103\t Time elapsed: 0.48 seconds\t Loss: 0.21\n", | |
"[] Training Epoch: 104\t Time elapsed: 0.48 seconds\t Loss: 0.21\n", | |
"[] Training Epoch: 105\t Time elapsed: 0.48 seconds\t Loss: 0.21\n", | |
"[] Training Epoch: 106\t Time elapsed: 0.48 seconds\t Loss: 0.21\n", | |
"[] Training Epoch: 107\t Time elapsed: 0.48 seconds\t Loss: 0.21\n", | |
"[] Training Epoch: 108\t Time elapsed: 0.48 seconds\t Loss: 0.21\n", | |
"[] Training Epoch: 109\t Time elapsed: 0.48 seconds\t Loss: 0.21\n", | |
"[] Training Epoch: 110\t Time elapsed: 0.48 seconds\t Loss: 0.21\n", | |
"[] Training Epoch: 111\t Time elapsed: 0.48 seconds\t Loss: 0.21\n", | |
"[] Training Epoch: 112\t Time elapsed: 0.48 seconds\t Loss: 0.21\n", | |
"[] Training Epoch: 113\t Time elapsed: 0.48 seconds\t Loss: 0.21\n", | |
"[] Training Epoch: 114\t Time elapsed: 0.48 seconds\t Loss: 0.21\n", | |
"[] Training Epoch: 115\t Time elapsed: 0.48 seconds\t Loss: 0.21\n", | |
"[] Training Epoch: 116\t Time elapsed: 0.48 seconds\t Loss: 0.21\n", | |
"[] Training Epoch: 117\t Time elapsed: 0.48 seconds\t Loss: 0.21\n", | |
"[] Training Epoch: 118\t Time elapsed: 0.48 seconds\t Loss: 0.20\n", | |
"[] Training Epoch: 119\t Time elapsed: 0.48 seconds\t Loss: 0.20\n", | |
"[] Training Epoch: 120\t Time elapsed: 0.48 seconds\t Loss: 0.20\n", | |
"[] Training Epoch: 121\t Time elapsed: 0.48 seconds\t Loss: 0.20\n", | |
"[] Training Epoch: 122\t Time elapsed: 0.49 seconds\t Loss: 0.20\n", | |
"[] Training Epoch: 123\t Time elapsed: 0.48 seconds\t Loss: 0.20\n", | |
"[] Training Epoch: 124\t Time elapsed: 0.48 seconds\t Loss: 0.20\n", | |
"[] Training Epoch: 125\t Time elapsed: 0.48 seconds\t Loss: 0.20\n", | |
"[] Training Epoch: 126\t Time elapsed: 0.48 seconds\t Loss: 0.20\n", | |
"[] Training Epoch: 127\t Time elapsed: 0.48 seconds\t Loss: 0.21\n", | |
"[] Training Epoch: 128\t Time elapsed: 0.48 seconds\t Loss: 0.20\n", | |
"[] Training Epoch: 129\t Time elapsed: 0.48 seconds\t Loss: 0.20\n", | |
"[] Training Epoch: 130\t Time elapsed: 0.48 seconds\t Loss: 0.20\n", | |
"[] Training Epoch: 131\t Time elapsed: 0.48 seconds\t Loss: 0.20\n", | |
"[] Training Epoch: 132\t Time elapsed: 0.48 seconds\t Loss: 0.20\n", | |
"[] Training Epoch: 133\t Time elapsed: 0.48 seconds\t Loss: 0.20\n", | |
"[] Training Epoch: 134\t Time elapsed: 0.48 seconds\t Loss: 0.20\n", | |
"[] Training Epoch: 135\t Time elapsed: 0.48 seconds\t Loss: 0.19\n", | |
"[] Training Epoch: 136\t Time elapsed: 0.48 seconds\t Loss: 0.19\n", | |
"[] Training Epoch: 137\t Time elapsed: 0.48 seconds\t Loss: 0.20\n", | |
"[] Training Epoch: 138\t Time elapsed: 0.48 seconds\t Loss: 0.19\n", | |
"[] Training Epoch: 139\t Time elapsed: 0.48 seconds\t Loss: 0.19\n", | |
"[] Training Epoch: 140\t Time elapsed: 0.48 seconds\t Loss: 0.19\n", | |
"[] Training Epoch: 141\t Time elapsed: 0.48 seconds\t Loss: 0.19\n", | |
"[] Training Epoch: 142\t Time elapsed: 0.48 seconds\t Loss: 0.19\n", | |
"[] Training Epoch: 143\t Time elapsed: 0.48 seconds\t Loss: 0.19\n", | |
"[] Training Epoch: 144\t Time elapsed: 0.48 seconds\t Loss: 0.19\n", | |
"[] Training Epoch: 145\t Time elapsed: 0.48 seconds\t Loss: 0.19\n", | |
"[] Training Epoch: 146\t Time elapsed: 0.48 seconds\t Loss: 0.19\n", | |
"[] Training Epoch: 147\t Time elapsed: 0.48 seconds\t Loss: 0.19\n", | |
"[] Training Epoch: 148\t Time elapsed: 0.48 seconds\t Loss: 0.19\n", | |
"[] Training Epoch: 149\t Time elapsed: 0.48 seconds\t Loss: 0.19\n", | |
"[] Training Epoch: 150\t Time elapsed: 0.49 seconds\t Loss: 0.19\n", | |
"[] Training Epoch: 151\t Time elapsed: 0.48 seconds\t Loss: 0.19\n", | |
"[] Training Epoch: 152\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 153\t Time elapsed: 0.49 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 154\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 155\t Time elapsed: 0.49 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 156\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 157\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 158\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 159\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 160\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 161\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 162\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 163\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 164\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 165\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 166\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 167\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 168\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 169\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 170\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 171\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 172\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 173\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 174\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 175\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 176\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 177\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 178\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 179\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 180\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 181\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 182\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 183\t Time elapsed: 0.49 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 184\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 185\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 186\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 187\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 188\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 189\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 190\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 191\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 192\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 193\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 194\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 195\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 196\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 197\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 198\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 199\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 200\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 201\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 202\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 203\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 204\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 205\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 206\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 207\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 208\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 209\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 210\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 211\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 212\t Time elapsed: 0.49 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 213\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 214\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 215\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 216\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 217\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 218\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 219\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 220\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 221\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 222\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 223\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 224\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 225\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 226\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 227\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 228\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 229\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 230\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 231\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 232\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 233\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 234\t Time elapsed: 0.49 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 235\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 236\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 237\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 238\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 239\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 240\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 241\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 242\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 243\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 244\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 245\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 246\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 247\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 248\t Time elapsed: 0.49 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 249\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 250\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 251\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 252\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 253\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 254\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 255\t Time elapsed: 0.49 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 256\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 257\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 258\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 259\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 260\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 261\t Time elapsed: 0.49 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 262\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 263\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 264\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 265\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 266\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 267\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 268\t Time elapsed: 0.49 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 269\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 270\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 271\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 272\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 273\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 274\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 275\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 276\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 277\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 278\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 279\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 280\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 281\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 282\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 283\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 284\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 285\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 286\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 287\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 288\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 289\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 290\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 291\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 292\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 293\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 294\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 295\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 296\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 297\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 298\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 299\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 300\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 301\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 302\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 303\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 304\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 305\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 306\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 307\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 308\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 309\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 310\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 311\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 312\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 313\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 314\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 315\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 316\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 317\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 318\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 319\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 320\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 321\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 322\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 323\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 324\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 325\t Time elapsed: 0.49 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 326\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 327\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 328\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 329\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 330\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 331\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 332\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 333\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 334\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 335\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 336\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 337\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 338\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 339\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 340\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 341\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 342\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 343\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 344\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 345\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 346\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 347\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 348\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 349\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 350\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 351\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 352\t Time elapsed: 0.49 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 353\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 354\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 355\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 356\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 357\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 358\t Time elapsed: 0.49 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 359\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 360\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 361\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 362\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 363\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 364\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 365\t Time elapsed: 0.49 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 366\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 367\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 368\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 369\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 370\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 371\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 372\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 373\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 374\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 375\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 376\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 377\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 378\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 379\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 380\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 381\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 382\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 383\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 384\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 385\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 386\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 387\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 388\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 389\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 390\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 391\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 392\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 393\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 394\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 395\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 396\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 397\t Time elapsed: 0.49 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 398\t Time elapsed: 0.48 seconds\t Loss: 0.18\n", | |
"[] Training Epoch: 399\t Time elapsed: 0.48 seconds\t Loss: 0.18\n" | |
] | |
} | |
], | |
"source": [ | |
"losses = []\n", | |
"for epoch in range(400):\n", | |
" loss = fit(\n", | |
" model, device, x, y, optimizer, criterion, epoch\n", | |
" )\n", | |
" losses.append(loss)\n", | |
" scheduler.step()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a3f24a40", | |
"metadata": {}, | |
"source": [ | |
"## Visualize" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "1d770ba4", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAbz0lEQVR4nO3dfXRc9Z3f8fd3ZjSynmzZkmyMn00IYCg2oBoIKeA0S+yExOUc9qzdNN1tk/WSA3va5uw2pOmBNP2j3eVsm83mgbhZL7tNgU0LTtgcwsOGsGweIJbBNjbYYIzBQjaSbWzZerAe5ts/5o481sxII+vhyvd+XsdzdO/93Yev7pE/+uk3d+41d0dERKIrEXYBIiIyuRT0IiIRp6AXEYk4Bb2ISMQp6EVEIi4VdgHFNDY2+tKlS8MuQ0TkgrF9+/aj7t5UrG1aBv3SpUtpaWkJuwwRkQuGmb1Tqk1DNyIiEaegFxGJOAW9iEjEKehFRCJOQS8iEnEKehGRiFPQi4hEXKSC/ps/e5N/eKMj7DJERKaVSAX9d59/i1+8qaAXEckXqaBPJYyBjB6kIiKSL1JBn0wagwp6EZFzRCroUwkFvYjIcJEK+qSCXkSkQKSCPpVIaIxeRGSYSAW9evQiIoVGvR+9mW0Bbgfa3f2qIu1/DHw2b39XAE3uftzMDgKngEFgwN2bJ6rwYnTVjYhIoXJ69A8Ba0s1uvsD7r7K3VcBXwH+wd2P562yJmif1JCHXI8+M9mHERG5oIwa9O7+AnB8tPUCG4FHxlXROCQTxsCgevQiIvkmbIzezKrJ9vwfy1vswDNmtt3MNo2y/SYzazGzlo6O8/t0a0rX0YuIFJjIN2M/Dfxy2LDNTe5+LbAOuNvMbi61sbtvdvdmd29uair6fNtRJXXVjYhIgYkM+g0MG7Zx97bgazuwFVg9gccroA9MiYgUmpCgN7NZwC3Aj/OW1ZhZXW4auA3YPRHHKyWZMAb0ZqyIyDnKubzyEeBWoNHMWoH7gQoAd38wWO0O4Bl378rbdB6w1cxyx3nY3Z+auNILpRJG/6CCXkQk36hB7+4by1jnIbKXYeYvOwCsPN/CzkcyYfT0a+hGRCRfpD4ZqzF6EZFCkQr6ZCKh6+hFRIaJVNCrRy8iUihSQZ9M6qobEZHhIhX06tGLiBSKVNAndfdKEZECkQp69ehFRApFKuh1rxsRkUKRCnr16EVECkUq6LP3o9dVNyIi+SIV9OrRi4gUilTQZ6+jV9CLiOSLVNCrRy8iUihSQZ+76sZdYS8ikhOpoE8lDAB16kVEzopU0CeDoNf9bkREzopU0Od69BqnFxE5K1JBf7ZHr6AXEcmJVNAP9ej18BERkSGRCvpkMvvtqEcvInJWtILeNEYvIjLcqEFvZlvMrN3Mdpdov9XMTprZjuB1X17bWjPbZ2b7zezeiSy8mJSuuhERKVBOj/4hYO0o6/yju68KXl8HMLMk8G1gHbAC2GhmK8ZT7GiSuupGRKTAqEHv7i8Ax89j36uB/e5+wN37gEeB9eexn7KlkrrqRkRkuIkao7/RzHaa2U/N7Mpg2QLgUN46rcGyosxsk5m1mFlLR0fHeRWhHr2ISKGJCPqXgSXuvhL4C+BHwXIrsm7JBHb3ze7e7O7NTU1N51XI0Bi9Lq8UERky7qB39053Px1MPwlUmFkj2R78orxVFwJt4z3eSJKJ7LejHr2IyFnjDnozu8gse12jma0O9nkM2AZcambLzCwNbACeGO/xRjL0gSndvVJEZEhqtBXM7BHgVqDRzFqB+4EKAHd/ELgT+KKZDQA9wAbP3id4wMzuAZ4GksAWd98zKd9F4OwYvS6vFBHJGTXo3X3jKO3fAr5Vou1J4MnzK23sNEYvIlIoWp+M1VU3IiIFIhX0uo5eRKRQpIJeV92IiBSKVNDnxujPDOjNWBGRnEgFfXU6CUBP/0DIlYiITB+RCvqayuxFRN19gyFXIiIyfUQq6KtyPXoFvYjIkEgFfXVFNui7zijoRURyIhX0qWSCdCpBt8boRUSGRCroIfuGrIZuRETOilzQ16RTGroREckTuaCvSid1eaWISJ7IBX11OqnLK0VE8kQz6DV0IyIyJIJBn9JVNyIieSIY9Bq6ERHJF82g19CNiMiQCAZ9iu4+Dd2IiOREMOg1dCMiki+SQT+Qcfp0T3oRESCCQZ+7VXHXGQ3fiIhAGUFvZlvMrN3Mdpdo/6yZ7QpevzKzlXltB83sVTPbYWYtE1l4KXNq0gAc6+qbisOJiEx75fToHwLWjtD+NnCLu18N/Fdg87D2Ne6+yt2bz6/EsWmsrQTg6OkzU3E4EZFpLzXaCu7+gpktHaH9V3mzLwILJ6Cu86agFxE510SP0X8e+GnevAPPmNl2M9s00oZmtsnMWsyspaOj47wLaKzNDt0cPaWgFxGBMnr05TKzNWSD/qN5i29y9zYzmws8a2Z73f2FYtu7+2aCYZ/m5mY/3zrqq9MkTGP0IiI5E9KjN7Orge8D6939WG65u7cFX9uBrcDqiTjeSJIJY05NpYZuREQC4w56M1sMPA58zt3fyFteY2Z1uWngNqDolTsTrbE2Tccp9ehFRKCMoRszewS4FWg0s1bgfqACwN0fBO4DGoDvmBnAQHCFzTxga7AsBTzs7k9NwvdQoLFWPXoRkZxyrrrZOEr7F4AvFFl+AFhZuMXkm1OT5tAH3WEcWkRk2oncJ2MBZlVVcLKnP+wyRESmhcgGfWdPP5nMeV+8IyISGZEM+vrqCjIOp3S/GxGRaAb9zKoKADo1fCMiEs2gnxUEvcbpRUQU9CIikRfpoD/RraAXEYlk0NdXq0cvIpITyaDX0I2IyFmRDPqqiiQVSVPQi4gQ0aA3M306VkQkEMmgh+xDwvWAcBGRCAd9dTpFd5+CXkQkskFfW5nktHr0IiLRDfqayhTdfYNhlyEiErroBn06pR69iAhRDvrKpN6MFREhwkFfnU7RfUZDNyIikQ362soUXX0DuOvhIyISb5EN+prKFBmH3v5M2KWIiIRq1KA3sy1m1m5mu0u0m5l908z2m9kuM7s2r22tme0L2u6dyMJHU1OZBNAbsiISe+X06B8C1o7Qvg64NHhtAr4LYGZJ4NtB+wpgo5mtGE+xY1GTTgHoQ1MiEnujBr27vwAcH2GV9cDfeNaLQL2ZzQdWA/vd/YC79wGPButOCfXoRUSyJmKMfgFwKG++NVhWavmUqKnM9ui7dOWNiMTcRAS9FVnmIywvvhOzTWbWYmYtHR0d4y6qOhi66dLQjYjE3EQEfSuwKG9+IdA2wvKi3H2zuze7e3NTU9O4i6od6tEr6EUk3iYi6J8A/nVw9c0NwEl3PwxsAy41s2VmlgY2BOtOidwYvT40JSJxlxptBTN7BLgVaDSzVuB+oALA3R8EngQ+CewHuoF/E7QNmNk9wNNAEtji7nsm4XsoKnfVjd6MFZG4GzXo3X3jKO0O3F2i7UmyvwimXI2GbkREgAh/MjadSlCRNLp0q2IRibnIBj3ocYIiIhD1oE+ndHmliMRetINe96QXEYl60Kf0yVgRib1oB72GbkREIh70GroREYl40Kc1dCMiEu2gr9TQjYhI5INe97oRkbiLdtCnk/QNZugb0HNjRSS+oh30ut+NiEjUgz57q2KN04tInEU86PU4QRGReAS9evQiEmPRDvq0xuhFRKId9LkxegW9iMRYtIM+rTF6EZFoB73G6EVEoh70uaEb9ehFJL4iHfRVFUkSpjF6EYm3soLezNaa2T4z229m9xZp/2Mz2xG8dpvZoJnNCdoOmtmrQVvLRH8Do9Ste9KLSOylRlvBzJLAt4HfAlqBbWb2hLu/llvH3R8AHgjW/zTwH9z9eN5u1rj70QmtvEzVuie9iMRcOT361cB+dz/g7n3Ao8D6EdbfCDwyEcVNBD1OUETirpygXwAcyptvDZYVMLNqYC3wWN5iB54xs+1mtqnUQcxsk5m1mFlLR0dHGWWVR0M3IhJ35QS9FVnmJdb9NPDLYcM2N7n7tcA64G4zu7nYhu6+2d2b3b25qampjLLKo8cJikjclRP0rcCivPmFQFuJdTcwbNjG3duCr+3AVrJDQVOmVkM3IhJz5QT9NuBSM1tmZmmyYf7E8JXMbBZwC/DjvGU1ZlaXmwZuA3ZPROHlqtbQjYjE3KhX3bj7gJndAzwNJIEt7r7HzO4K2h8MVr0DeMbdu/I2nwdsNbPcsR5296cm8hsYzcyqFCd7+qfykCIi08qoQQ/g7k8CTw5b9uCw+YeAh4YtOwCsHFeF49RQU8mJ7n4GBjOkkpH+fJiISFGRT76G2jQAx7v7Qq5ERCQckQ/6OTVB0Hcp6EUknuIT9KcV9CIST5EP+sbaSgCOqUcvIjEV+aDX0I2IxF3kg352dRoz9ehFJL4iH/TJhFFfVcGx02fCLkVEJBSRD3qAuXUzeL9TQS8i8RSLoF80p5rWD7rDLkNEJBSxCPrFc6p593g37qVuuikiEl0xCfoquvsG9YasiMRSPIK+oRqAd49r+EZE4iceQT8nCPpjCnoRiZ9YBP2iOdUkE8ZbHafDLkVEZMrFIugrU0mWN9bw+uFTYZciIjLlYhH0AJddVMe+9zvDLkNEZMrFJugvv6iOQ8d7OK0HhYtIzMQo6GcCsO+Ihm9EJF5iE/SXXVQHKOhFJH5iE/QLZ1dRW5li3xGN04tIvMQm6M2Myy6q43X16EUkZsoKejNba2b7zGy/md1bpP1WMztpZjuC133lbjuVLr+ojr2HO8lkdM8bEYmPUYPezJLAt4F1wApgo5mtKLLqP7r7quD19TFuOyVWLaqns3dAH5wSkVgpp0e/Gtjv7gfcvQ94FFhf5v7Hs+2Eu27JbAC2v/NBWCWIiEy5coJ+AXAob741WDbcjWa208x+amZXjnFbzGyTmbWYWUtHR0cZZY3dssYaGmrStCjoRSRGygl6K7Js+CD3y8ASd18J/AXwozFsm13ovtndm929uampqYyyxs7MuHbJbPXoRSRWygn6VmBR3vxCoC1/BXfvdPfTwfSTQIWZNZaz7VRrXjKbt4926RmyIhIb5QT9NuBSM1tmZmlgA/BE/gpmdpGZWTC9OtjvsXK2nWoapxeRuBk16N19ALgHeBp4Hfihu+8xs7vM7K5gtTuB3Wa2E/gmsMGzim47Gd9Iua5aMIvqdJLn9raHWYaIyJSx6fgc1ebmZm9paZm0/X/phzt49rX32fbVjzOjIjlpxxERmSpmtt3dm4u1xeaTsfnuvG4hp3oH+Mmuw2GXIiIy6WIZ9Dcub+DD82rZ8ou3mY5/0YiITKRYBr2ZsenmS3jtcCc/3X0k7HJERCZVLIMe4I5rFvDhebX86VN76R/MhF2OiMikiW3QJxPGvesu5+Cxbh5+6d2wyxERmTSxDXqANZfN5cblDXzj79/gZE9/2OWIiEyKWAe9mfGfb7+CEz39fOu5N8MuR0RkUsQ66AGuvHgWv33dQv7qlwd58cCxsMsREZlwsQ96gK9+agVLGqq56wfbOaB71YtIxCjogVlVFfzV760macbvbH6R19r0XFkRiQ4FfWBxQzV/+wc3kEoYv/O9X/OShnFEJCIU9Hk+NLeOx774EebNmsHntvyGp/fow1QicuFT0A9zcX0V//cPbmTF/Jl88Qfb+Z/PvsGZgcGwyxIROW8K+iJm16R5+Pev5zMrL+bPf/Ymt3/zF7p/vYhcsBT0JVSnU3xjwzVs+b1mus4McOeDv+LL/28X73f2hl2aiMiYKOhH8bHL5/HMl27h8zct4/FXWrnlgZ/zJ0/t5ageRSgiF4hYPnjkfL17rJsHntnHT3a1UZlKsOGfLub3b17OgvqqsEsTkZgb6cEjCvrz8FbHaR58/i22vvIeGXfWXDaXz96wmFs+PJdkwsIuT0RiSEE/Sd470cMjL73L37YcouPUGRbUV3HHNQu4feV8LptXR/C8dBGRSaegn2T9gxn+/rX3efg37/LL/UfJOHxobi23Xz2f26+ezyVNtQp9EZlUCvop1HHqDE/tPszf7TrMtoPHcYclDdWsuWwuH7t8Ltcvn0NlqvCB5DsPnaCrb4CPXNIYQtUicqEbd9Cb2Vrgz4Ek8H13/+/D2j8LfDmYPQ180d13Bm0HgVPAIDBQqpB8F3LQ53u/s5en9xzhub3t/PqtY5wZyFCdTnLj8gauXz6HG5Y3cMX8mfzdzja+/Ngu0skEL/6nf87AoDO7Jh12+SJyARlX0JtZEngD+C2gFdgGbHT31/LW+Qjwurt/YGbrgK+5+/VB20Gg2d2PlltwVII+X0/fIL8+cJTn9rbzq/3HOHC065z2yy+qY++RU5hB0ox/+9Fl3PShRm75cBOZjJPQm7wiMoLxBv2NZIP7E8H8VwDc/b+VWH82sNvdFwTzB1HQF2jv7OWlt4/zZvtpFs6uYv2qi7n/x3t48tXDdPYODK23rLGGU70D3L3mEtZcNpeljTUhVi0i09V4g/5OYK27fyGY/xxwvbvfU2L9PwIuz1v/beADwIHvufvmEtttAjYBLF68+Lp33nmnnO8tkp7f184r757ADL77/FucGcg+vDyVMBbMrmJBfRVXzJ/JJ668iCUN1TTWVuqyTpGYG2/Q/zbwiWFBv9rd/7DIumuA7wAfdfdjwbKL3b3NzOYCzwJ/6O4vjHTMOPTox2LnoRMcPtnDjkMn2Xukc2i8P2dBfRVLGqq5flkDF9fPYNWiepY11pBKJujuG+Dg0W4WzamibkZFiN+FiEymkYI+Vcb2rcCivPmFQFuRg1wNfB9Ylwt5AHdvC762m9lWYDUwYtDLuVYuqmflonrWXjUfgN7+QXr7B/nF/qMcOdnLc3vbOdnTzzd+9ga539vJhDGvrpJjXX2cGchQV5li4/WLaahJs7SxhmsW1zNzRgUzKgqvABKRaCkn6LcBl5rZMuA9YAPwL/NXMLPFwOPA59z9jbzlNUDC3U8F07cBX5+o4uNqRkWSGRVJbr/6YgC+8M+WA3Ciu4+jp8/w8rsnePdYN20nephdk+bqhbN4avcRNr9w4Jz9VCSNlQvrWTynmoWzq1g4p5rFc6pZUF/F3JmVRS8DFZELz6hB7+4DZnYP8DTZyyu3uPseM7sraH8QuA9oAL4TfDAodxnlPGBrsCwFPOzuT03KdyLUV6epr07zobl1BW3rVy2g68wADmx7+zitJ3p452gXu947yUtvH+dHO3rIDBvFa6xNc9GsGdRXpekbyNDZ209DbZrPrLyY65bMobd/kMpUguVNtQDsbz/N3iOdXNJUy5UXz9SHxESmCX1gSgDoG8hw+GQPh4730HayhyMnezl8spcjJ3s40dNPZSpB3YwK3uo4zYGOcy8NrapI4ji9/ee+b/DxK+bS1TdIQ02aj6+YR3U6Sd9Ahn+yYBappG6cKjKRxjtGLzGQTiVY0lDDkoaRL990d3a/18mBo6epqkjS2TvAnraTJMy4asFMLps3kz1tJ3ns5VYef/k9aipTHO/q43t5w0b11RUsa6yhMpVgVlUFC2cHQ0ezq5lbV0l9dQX1VWnqZqT0+QGRCaAevUy6U739vPDGURIGg+78fG8H7ad6OdOf4Xh3H+990ENPf/HHNdakk5gZ7k7CjGTSSCWMZMJIBvNJy86nEgkSiWx7ImEkLfumtJmRCKYTlp1PGkPT57adO52w7P4TCYLtsutntzOSibP7yZ9OBB98SyRs6ENw5+wzcXY/+dMF+zzneKXrzB1raJ9WuN7w4yXytjlnP7lXUEep/cj0oh69hKpuRgWfunr+0HzuTeQcd+d4Vx+HPujh2OkznOju54PuPjp7B7LvKziYgTsMZjL0Z5xMxhnI+zoYvAYyTsad/sEM7pDx7HwmAwODGTLuDHr2mLnlQ+s4ZDJ50+7B/LB1ii0ftp+oy/3CShgYRvAPC+ZtaN4wGGont2xYuwUr2TnrBfsedtxz5gvqGvkXUMH2Bfsr3H6sxyzYwxi2n1Od5od33VhQw3gp6CV0ZkZDbSUNtZVhlzJhCn5hlPolkTc9mPFzfzl5qTaC+cLp7C+yYD7D2eki+ynY59B2w34ZjnK8jIPjBP9wz+43O51tyx84cPeCttx8dipoG/YL0zl3QWE7o7SPvEGx38/DRzxGP8bYth++oG7G5ESygl5kEiQSRqJI71AkDLr0QUQk4hT0IiIRp6AXEYk4Bb2ISMQp6EVEIk5BLyIScQp6EZGIU9CLiETctLzXjZl1AOf7LMFGoOzn004h1TU2qmtspmtdMH1ri1pdS9y9qVjDtAz68TCzllI39gmT6hob1TU207UumL61xakuDd2IiEScgl5EJOKiGPSbwy6gBNU1NqprbKZrXTB9a4tNXZEboxcRkXNFsUcvIiJ5FPQiIhEXmaA3s7Vmts/M9pvZvSHXctDMXjWzHWbWEiybY2bPmtmbwdfZU1TLFjNrN7PdectK1mJmXwnO4T4z+8QU1/U1M3svOG87zOyTIdS1yMx+bmavm9keM/t3wfJQz9kIdYV6zsxshpn9xsx2BnX9l2B52OerVF2h/4wFx0qa2Stm9pNgfnLPlwePD7uQX0ASeAtYDqSBncCKEOs5CDQOW/anwL3B9L3An0xRLTcD1wK7R6sFWBGcu0pgWXBOk1NY19eAPyqy7lTWNR+4NpiuA94Ijh/qORuhrlDPGdlHoNYG0xXAS8AN0+B8laor9J+x4HhfAh4GfhLMT+r5ikqPfjWw390PuHsf8CiwPuSahlsP/HUw/dfAv5iKg7r7C8DxMmtZDzzq7mfc/W1gP9lzO1V1lTKVdR1295eD6VPA68ACQj5nI9RVylTV5e5+OpitCF5O+OerVF2lTNnPmJktBD4FfH/Y8SftfEUl6BcAh/LmWxn5P8Fkc+AZM9tuZpuCZfPc/TBk/9MCc0OrrnQt0+E83mNmu4Khndyfr6HUZWZLgWvI9ganzTkbVheEfM6CYYgdQDvwrLtPi/NVoi4I/2fsG8B/BDJ5yyb1fEUl6Is9hTnM60ZvcvdrgXXA3WZ2c4i1jEXY5/G7wCXAKuAw8GfB8imvy8xqgceAf+/unSOtWmTZpNVWpK7Qz5m7D7r7KmAhsNrMrhph9bDrCvV8mdntQLu7by93kyLLxlxXVIK+FViUN78QaAupFty9LfjaDmwl+6fW+2Y2HyD42h5WfSPUEup5dPf3g/+cGeB/cfZP1Cmty8wqyIbp/3H3x4PFoZ+zYnVNl3MW1HICeB5YyzQ4X8Xqmgbn6ybgM2Z2kOwQ88fM7AdM8vmKStBvAy41s2VmlgY2AE+EUYiZ1ZhZXW4auA3YHdTzu8Fqvwv8OIz6AqVqeQLYYGaVZrYMuBT4zVQVlftBD9xB9rxNaV1mZsBfAq+7+//Iawr1nJWqK+xzZmZNZlYfTFcBHwf2Ev75KlpX2OfL3b/i7gvdfSnZnHrO3f8Vk32+Jutd5al+AZ8keyXCW8BXQ6xjOdl3yXcCe3K1AA3Az4A3g69zpqieR8j+idpPtnfw+ZFqAb4anMN9wLoprut/A68Cu4If8Pkh1PVRsn8a7wJ2BK9Phn3ORqgr1HMGXA28Ehx/N3DfaD/vIdcV+s9Y3vFu5exVN5N6vnQLBBGRiIvK0I2IiJSgoBcRiTgFvYhIxCnoRUQiTkEvIhJxCnoRkYhT0IuIRNz/B/PvIJHVX9tuAAAAAElFTkSuQmCC\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.figure()\n", | |
"plt.plot(losses)\n", | |
"plt.show()\n", | |
"plt.close()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "94b93168", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"y_viz = y.squeeze().cpu().numpy()\n", | |
"with torch.no_grad():\n", | |
" y_hat = model(x).squeeze().cpu().numpy()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "2b386abe", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment