Created
November 25, 2021 23:35
-
-
Save ShairozS/5b29ca60acc0845b424dbc0059e0fc61 to your computer and use it in GitHub Desktop.
Train MNIST with contrastive triplet loss
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Training on MNIST Dataset with Triplet Loss\n", | |
"-------------------------" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"D:\\Research\\ContrastiveRepresentationLearning\n" | |
] | |
} | |
], | |
"source": [ | |
"import os\n", | |
"#import torchsummary\n", | |
"import matplotlib.pyplot as plt\n", | |
"import torch\n", | |
"import numpy as np\n", | |
"import cv2\n", | |
"from torch import nn\n", | |
"from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler\n", | |
"from torchvision import transforms, models\n", | |
"\n", | |
"os.chdir('..'); os.chdir('..')\n", | |
"print(os.getcwd()) # Should be .\\ContrastiveLearning\n", | |
"from Code.trainers import Trainer\n", | |
"#from Code.models import SiameseNet\n", | |
"from Code.losses import form_triplets, ContrastiveLoss\n", | |
"from Code.dataloaders import LabeledContrastiveDataset\n", | |
"from Code.utils import extract_embeddings, plot_embeddings\n", | |
"\n", | |
"\n", | |
"\n", | |
"\n", | |
"# Hyperparameters\n", | |
"N = 3000\n", | |
"EMB_SIZE = 32\n", | |
"DEVICE = 'cuda'\n", | |
"LR = 0.0005\n", | |
"EPOCHS = 10\n", | |
"MARGIN = 1.0\n", | |
"NAME = 'MNIST_TRIPLET_LOSS_' + '_'.join([str(N), str(EMB_SIZE), str(LR), str(EPOCHS), str(MARGIN)])\n", | |
"\n", | |
"# Reproduciblity\n", | |
"SEED = 911\n", | |
"torch.manual_seed(SEED)\n", | |
"torch.backends.cudnn.deterministic = True\n", | |
"torch.backends.cudnn.benchmark = False\n", | |
"np.random.seed(SEED)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Create Dataloader and Inspect Data\n", | |
"---------------------" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"root = r'D:\\Data\\Imagery\\MNIST\\MNIST'\n", | |
"mean, std = 0.1307, 0.3081\n", | |
"\n", | |
"tfms = transforms.Compose([\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize((mean,), (std,))\n", | |
" ])\n", | |
"\n", | |
"\n", | |
"lcd = LabeledContrastiveDataset(root, transforms=tfms)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"torch.Size([10, 1, 28, 28])\n", | |
"torch.Size([10, 1, 28, 28])\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"C:\\Users\\Shair\\.conda\\envs\\pytorch\\lib\\site-packages\\torchvision\\transforms\\functional.py:114: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ..\\torch\\csrc\\utils\\tensor_numpy.cpp:143.)\n", | |
" img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()\n" | |
] | |
} | |
], | |
"source": [ | |
"datadict = lcd.__getitem__(4)\n", | |
"print(datadict[\"x1\"].shape); print(datadict[\"x2\"].shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from torchvision import transforms\n", | |
"\n", | |
"\n", | |
"train_sampler = SubsetRandomSampler(range(int(N*0.9)))\n", | |
"test_sampler = SubsetRandomSampler(range(int(N*0.9), N))\n", | |
"\n", | |
"siamese_train_loader = torch.utils.data.DataLoader(lcd, batch_size=None, sampler=train_sampler)\n", | |
"siamese_test_loader = torch.utils.data.DataLoader(lcd, batch_size=None, shuffle=test_sampler)\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Model\n", | |
"------------" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"embedding_net = models.resnet18()\n", | |
"embedding_net.conv1 = nn.Conv2d(1, 64, (7,7), (2,2), (3,3))\n", | |
"embedding_net.fc = nn.Linear(512, EMB_SIZE)\n", | |
"model = embedding_net\n", | |
"model.train(); print() ; #torchsummary.summary(model, input_size = [(1,28,28),(1, 28, 28)], device=DEVICE)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Training\n", | |
"-------------------------" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class TripletLoss(nn.Module):\n", | |
" \n", | |
" def __init__(self, margin, norm, miner):\n", | |
" super(TripletLoss, self).__init__()\n", | |
" self.loss = nn.TripletMarginLoss(margin, norm)\n", | |
" self.miner = miner\n", | |
" \n", | |
" def forward(self, x, y):\n", | |
" a, p, n = self.miner(x, y)\n", | |
" return(self.loss(a,p,n))\n", | |
" \n", | |
" \n", | |
"TL = ContrastiveLoss(margin=1.0, mode='triplets')\n", | |
"\n", | |
"t = Trainer(model = model, \n", | |
" dataloader = siamese_train_loader,\n", | |
" lr=LR,\n", | |
" loss_function= TL)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 0%| | 0/2700 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"----- Epoch: 0 -----\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [01:52<00:00, 24.00it/s]\n", | |
" 0%| | 3/2700 [00:00<01:51, 24.09it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Avg train loss: 0.21114463511037887\n", | |
"----- Epoch: 1 -----\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [01:51<00:00, 24.11it/s]\n", | |
" 0%| | 3/2700 [00:00<01:49, 24.56it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Avg train loss: 0.13589056040674188\n", | |
"----- Epoch: 2 -----\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [01:53<00:00, 23.71it/s]\n", | |
" 0%| | 3/2700 [00:00<01:49, 24.73it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Avg train loss: 0.10390172110855794\n", | |
"----- Epoch: 3 -----\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [01:52<00:00, 24.07it/s]\n", | |
" 0%| | 3/2700 [00:00<01:50, 24.38it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Avg train loss: 0.0842619685555861\n", | |
"----- Epoch: 4 -----\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [01:52<00:00, 24.05it/s]\n", | |
" 0%| | 3/2700 [00:00<01:49, 24.59it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Avg train loss: 0.07334132929354033\n", | |
"----- Epoch: 5 -----\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [01:52<00:00, 23.94it/s]\n", | |
" 0%| | 3/2700 [00:00<01:52, 23.99it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Avg train loss: 0.06533128698030446\n", | |
"----- Epoch: 6 -----\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [01:51<00:00, 24.15it/s]\n", | |
" 0%| | 3/2700 [00:00<01:51, 24.23it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Avg train loss: 0.05845108959283751\n", | |
"----- Epoch: 7 -----\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [01:51<00:00, 24.16it/s]\n", | |
" 0%| | 3/2700 [00:00<01:51, 24.19it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Avg train loss: 0.05141863693237274\n", | |
"----- Epoch: 8 -----\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [01:51<00:00, 24.13it/s]\n", | |
" 0%| | 3/2700 [00:00<01:50, 24.39it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Avg train loss: 0.04743843640161989\n", | |
"----- Epoch: 9 -----\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [01:51<00:00, 24.15it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Avg train loss: 0.043700399544134225\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"losses = t.train(EPOCHS, print_every=1)#, writer = writer)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Text(0.5, 0, 'Epochs')" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.plot(losses)\n", | |
"plt.title(\"Training Loss - Triplet Loss\")\n", | |
"plt.ylabel(\"Train loss\"); plt.xlabel(\"Epochs\")\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[0.21114463511037887,\n", | |
" 0.06063648570310492,\n", | |
" 0.03992404251219004,\n", | |
" 0.025342710896670555,\n", | |
" 0.029658772245357224,\n", | |
" 0.025281075414125182,\n", | |
" 0.017169905268035815,\n", | |
" 0.002191468309119281,\n", | |
" 0.015596832155597103,\n", | |
" 0.010058067826763295]" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"losses" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Inspecting Embeddings\n", | |
"-------------------" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"EMBS_TO_VISUALIZE = N - int(N*0.9)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"performing PCA to reduce embeddings to 2 dimensions\n", | |
"0.5584684 % variance explained using PCA\n" | |
] | |
} | |
], | |
"source": [ | |
"test_embs = extract_embeddings(siamese_test_loader, model, EMBS_TO_VISUALIZE, reduce_to_dimension=2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>Emb</th>\n", | |
" <th>Label</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>[1.9465842, 1.9470881]</td>\n", | |
" <td>5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>[-0.8707288, -1.3512108]</td>\n", | |
" <td>2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>[-0.45471096, 1.8426472]</td>\n", | |
" <td>9</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>[1.1143692, -1.8616936]</td>\n", | |
" <td>8</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>[-1.1564724, 2.1490877]</td>\n", | |
" <td>4</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" Emb Label\n", | |
"0 [1.9465842, 1.9470881] 5\n", | |
"1 [-0.8707288, -1.3512108] 2\n", | |
"2 [-0.45471096, 1.8426472] 9\n", | |
"3 [1.1143692, -1.8616936] 8\n", | |
"4 [-1.1564724, 2.1490877] 4" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"test_embs.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 401.625x360 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plot_embeddings(test_embs)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Saving Model\n", | |
"-------------------------" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"model saved to: D:\\Research\\ContrastiveRepresentationLearning\\Outputs\\Weights\\\\MNIST_TRIPLET_LOSS_3000_32_0.0005_10_1.0\n" | |
] | |
} | |
], | |
"source": [ | |
"outpath = os.getcwd() + r'\\Outputs\\Weights\\\\' + NAME \n", | |
"torch.save(model.state_dict(), outpath); print(\"model saved to: \" + outpath)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment