Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save calebrob6/93b34537d960265eb6c23e5b877bafc5 to your computer and use it in GitHub Desktop.
Save calebrob6/93b34537d960265eb6c23e5b877bafc5 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "a6bc9db5",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import sys\n",
"import os\n",
"import time\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import torch\n",
"torch.backends.cudnn.deterministic = False\n",
"torch.backends.cudnn.benchmark = True\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"\n",
"from torch import Tensor\n",
"from torch.nn.modules import Module\n",
"\n",
"import segmentation_models_pytorch as smp\n",
"\n",
"import rasterio"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ec3250e2",
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda\")"
]
},
{
"cell_type": "markdown",
"id": "124cceb8",
"metadata": {},
"source": [
"## Data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c1ca86e7",
"metadata": {},
"outputs": [],
"source": [
"fns = [\n",
" '2021-08-04-vv.tif',\n",
" '2021-07-23-vh.tif',\n",
" '2021-07-23-vv.tif',\n",
" '2021-08-04-vh.tif'\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "18fd96a2",
"metadata": {},
"outputs": [],
"source": [
"with rasterio.open(fns[1]) as f:\n",
" y = f.read()\n",
" \n",
"y = np.log1p(y)\n",
"mean = y.mean(dtype=np.float64)\n",
"std = y.std(dtype=np.float64)\n",
"y = (y - mean) / std "
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "dad8e4bd",
"metadata": {},
"outputs": [],
"source": [
"x = np.random.randn(y.shape[0], y.shape[1], y.shape[2])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7f362ad3",
"metadata": {},
"outputs": [],
"source": [
"x = x[:,2000:4000,2000:4000]\n",
"y = y[:,2000:4000,2000:4000]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "9be39ffd",
"metadata": {},
"outputs": [],
"source": [
"x = torch.from_numpy(x).unsqueeze(0).float().to(device)\n",
"y = torch.from_numpy(y).unsqueeze(0).float().to(device)"
]
},
{
"cell_type": "markdown",
"id": "8d007f11",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "4ea08bda",
"metadata": {},
"outputs": [],
"source": [
"def fit(model, device, x, y, optimizer, criterion, epoch, memo=''):\n",
" model.train()\n",
" \n",
" tic = time.time()\n",
" \n",
" optimizer.zero_grad()\n",
" outputs = model(x)\n",
" loss = criterion(outputs, y)\n",
" epoch_loss = loss.item()\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" print('[{}] Training Epoch: {}\\t Time elapsed: {:.2f} seconds\\t Loss: {:.2f}'.format(\n",
" memo, epoch, time.time()-tic, epoch_loss)\n",
" )\n",
" \n",
" return epoch_loss"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "1cdecb78",
"metadata": {},
"outputs": [],
"source": [
"model = smp.Unet(\n",
" encoder_name='resnet18', encoder_depth=3, encoder_weights=None,\n",
" decoder_channels=(128, 64, 64), in_channels=1, classes=1\n",
")\n",
"model = model.to(device)\n",
"\n",
"optimizer = optim.AdamW(model.parameters(), lr=0.001, amsgrad=True)\n",
"criterion = nn.MSELoss()\n",
"scheduler = optim.lr_scheduler.StepLR(optimizer, 150, verbose=False)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "2b612c1f",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[] Training Epoch: 0\t Time elapsed: 4.78 seconds\t Loss: 1.83\n",
"[] Training Epoch: 1\t Time elapsed: 0.35 seconds\t Loss: 1.42\n",
"[] Training Epoch: 2\t Time elapsed: 0.66 seconds\t Loss: 1.28\n",
"[] Training Epoch: 3\t Time elapsed: 0.43 seconds\t Loss: 1.21\n",
"[] Training Epoch: 4\t Time elapsed: 0.48 seconds\t Loss: 1.11\n",
"[] Training Epoch: 5\t Time elapsed: 0.48 seconds\t Loss: 0.98\n",
"[] Training Epoch: 6\t Time elapsed: 0.48 seconds\t Loss: 0.84\n",
"[] Training Epoch: 7\t Time elapsed: 0.48 seconds\t Loss: 0.72\n",
"[] Training Epoch: 8\t Time elapsed: 0.48 seconds\t Loss: 0.64\n",
"[] Training Epoch: 9\t Time elapsed: 0.48 seconds\t Loss: 0.59\n",
"[] Training Epoch: 10\t Time elapsed: 0.48 seconds\t Loss: 0.56\n",
"[] Training Epoch: 11\t Time elapsed: 0.49 seconds\t Loss: 0.52\n",
"[] Training Epoch: 12\t Time elapsed: 0.48 seconds\t Loss: 0.49\n",
"[] Training Epoch: 13\t Time elapsed: 0.48 seconds\t Loss: 0.47\n",
"[] Training Epoch: 14\t Time elapsed: 0.48 seconds\t Loss: 0.45\n",
"[] Training Epoch: 15\t Time elapsed: 0.48 seconds\t Loss: 0.43\n",
"[] Training Epoch: 16\t Time elapsed: 0.48 seconds\t Loss: 0.42\n",
"[] Training Epoch: 17\t Time elapsed: 0.48 seconds\t Loss: 0.41\n",
"[] Training Epoch: 18\t Time elapsed: 0.48 seconds\t Loss: 0.40\n",
"[] Training Epoch: 19\t Time elapsed: 0.48 seconds\t Loss: 0.39\n",
"[] Training Epoch: 20\t Time elapsed: 0.48 seconds\t Loss: 0.38\n",
"[] Training Epoch: 21\t Time elapsed: 0.48 seconds\t Loss: 0.37\n",
"[] Training Epoch: 22\t Time elapsed: 0.49 seconds\t Loss: 0.37\n",
"[] Training Epoch: 23\t Time elapsed: 0.48 seconds\t Loss: 0.36\n",
"[] Training Epoch: 24\t Time elapsed: 0.49 seconds\t Loss: 0.35\n",
"[] Training Epoch: 25\t Time elapsed: 0.48 seconds\t Loss: 0.35\n",
"[] Training Epoch: 26\t Time elapsed: 0.48 seconds\t Loss: 0.34\n",
"[] Training Epoch: 27\t Time elapsed: 0.48 seconds\t Loss: 0.34\n",
"[] Training Epoch: 28\t Time elapsed: 0.49 seconds\t Loss: 0.34\n",
"[] Training Epoch: 29\t Time elapsed: 0.48 seconds\t Loss: 0.33\n",
"[] Training Epoch: 30\t Time elapsed: 0.48 seconds\t Loss: 0.33\n",
"[] Training Epoch: 31\t Time elapsed: 0.48 seconds\t Loss: 0.32\n",
"[] Training Epoch: 32\t Time elapsed: 0.48 seconds\t Loss: 0.32\n",
"[] Training Epoch: 33\t Time elapsed: 0.48 seconds\t Loss: 0.32\n",
"[] Training Epoch: 34\t Time elapsed: 0.48 seconds\t Loss: 0.31\n",
"[] Training Epoch: 35\t Time elapsed: 0.48 seconds\t Loss: 0.31\n",
"[] Training Epoch: 36\t Time elapsed: 0.48 seconds\t Loss: 0.31\n",
"[] Training Epoch: 37\t Time elapsed: 0.48 seconds\t Loss: 0.30\n",
"[] Training Epoch: 38\t Time elapsed: 0.48 seconds\t Loss: 0.30\n",
"[] Training Epoch: 39\t Time elapsed: 0.48 seconds\t Loss: 0.30\n",
"[] Training Epoch: 40\t Time elapsed: 0.48 seconds\t Loss: 0.30\n",
"[] Training Epoch: 41\t Time elapsed: 0.48 seconds\t Loss: 0.29\n",
"[] Training Epoch: 42\t Time elapsed: 0.48 seconds\t Loss: 0.29\n",
"[] Training Epoch: 43\t Time elapsed: 0.48 seconds\t Loss: 0.29\n",
"[] Training Epoch: 44\t Time elapsed: 0.48 seconds\t Loss: 0.29\n",
"[] Training Epoch: 45\t Time elapsed: 0.48 seconds\t Loss: 0.28\n",
"[] Training Epoch: 46\t Time elapsed: 0.48 seconds\t Loss: 0.28\n",
"[] Training Epoch: 47\t Time elapsed: 0.48 seconds\t Loss: 0.28\n",
"[] Training Epoch: 48\t Time elapsed: 0.48 seconds\t Loss: 0.28\n",
"[] Training Epoch: 49\t Time elapsed: 0.48 seconds\t Loss: 0.28\n",
"[] Training Epoch: 50\t Time elapsed: 0.48 seconds\t Loss: 0.27\n",
"[] Training Epoch: 51\t Time elapsed: 0.48 seconds\t Loss: 0.27\n",
"[] Training Epoch: 52\t Time elapsed: 0.48 seconds\t Loss: 0.27\n",
"[] Training Epoch: 53\t Time elapsed: 0.48 seconds\t Loss: 0.27\n",
"[] Training Epoch: 54\t Time elapsed: 0.48 seconds\t Loss: 0.27\n",
"[] Training Epoch: 55\t Time elapsed: 0.48 seconds\t Loss: 0.27\n",
"[] Training Epoch: 56\t Time elapsed: 0.48 seconds\t Loss: 0.26\n",
"[] Training Epoch: 57\t Time elapsed: 0.48 seconds\t Loss: 0.26\n",
"[] Training Epoch: 58\t Time elapsed: 0.48 seconds\t Loss: 0.26\n",
"[] Training Epoch: 59\t Time elapsed: 0.48 seconds\t Loss: 0.26\n",
"[] Training Epoch: 60\t Time elapsed: 0.48 seconds\t Loss: 0.26\n",
"[] Training Epoch: 61\t Time elapsed: 0.48 seconds\t Loss: 0.27\n",
"[] Training Epoch: 62\t Time elapsed: 0.48 seconds\t Loss: 0.27\n",
"[] Training Epoch: 63\t Time elapsed: 0.48 seconds\t Loss: 0.26\n",
"[] Training Epoch: 64\t Time elapsed: 0.48 seconds\t Loss: 0.26\n",
"[] Training Epoch: 65\t Time elapsed: 0.48 seconds\t Loss: 0.26\n",
"[] Training Epoch: 66\t Time elapsed: 0.48 seconds\t Loss: 0.25\n",
"[] Training Epoch: 67\t Time elapsed: 0.48 seconds\t Loss: 0.26\n",
"[] Training Epoch: 68\t Time elapsed: 0.48 seconds\t Loss: 0.25\n",
"[] Training Epoch: 69\t Time elapsed: 0.48 seconds\t Loss: 0.25\n",
"[] Training Epoch: 70\t Time elapsed: 0.48 seconds\t Loss: 0.25\n",
"[] Training Epoch: 71\t Time elapsed: 0.48 seconds\t Loss: 0.24\n",
"[] Training Epoch: 72\t Time elapsed: 0.48 seconds\t Loss: 0.25\n",
"[] Training Epoch: 73\t Time elapsed: 0.48 seconds\t Loss: 0.24\n",
"[] Training Epoch: 74\t Time elapsed: 0.48 seconds\t Loss: 0.24\n",
"[] Training Epoch: 75\t Time elapsed: 0.48 seconds\t Loss: 0.24\n",
"[] Training Epoch: 76\t Time elapsed: 0.48 seconds\t Loss: 0.24\n",
"[] Training Epoch: 77\t Time elapsed: 0.48 seconds\t Loss: 0.24\n",
"[] Training Epoch: 78\t Time elapsed: 0.48 seconds\t Loss: 0.24\n",
"[] Training Epoch: 79\t Time elapsed: 0.48 seconds\t Loss: 0.24\n",
"[] Training Epoch: 80\t Time elapsed: 0.48 seconds\t Loss: 0.24\n",
"[] Training Epoch: 81\t Time elapsed: 0.48 seconds\t Loss: 0.23\n",
"[] Training Epoch: 82\t Time elapsed: 0.48 seconds\t Loss: 0.23\n",
"[] Training Epoch: 83\t Time elapsed: 0.48 seconds\t Loss: 0.23\n",
"[] Training Epoch: 84\t Time elapsed: 0.48 seconds\t Loss: 0.23\n",
"[] Training Epoch: 85\t Time elapsed: 0.48 seconds\t Loss: 0.23\n",
"[] Training Epoch: 86\t Time elapsed: 0.48 seconds\t Loss: 0.23\n",
"[] Training Epoch: 87\t Time elapsed: 0.48 seconds\t Loss: 0.23\n",
"[] Training Epoch: 88\t Time elapsed: 0.48 seconds\t Loss: 0.23\n",
"[] Training Epoch: 89\t Time elapsed: 0.48 seconds\t Loss: 0.23\n",
"[] Training Epoch: 90\t Time elapsed: 0.48 seconds\t Loss: 0.23\n",
"[] Training Epoch: 91\t Time elapsed: 0.48 seconds\t Loss: 0.23\n",
"[] Training Epoch: 92\t Time elapsed: 0.48 seconds\t Loss: 0.23\n",
"[] Training Epoch: 93\t Time elapsed: 0.49 seconds\t Loss: 0.23\n",
"[] Training Epoch: 94\t Time elapsed: 0.48 seconds\t Loss: 0.22\n",
"[] Training Epoch: 95\t Time elapsed: 0.48 seconds\t Loss: 0.22\n",
"[] Training Epoch: 96\t Time elapsed: 0.48 seconds\t Loss: 0.22\n",
"[] Training Epoch: 97\t Time elapsed: 0.48 seconds\t Loss: 0.22\n",
"[] Training Epoch: 98\t Time elapsed: 0.48 seconds\t Loss: 0.22\n",
"[] Training Epoch: 99\t Time elapsed: 0.48 seconds\t Loss: 0.22\n",
"[] Training Epoch: 100\t Time elapsed: 0.48 seconds\t Loss: 0.22\n",
"[] Training Epoch: 101\t Time elapsed: 0.48 seconds\t Loss: 0.22\n",
"[] Training Epoch: 102\t Time elapsed: 0.48 seconds\t Loss: 0.22\n",
"[] Training Epoch: 103\t Time elapsed: 0.48 seconds\t Loss: 0.21\n",
"[] Training Epoch: 104\t Time elapsed: 0.48 seconds\t Loss: 0.21\n",
"[] Training Epoch: 105\t Time elapsed: 0.48 seconds\t Loss: 0.21\n",
"[] Training Epoch: 106\t Time elapsed: 0.48 seconds\t Loss: 0.21\n",
"[] Training Epoch: 107\t Time elapsed: 0.48 seconds\t Loss: 0.21\n",
"[] Training Epoch: 108\t Time elapsed: 0.48 seconds\t Loss: 0.21\n",
"[] Training Epoch: 109\t Time elapsed: 0.48 seconds\t Loss: 0.21\n",
"[] Training Epoch: 110\t Time elapsed: 0.48 seconds\t Loss: 0.21\n",
"[] Training Epoch: 111\t Time elapsed: 0.48 seconds\t Loss: 0.21\n",
"[] Training Epoch: 112\t Time elapsed: 0.48 seconds\t Loss: 0.21\n",
"[] Training Epoch: 113\t Time elapsed: 0.48 seconds\t Loss: 0.21\n",
"[] Training Epoch: 114\t Time elapsed: 0.48 seconds\t Loss: 0.21\n",
"[] Training Epoch: 115\t Time elapsed: 0.48 seconds\t Loss: 0.21\n",
"[] Training Epoch: 116\t Time elapsed: 0.48 seconds\t Loss: 0.21\n",
"[] Training Epoch: 117\t Time elapsed: 0.48 seconds\t Loss: 0.21\n",
"[] Training Epoch: 118\t Time elapsed: 0.48 seconds\t Loss: 0.20\n",
"[] Training Epoch: 119\t Time elapsed: 0.48 seconds\t Loss: 0.20\n",
"[] Training Epoch: 120\t Time elapsed: 0.48 seconds\t Loss: 0.20\n",
"[] Training Epoch: 121\t Time elapsed: 0.48 seconds\t Loss: 0.20\n",
"[] Training Epoch: 122\t Time elapsed: 0.49 seconds\t Loss: 0.20\n",
"[] Training Epoch: 123\t Time elapsed: 0.48 seconds\t Loss: 0.20\n",
"[] Training Epoch: 124\t Time elapsed: 0.48 seconds\t Loss: 0.20\n",
"[] Training Epoch: 125\t Time elapsed: 0.48 seconds\t Loss: 0.20\n",
"[] Training Epoch: 126\t Time elapsed: 0.48 seconds\t Loss: 0.20\n",
"[] Training Epoch: 127\t Time elapsed: 0.48 seconds\t Loss: 0.21\n",
"[] Training Epoch: 128\t Time elapsed: 0.48 seconds\t Loss: 0.20\n",
"[] Training Epoch: 129\t Time elapsed: 0.48 seconds\t Loss: 0.20\n",
"[] Training Epoch: 130\t Time elapsed: 0.48 seconds\t Loss: 0.20\n",
"[] Training Epoch: 131\t Time elapsed: 0.48 seconds\t Loss: 0.20\n",
"[] Training Epoch: 132\t Time elapsed: 0.48 seconds\t Loss: 0.20\n",
"[] Training Epoch: 133\t Time elapsed: 0.48 seconds\t Loss: 0.20\n",
"[] Training Epoch: 134\t Time elapsed: 0.48 seconds\t Loss: 0.20\n",
"[] Training Epoch: 135\t Time elapsed: 0.48 seconds\t Loss: 0.19\n",
"[] Training Epoch: 136\t Time elapsed: 0.48 seconds\t Loss: 0.19\n",
"[] Training Epoch: 137\t Time elapsed: 0.48 seconds\t Loss: 0.20\n",
"[] Training Epoch: 138\t Time elapsed: 0.48 seconds\t Loss: 0.19\n",
"[] Training Epoch: 139\t Time elapsed: 0.48 seconds\t Loss: 0.19\n",
"[] Training Epoch: 140\t Time elapsed: 0.48 seconds\t Loss: 0.19\n",
"[] Training Epoch: 141\t Time elapsed: 0.48 seconds\t Loss: 0.19\n",
"[] Training Epoch: 142\t Time elapsed: 0.48 seconds\t Loss: 0.19\n",
"[] Training Epoch: 143\t Time elapsed: 0.48 seconds\t Loss: 0.19\n",
"[] Training Epoch: 144\t Time elapsed: 0.48 seconds\t Loss: 0.19\n",
"[] Training Epoch: 145\t Time elapsed: 0.48 seconds\t Loss: 0.19\n",
"[] Training Epoch: 146\t Time elapsed: 0.48 seconds\t Loss: 0.19\n",
"[] Training Epoch: 147\t Time elapsed: 0.48 seconds\t Loss: 0.19\n",
"[] Training Epoch: 148\t Time elapsed: 0.48 seconds\t Loss: 0.19\n",
"[] Training Epoch: 149\t Time elapsed: 0.48 seconds\t Loss: 0.19\n",
"[] Training Epoch: 150\t Time elapsed: 0.49 seconds\t Loss: 0.19\n",
"[] Training Epoch: 151\t Time elapsed: 0.48 seconds\t Loss: 0.19\n",
"[] Training Epoch: 152\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 153\t Time elapsed: 0.49 seconds\t Loss: 0.18\n",
"[] Training Epoch: 154\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 155\t Time elapsed: 0.49 seconds\t Loss: 0.18\n",
"[] Training Epoch: 156\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 157\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 158\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 159\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 160\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 161\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 162\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 163\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 164\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 165\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 166\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 167\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 168\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 169\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 170\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 171\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 172\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 173\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 174\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 175\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 176\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 177\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 178\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 179\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 180\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 181\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 182\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 183\t Time elapsed: 0.49 seconds\t Loss: 0.18\n",
"[] Training Epoch: 184\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 185\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 186\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 187\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 188\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 189\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 190\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 191\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 192\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 193\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 194\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 195\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 196\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 197\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 198\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 199\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 200\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 201\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 202\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 203\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 204\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 205\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 206\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 207\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 208\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 209\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 210\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 211\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 212\t Time elapsed: 0.49 seconds\t Loss: 0.18\n",
"[] Training Epoch: 213\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 214\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 215\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 216\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 217\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 218\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 219\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 220\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 221\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 222\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 223\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 224\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 225\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 226\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 227\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 228\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 229\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 230\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 231\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 232\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 233\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 234\t Time elapsed: 0.49 seconds\t Loss: 0.18\n",
"[] Training Epoch: 235\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 236\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 237\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 238\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 239\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 240\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 241\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 242\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 243\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 244\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 245\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 246\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 247\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 248\t Time elapsed: 0.49 seconds\t Loss: 0.18\n",
"[] Training Epoch: 249\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 250\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 251\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 252\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 253\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 254\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 255\t Time elapsed: 0.49 seconds\t Loss: 0.18\n",
"[] Training Epoch: 256\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 257\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 258\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 259\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 260\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 261\t Time elapsed: 0.49 seconds\t Loss: 0.18\n",
"[] Training Epoch: 262\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 263\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 264\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 265\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 266\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 267\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 268\t Time elapsed: 0.49 seconds\t Loss: 0.18\n",
"[] Training Epoch: 269\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 270\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 271\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 272\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 273\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 274\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 275\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 276\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 277\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 278\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 279\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 280\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 281\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 282\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 283\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 284\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 285\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 286\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 287\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 288\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 289\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 290\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 291\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 292\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 293\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 294\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 295\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 296\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 297\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 298\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 299\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 300\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 301\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 302\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 303\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 304\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 305\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 306\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 307\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 308\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 309\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 310\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 311\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 312\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 313\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 314\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 315\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 316\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 317\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 318\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 319\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 320\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 321\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 322\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 323\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 324\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 325\t Time elapsed: 0.49 seconds\t Loss: 0.18\n",
"[] Training Epoch: 326\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 327\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 328\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 329\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 330\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 331\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 332\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 333\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 334\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 335\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 336\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 337\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 338\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 339\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 340\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 341\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 342\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 343\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 344\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 345\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 346\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 347\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 348\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 349\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 350\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 351\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 352\t Time elapsed: 0.49 seconds\t Loss: 0.18\n",
"[] Training Epoch: 353\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 354\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 355\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 356\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 357\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 358\t Time elapsed: 0.49 seconds\t Loss: 0.18\n",
"[] Training Epoch: 359\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 360\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 361\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 362\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 363\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 364\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 365\t Time elapsed: 0.49 seconds\t Loss: 0.18\n",
"[] Training Epoch: 366\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 367\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 368\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 369\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 370\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 371\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 372\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 373\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 374\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 375\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 376\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 377\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 378\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 379\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 380\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 381\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 382\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 383\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 384\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 385\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 386\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 387\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 388\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 389\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 390\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 391\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 392\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 393\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 394\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 395\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 396\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 397\t Time elapsed: 0.49 seconds\t Loss: 0.18\n",
"[] Training Epoch: 398\t Time elapsed: 0.48 seconds\t Loss: 0.18\n",
"[] Training Epoch: 399\t Time elapsed: 0.48 seconds\t Loss: 0.18\n"
]
}
],
"source": [
"losses = []\n",
"for epoch in range(400):\n",
" loss = fit(\n",
" model, device, x, y, optimizer, criterion, epoch\n",
" )\n",
" losses.append(loss)\n",
" scheduler.step()"
]
},
{
"cell_type": "markdown",
"id": "a3f24a40",
"metadata": {},
"source": [
"## Visualize"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "1d770ba4",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAbz0lEQVR4nO3dfXRc9Z3f8fd3ZjSynmzZkmyMn00IYCg2oBoIKeA0S+yExOUc9qzdNN1tk/WSA3va5uw2pOmBNP2j3eVsm83mgbhZL7tNgU0LTtgcwsOGsGweIJbBNjbYYIzBQjaSbWzZerAe5ts/5o481sxII+vhyvd+XsdzdO/93Yev7pE/+uk3d+41d0dERKIrEXYBIiIyuRT0IiIRp6AXEYk4Bb2ISMQp6EVEIi4VdgHFNDY2+tKlS8MuQ0TkgrF9+/aj7t5UrG1aBv3SpUtpaWkJuwwRkQuGmb1Tqk1DNyIiEaegFxGJOAW9iEjEKehFRCJOQS8iEnEKehGRiFPQi4hEXKSC/ps/e5N/eKMj7DJERKaVSAX9d59/i1+8qaAXEckXqaBPJYyBjB6kIiKSL1JBn0wagwp6EZFzRCroUwkFvYjIcJEK+qSCXkSkQKSCPpVIaIxeRGSYSAW9evQiIoVGvR+9mW0Bbgfa3f2qIu1/DHw2b39XAE3uftzMDgKngEFgwN2bJ6rwYnTVjYhIoXJ69A8Ba0s1uvsD7r7K3VcBXwH+wd2P562yJmif1JCHXI8+M9mHERG5oIwa9O7+AnB8tPUCG4FHxlXROCQTxsCgevQiIvkmbIzezKrJ9vwfy1vswDNmtt3MNo2y/SYzazGzlo6O8/t0a0rX0YuIFJjIN2M/Dfxy2LDNTe5+LbAOuNvMbi61sbtvdvdmd29uair6fNtRJXXVjYhIgYkM+g0MG7Zx97bgazuwFVg9gccroA9MiYgUmpCgN7NZwC3Aj/OW1ZhZXW4auA3YPRHHKyWZMAb0ZqyIyDnKubzyEeBWoNHMWoH7gQoAd38wWO0O4Bl378rbdB6w1cxyx3nY3Z+auNILpRJG/6CCXkQk36hB7+4by1jnIbKXYeYvOwCsPN/CzkcyYfT0a+hGRCRfpD4ZqzF6EZFCkQr6ZCKh6+hFRIaJVNCrRy8iUihSQZ9M6qobEZHhIhX06tGLiBSKVNAndfdKEZECkQp69ehFRApFKuh1rxsRkUKRCnr16EVECkUq6LP3o9dVNyIi+SIV9OrRi4gUilTQZ6+jV9CLiOSLVNCrRy8iUihSQZ+76sZdYS8ikhOpoE8lDAB16kVEzopU0CeDoNf9bkREzopU0Od69BqnFxE5K1JBf7ZHr6AXEcmJVNAP9ej18BERkSGRCvpkMvvtqEcvInJWtILeNEYvIjLcqEFvZlvMrN3Mdpdov9XMTprZjuB1X17bWjPbZ2b7zezeiSy8mJSuuhERKVBOj/4hYO0o6/yju68KXl8HMLMk8G1gHbAC2GhmK8ZT7GiSuupGRKTAqEHv7i8Ax89j36uB/e5+wN37gEeB9eexn7KlkrrqRkRkuIkao7/RzHaa2U/N7Mpg2QLgUN46rcGyosxsk5m1mFlLR0fHeRWhHr2ISKGJCPqXgSXuvhL4C+BHwXIrsm7JBHb3ze7e7O7NTU1N51XI0Bi9Lq8UERky7qB39053Px1MPwlUmFkj2R78orxVFwJt4z3eSJKJ7LejHr2IyFnjDnozu8gse12jma0O9nkM2AZcambLzCwNbACeGO/xRjL0gSndvVJEZEhqtBXM7BHgVqDRzFqB+4EKAHd/ELgT+KKZDQA9wAbP3id4wMzuAZ4GksAWd98zKd9F4OwYvS6vFBHJGTXo3X3jKO3fAr5Vou1J4MnzK23sNEYvIlIoWp+M1VU3IiIFIhX0uo5eRKRQpIJeV92IiBSKVNDnxujPDOjNWBGRnEgFfXU6CUBP/0DIlYiITB+RCvqayuxFRN19gyFXIiIyfUQq6KtyPXoFvYjIkEgFfXVFNui7zijoRURyIhX0qWSCdCpBt8boRUSGRCroIfuGrIZuRETOilzQ16RTGroREckTuaCvSid1eaWISJ7IBX11OqnLK0VE8kQz6DV0IyIyJIJBn9JVNyIieSIY9Bq6ERHJF82g19CNiMiQCAZ9iu4+Dd2IiOREMOg1dCMiki+SQT+Qcfp0T3oRESCCQZ+7VXHXGQ3fiIhAGUFvZlvMrN3Mdpdo/6yZ7QpevzKzlXltB83sVTPbYWYtE1l4KXNq0gAc6+qbisOJiEx75fToHwLWjtD+NnCLu18N/Fdg87D2Ne6+yt2bz6/EsWmsrQTg6OkzU3E4EZFpLzXaCu7+gpktHaH9V3mzLwILJ6Cu86agFxE510SP0X8e+GnevAPPmNl2M9s00oZmtsnMWsyspaOj47wLaKzNDt0cPaWgFxGBMnr05TKzNWSD/qN5i29y9zYzmws8a2Z73f2FYtu7+2aCYZ/m5mY/3zrqq9MkTGP0IiI5E9KjN7Orge8D6939WG65u7cFX9uBrcDqiTjeSJIJY05NpYZuREQC4w56M1sMPA58zt3fyFteY2Z1uWngNqDolTsTrbE2Tccp9ehFRKCMoRszewS4FWg0s1bgfqACwN0fBO4DGoDvmBnAQHCFzTxga7AsBTzs7k9NwvdQoLFWPXoRkZxyrrrZOEr7F4AvFFl+AFhZuMXkm1OT5tAH3WEcWkRk2oncJ2MBZlVVcLKnP+wyRESmhcgGfWdPP5nMeV+8IyISGZEM+vrqCjIOp3S/GxGRaAb9zKoKADo1fCMiEs2gnxUEvcbpRUQU9CIikRfpoD/RraAXEYlk0NdXq0cvIpITyaDX0I2IyFmRDPqqiiQVSVPQi4gQ0aA3M306VkQkEMmgh+xDwvWAcBGRCAd9dTpFd5+CXkQkskFfW5nktHr0IiLRDfqayhTdfYNhlyEiErroBn06pR69iAhRDvrKpN6MFREhwkFfnU7RfUZDNyIikQ362soUXX0DuOvhIyISb5EN+prKFBmH3v5M2KWIiIRq1KA3sy1m1m5mu0u0m5l908z2m9kuM7s2r22tme0L2u6dyMJHU1OZBNAbsiISe+X06B8C1o7Qvg64NHhtAr4LYGZJ4NtB+wpgo5mtGE+xY1GTTgHoQ1MiEnujBr27vwAcH2GV9cDfeNaLQL2ZzQdWA/vd/YC79wGPButOCfXoRUSyJmKMfgFwKG++NVhWavmUqKnM9ui7dOWNiMTcRAS9FVnmIywvvhOzTWbWYmYtHR0d4y6qOhi66dLQjYjE3EQEfSuwKG9+IdA2wvKi3H2zuze7e3NTU9O4i6od6tEr6EUk3iYi6J8A/nVw9c0NwEl3PwxsAy41s2VmlgY2BOtOidwYvT40JSJxlxptBTN7BLgVaDSzVuB+oALA3R8EngQ+CewHuoF/E7QNmNk9wNNAEtji7nsm4XsoKnfVjd6MFZG4GzXo3X3jKO0O3F2i7UmyvwimXI2GbkREgAh/MjadSlCRNLp0q2IRibnIBj3ocYIiIhD1oE+ndHmliMRetINe96QXEYl60Kf0yVgRib1oB72GbkREIh70GroREYl40Kc1dCMiEu2gr9TQjYhI5INe97oRkbiLdtCnk/QNZugb0HNjRSS+oh30ut+NiEjUgz57q2KN04tInEU86PU4QRGReAS9evQiEmPRDvq0xuhFRKId9LkxegW9iMRYtIM+rTF6EZFoB73G6EVEoh70uaEb9ehFJL4iHfRVFUkSpjF6EYm3soLezNaa2T4z229m9xZp/2Mz2xG8dpvZoJnNCdoOmtmrQVvLRH8Do9Ste9KLSOylRlvBzJLAt4HfAlqBbWb2hLu/llvH3R8AHgjW/zTwH9z9eN5u1rj70QmtvEzVuie9iMRcOT361cB+dz/g7n3Ao8D6EdbfCDwyEcVNBD1OUETirpygXwAcyptvDZYVMLNqYC3wWN5iB54xs+1mtqnUQcxsk5m1mFlLR0dHGWWVR0M3IhJ35QS9FVnmJdb9NPDLYcM2N7n7tcA64G4zu7nYhu6+2d2b3b25qampjLLKo8cJikjclRP0rcCivPmFQFuJdTcwbNjG3duCr+3AVrJDQVOmVkM3IhJz5QT9NuBSM1tmZmmyYf7E8JXMbBZwC/DjvGU1ZlaXmwZuA3ZPROHlqtbQjYjE3KhX3bj7gJndAzwNJIEt7r7HzO4K2h8MVr0DeMbdu/I2nwdsNbPcsR5296cm8hsYzcyqFCd7+qfykCIi08qoQQ/g7k8CTw5b9uCw+YeAh4YtOwCsHFeF49RQU8mJ7n4GBjOkkpH+fJiISFGRT76G2jQAx7v7Qq5ERCQckQ/6OTVB0Hcp6EUknuIT9KcV9CIST5EP+sbaSgCOqUcvIjEV+aDX0I2IxF3kg352dRoz9ehFJL4iH/TJhFFfVcGx02fCLkVEJBSRD3qAuXUzeL9TQS8i8RSLoF80p5rWD7rDLkNEJBSxCPrFc6p593g37qVuuikiEl0xCfoquvsG9YasiMRSPIK+oRqAd49r+EZE4iceQT8nCPpjCnoRiZ9YBP2iOdUkE8ZbHafDLkVEZMrFIugrU0mWN9bw+uFTYZciIjLlYhH0AJddVMe+9zvDLkNEZMrFJugvv6iOQ8d7OK0HhYtIzMQo6GcCsO+Ihm9EJF5iE/SXXVQHKOhFJH5iE/QLZ1dRW5li3xGN04tIvMQm6M2Myy6q43X16EUkZsoKejNba2b7zGy/md1bpP1WMztpZjuC133lbjuVLr+ojr2HO8lkdM8bEYmPUYPezJLAt4F1wApgo5mtKLLqP7r7quD19TFuOyVWLaqns3dAH5wSkVgpp0e/Gtjv7gfcvQ94FFhf5v7Hs+2Eu27JbAC2v/NBWCWIiEy5coJ+AXAob741WDbcjWa208x+amZXjnFbzGyTmbWYWUtHR0cZZY3dssYaGmrStCjoRSRGygl6K7Js+CD3y8ASd18J/AXwozFsm13ovtndm929uampqYyyxs7MuHbJbPXoRSRWygn6VmBR3vxCoC1/BXfvdPfTwfSTQIWZNZaz7VRrXjKbt4926RmyIhIb5QT9NuBSM1tmZmlgA/BE/gpmdpGZWTC9OtjvsXK2nWoapxeRuBk16N19ALgHeBp4Hfihu+8xs7vM7K5gtTuB3Wa2E/gmsMGzim47Gd9Iua5aMIvqdJLn9raHWYaIyJSx6fgc1ebmZm9paZm0/X/phzt49rX32fbVjzOjIjlpxxERmSpmtt3dm4u1xeaTsfnuvG4hp3oH+Mmuw2GXIiIy6WIZ9Dcub+DD82rZ8ou3mY5/0YiITKRYBr2ZsenmS3jtcCc/3X0k7HJERCZVLIMe4I5rFvDhebX86VN76R/MhF2OiMikiW3QJxPGvesu5+Cxbh5+6d2wyxERmTSxDXqANZfN5cblDXzj79/gZE9/2OWIiEyKWAe9mfGfb7+CEz39fOu5N8MuR0RkUsQ66AGuvHgWv33dQv7qlwd58cCxsMsREZlwsQ96gK9+agVLGqq56wfbOaB71YtIxCjogVlVFfzV760macbvbH6R19r0XFkRiQ4FfWBxQzV/+wc3kEoYv/O9X/OShnFEJCIU9Hk+NLeOx774EebNmsHntvyGp/fow1QicuFT0A9zcX0V//cPbmTF/Jl88Qfb+Z/PvsGZgcGwyxIROW8K+iJm16R5+Pev5zMrL+bPf/Ymt3/zF7p/vYhcsBT0JVSnU3xjwzVs+b1mus4McOeDv+LL/28X73f2hl2aiMiYKOhH8bHL5/HMl27h8zct4/FXWrnlgZ/zJ0/t5ageRSgiF4hYPnjkfL17rJsHntnHT3a1UZlKsOGfLub3b17OgvqqsEsTkZgb6cEjCvrz8FbHaR58/i22vvIeGXfWXDaXz96wmFs+PJdkwsIuT0RiSEE/Sd470cMjL73L37YcouPUGRbUV3HHNQu4feV8LptXR/C8dBGRSaegn2T9gxn+/rX3efg37/LL/UfJOHxobi23Xz2f26+ezyVNtQp9EZlUCvop1HHqDE/tPszf7TrMtoPHcYclDdWsuWwuH7t8Ltcvn0NlqvCB5DsPnaCrb4CPXNIYQtUicqEbd9Cb2Vrgz4Ek8H13/+/D2j8LfDmYPQ180d13Bm0HgVPAIDBQqpB8F3LQ53u/s5en9xzhub3t/PqtY5wZyFCdTnLj8gauXz6HG5Y3cMX8mfzdzja+/Ngu0skEL/6nf87AoDO7Jh12+SJyARlX0JtZEngD+C2gFdgGbHT31/LW+Qjwurt/YGbrgK+5+/VB20Gg2d2PlltwVII+X0/fIL8+cJTn9rbzq/3HOHC065z2yy+qY++RU5hB0ox/+9Fl3PShRm75cBOZjJPQm7wiMoLxBv2NZIP7E8H8VwDc/b+VWH82sNvdFwTzB1HQF2jv7OWlt4/zZvtpFs6uYv2qi7n/x3t48tXDdPYODK23rLGGU70D3L3mEtZcNpeljTUhVi0i09V4g/5OYK27fyGY/xxwvbvfU2L9PwIuz1v/beADwIHvufvmEtttAjYBLF68+Lp33nmnnO8tkp7f184r757ADL77/FucGcg+vDyVMBbMrmJBfRVXzJ/JJ668iCUN1TTWVuqyTpGYG2/Q/zbwiWFBv9rd/7DIumuA7wAfdfdjwbKL3b3NzOYCzwJ/6O4vjHTMOPTox2LnoRMcPtnDjkMn2Xukc2i8P2dBfRVLGqq5flkDF9fPYNWiepY11pBKJujuG+Dg0W4WzamibkZFiN+FiEymkYI+Vcb2rcCivPmFQFuRg1wNfB9Ylwt5AHdvC762m9lWYDUwYtDLuVYuqmflonrWXjUfgN7+QXr7B/nF/qMcOdnLc3vbOdnTzzd+9ga539vJhDGvrpJjXX2cGchQV5li4/WLaahJs7SxhmsW1zNzRgUzKgqvABKRaCkn6LcBl5rZMuA9YAPwL/NXMLPFwOPA59z9jbzlNUDC3U8F07cBX5+o4uNqRkWSGRVJbr/6YgC+8M+WA3Ciu4+jp8/w8rsnePdYN20nephdk+bqhbN4avcRNr9w4Jz9VCSNlQvrWTynmoWzq1g4p5rFc6pZUF/F3JmVRS8DFZELz6hB7+4DZnYP8DTZyyu3uPseM7sraH8QuA9oAL4TfDAodxnlPGBrsCwFPOzuT03KdyLUV6epr07zobl1BW3rVy2g68wADmx7+zitJ3p452gXu947yUtvH+dHO3rIDBvFa6xNc9GsGdRXpekbyNDZ209DbZrPrLyY65bMobd/kMpUguVNtQDsbz/N3iOdXNJUy5UXz9SHxESmCX1gSgDoG8hw+GQPh4730HayhyMnezl8spcjJ3s40dNPZSpB3YwK3uo4zYGOcy8NrapI4ji9/ee+b/DxK+bS1TdIQ02aj6+YR3U6Sd9Ahn+yYBappG6cKjKRxjtGLzGQTiVY0lDDkoaRL990d3a/18mBo6epqkjS2TvAnraTJMy4asFMLps3kz1tJ3ns5VYef/k9aipTHO/q43t5w0b11RUsa6yhMpVgVlUFC2cHQ0ezq5lbV0l9dQX1VWnqZqT0+QGRCaAevUy6U739vPDGURIGg+78fG8H7ad6OdOf4Xh3H+990ENPf/HHNdakk5gZ7k7CjGTSSCWMZMJIBvNJy86nEgkSiWx7ImEkLfumtJmRCKYTlp1PGkPT57adO52w7P4TCYLtsutntzOSibP7yZ9OBB98SyRs6ENw5+wzcXY/+dMF+zzneKXrzB1raJ9WuN7w4yXytjlnP7lXUEep/cj0oh69hKpuRgWfunr+0HzuTeQcd+d4Vx+HPujh2OkznOju54PuPjp7B7LvKziYgTsMZjL0Z5xMxhnI+zoYvAYyTsad/sEM7pDx7HwmAwODGTLuDHr2mLnlQ+s4ZDJ50+7B/LB1ii0ftp+oy/3CShgYRvAPC+ZtaN4wGGont2xYuwUr2TnrBfsedtxz5gvqGvkXUMH2Bfsr3H6sxyzYwxi2n1Od5od33VhQw3gp6CV0ZkZDbSUNtZVhlzJhCn5hlPolkTc9mPFzfzl5qTaC+cLp7C+yYD7D2eki+ynY59B2w34ZjnK8jIPjBP9wz+43O51tyx84cPeCttx8dipoG/YL0zl3QWE7o7SPvEGx38/DRzxGP8bYth++oG7G5ESygl5kEiQSRqJI71AkDLr0QUQk4hT0IiIRp6AXEYk4Bb2ISMQp6EVEIk5BLyIScQp6EZGIU9CLiETctLzXjZl1AOf7LMFGoOzn004h1TU2qmtspmtdMH1ri1pdS9y9qVjDtAz68TCzllI39gmT6hob1TU207UumL61xakuDd2IiEScgl5EJOKiGPSbwy6gBNU1NqprbKZrXTB9a4tNXZEboxcRkXNFsUcvIiJ5FPQiIhEXmaA3s7Vmts/M9pvZvSHXctDMXjWzHWbWEiybY2bPmtmbwdfZU1TLFjNrN7PdectK1mJmXwnO4T4z+8QU1/U1M3svOG87zOyTIdS1yMx+bmavm9keM/t3wfJQz9kIdYV6zsxshpn9xsx2BnX9l2B52OerVF2h/4wFx0qa2Stm9pNgfnLPlwePD7uQX0ASeAtYDqSBncCKEOs5CDQOW/anwL3B9L3An0xRLTcD1wK7R6sFWBGcu0pgWXBOk1NY19eAPyqy7lTWNR+4NpiuA94Ijh/qORuhrlDPGdlHoNYG0xXAS8AN0+B8laor9J+x4HhfAh4GfhLMT+r5ikqPfjWw390PuHsf8CiwPuSahlsP/HUw/dfAv5iKg7r7C8DxMmtZDzzq7mfc/W1gP9lzO1V1lTKVdR1295eD6VPA68ACQj5nI9RVylTV5e5+OpitCF5O+OerVF2lTNnPmJktBD4FfH/Y8SftfEUl6BcAh/LmWxn5P8Fkc+AZM9tuZpuCZfPc/TBk/9MCc0OrrnQt0+E83mNmu4Khndyfr6HUZWZLgWvI9ganzTkbVheEfM6CYYgdQDvwrLtPi/NVoi4I/2fsG8B/BDJ5yyb1fEUl6Is9hTnM60ZvcvdrgXXA3WZ2c4i1jEXY5/G7wCXAKuAw8GfB8imvy8xqgceAf+/unSOtWmTZpNVWpK7Qz5m7D7r7KmAhsNrMrhph9bDrCvV8mdntQLu7by93kyLLxlxXVIK+FViUN78QaAupFty9LfjaDmwl+6fW+2Y2HyD42h5WfSPUEup5dPf3g/+cGeB/cfZP1Cmty8wqyIbp/3H3x4PFoZ+zYnVNl3MW1HICeB5YyzQ4X8Xqmgbn6ybgM2Z2kOwQ88fM7AdM8vmKStBvAy41s2VmlgY2AE+EUYiZ1ZhZXW4auA3YHdTzu8Fqvwv8OIz6AqVqeQLYYGaVZrYMuBT4zVQVlftBD9xB9rxNaV1mZsBfAq+7+//Iawr1nJWqK+xzZmZNZlYfTFcBHwf2Ev75KlpX2OfL3b/i7gvdfSnZnHrO3f8Vk32+Jutd5al+AZ8keyXCW8BXQ6xjOdl3yXcCe3K1AA3Az4A3g69zpqieR8j+idpPtnfw+ZFqAb4anMN9wLoprut/A68Cu4If8Pkh1PVRsn8a7wJ2BK9Phn3ORqgr1HMGXA28Ehx/N3DfaD/vIdcV+s9Y3vFu5exVN5N6vnQLBBGRiIvK0I2IiJSgoBcRiTgFvYhIxCnoRUQiTkEvIhJxCnoRkYhT0IuIRNz/B/PvIJHVX9tuAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.figure()\n",
"plt.plot(losses)\n",
"plt.show()\n",
"plt.close()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "94b93168",
"metadata": {},
"outputs": [],
"source": [
"y_viz = y.squeeze().cpu().numpy()\n",
"with torch.no_grad():\n",
" y_hat = model(x).squeeze().cpu().numpy()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "2b386abe",
"metadata": {},
"outputs": [
{
"data": {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment