Created
November 25, 2021 23:11
-
-
Save ShairozS/7d4a33dbb788d96f869f61b98eb35dae to your computer and use it in GitHub Desktop.
Training a Pytorch contrastive backbone with pairwise contrastive loss
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Training on MNIST Dataset with Contrastive Pairs Loss\n", | |
"-------------------------" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"D:\\Research\\ContrastiveRepresentationLearning\n" | |
] | |
} | |
], | |
"source": [ | |
"import os\n", | |
"#import torchsummary\n", | |
"import matplotlib.pyplot as plt\n", | |
"import torch\n", | |
"import numpy as np\n", | |
"import cv2\n", | |
"from torch import nn\n", | |
"from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler\n", | |
"from torchvision import transforms, models\n", | |
"\n", | |
"os.chdir('..'); os.chdir('..')\n", | |
"print(os.getcwd()) # Should be .\\ContrastiveLearning\n", | |
"from Code.trainers import Trainer\n", | |
"#from Code.models import SiameseNet\n", | |
"from Code.losses import form_triplets, ContrastiveLoss\n", | |
"from Code.dataloaders import LabeledContrastiveDataset\n", | |
"from Code.utils import extract_embeddings, plot_embeddings\n", | |
"\n", | |
"\n", | |
"# Hyperparameters\n", | |
"N = 3000\n", | |
"EMB_SIZE = 32\n", | |
"DEVICE = 'cuda'\n", | |
"LR = 0.0005\n", | |
"EPOCHS = 10\n", | |
"MARGIN = 1.0\n", | |
"NAME = 'MNIST_PAIR_LOSS_' + '_'.join([str(N), str(EMB_SIZE), str(LR), str(EPOCHS), str(MARGIN)])\n", | |
"\n", | |
"# Reproduciblity\n", | |
"SEED = 911\n", | |
"torch.manual_seed(SEED)\n", | |
"torch.backends.cudnn.deterministic = True\n", | |
"torch.backends.cudnn.benchmark = False\n", | |
"np.random.seed(SEED)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Create Dataloader and Inspect Data\n", | |
"---------------------" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"root = r'D:\\Data\\Imagery\\MNIST\\MNIST'\n", | |
"mean, std = 0.1307, 0.3081\n", | |
"tfms = transforms.Compose([\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize((mean,), (std,))\n", | |
" ])\n", | |
"\n", | |
"\n", | |
"lcd = LabeledContrastiveDataset(root, transforms=tfms)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"torch.Size([10, 1, 28, 28])\n", | |
"torch.Size([10, 1, 28, 28])\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"C:\\Users\\Shair\\.conda\\envs\\pytorch\\lib\\site-packages\\torchvision\\transforms\\functional.py:114: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ..\\torch\\csrc\\utils\\tensor_numpy.cpp:143.)\n", | |
" img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()\n" | |
] | |
} | |
], | |
"source": [ | |
"datadict = lcd.__getitem__(4)\n", | |
"print(datadict[\"x1\"].shape); print(datadict[\"x2\"].shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from torchvision import transforms\n", | |
"\n", | |
"\n", | |
"train_sampler = SubsetRandomSampler(range(int(N*0.9)))\n", | |
"test_sampler = SubsetRandomSampler(range(int(N*0.9), N))\n", | |
"\n", | |
"siamese_train_loader = torch.utils.data.DataLoader(lcd, batch_size=None, sampler=train_sampler)\n", | |
"siamese_test_loader = torch.utils.data.DataLoader(lcd, batch_size=None, shuffle=test_sampler)\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Model\n", | |
"------------" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"embedding_net = models.resnet18()\n", | |
"embedding_net.conv1 = nn.Conv2d(1, 64, (7,7), (2,2), (3,3))\n", | |
"embedding_net.fc = nn.Linear(512, EMB_SIZE)\n", | |
"model = embedding_net\n", | |
"model.train(); print() ; #torchsummary.summary(model, input_size = [(1,28,28),(1, 28, 28)], device=DEVICE)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Training\n", | |
"-------------------------" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"TL = ContrastiveLoss(margin=1.0, mode='pairs')\n", | |
"\n", | |
"t = Trainer(model = model, \n", | |
" dataloader = siamese_train_loader,\n", | |
" lr=LR,\n", | |
" loss_function= TL)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 0%| | 0/2700 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"----- Epoch: 0 -----\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [02:44<00:00, 16.44it/s]\n", | |
" 0%| | 3/2700 [00:00<01:58, 22.82it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Avg train loss: 0.04647024116440055\n", | |
"----- Epoch: 1 -----\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [02:22<00:00, 18.89it/s]\n", | |
" 0%| | 3/2700 [00:00<01:48, 24.83it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Avg train loss: 0.02958299503657515\n", | |
"----- Epoch: 2 -----\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [02:02<00:00, 22.02it/s]\n", | |
" 0%| | 3/2700 [00:00<01:45, 25.64it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Avg train loss: 0.02280590784906603\n", | |
"----- Epoch: 3 -----\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [01:51<00:00, 24.14it/s]\n", | |
" 0%| | 3/2700 [00:00<01:44, 25.86it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Avg train loss: 0.017991772402449355\n", | |
"----- Epoch: 4 -----\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [01:49<00:00, 24.66it/s]\n", | |
" 0%| | 3/2700 [00:00<01:46, 25.41it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Avg train loss: 0.015129790639827759\n", | |
"----- Epoch: 5 -----\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [01:51<00:00, 24.18it/s]\n", | |
" 0%| | 3/2700 [00:00<01:46, 25.21it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Avg train loss: 0.01323084365027221\n", | |
"----- Epoch: 6 -----\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [01:50<00:00, 24.51it/s]\n", | |
" 0%| | 3/2700 [00:00<01:53, 23.80it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Avg train loss: 0.011587587445622171\n", | |
"----- Epoch: 7 -----\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [01:49<00:00, 24.68it/s]\n", | |
" 0%| | 3/2700 [00:00<01:45, 25.63it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Avg train loss: 0.010335876384317311\n", | |
"----- Epoch: 8 -----\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [01:47<00:00, 25.00it/s]\n", | |
" 0%| | 3/2700 [00:00<01:46, 25.42it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Avg train loss: 0.009299262001579317\n", | |
"----- Epoch: 9 -----\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [01:48<00:00, 24.94it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Avg train loss: 0.008486182559181838\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"losses = t.train(EPOCHS, print_every=1)#, writer = writer)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Text(0.5, 0, 'Epochs')" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.plot(losses)\n", | |
"plt.title(\"Training Loss - Contrastive Pair Loss\")\n", | |
"plt.ylabel(\"Train loss\"); plt.xlabel(\"Epochs\")\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[0.04647024116440055,\n", | |
" 0.01269574890874975,\n", | |
" 0.009251733474047796,\n", | |
" 0.00354936606259933,\n", | |
" 0.003681863589341379,\n", | |
" 0.0037361087024944733,\n", | |
" 0.0017280502177219306,\n", | |
" 0.0015738989551832843,\n", | |
" 0.001006346939675369,\n", | |
" 0.001168467577604512]" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"losses" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Inspecting Embeddings\n", | |
"-------------------" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"EMBS_TO_VISUALIZE = N - int(N*0.9)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"performing PCA to reduce embeddings to 2 dimensions\n", | |
"0.45079494 % variance explained using PCA\n" | |
] | |
} | |
], | |
"source": [ | |
"test_embs = extract_embeddings(siamese_test_loader, model, EMBS_TO_VISUALIZE, reduce_to_dimension=2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>Emb</th>\n", | |
" <th>Label</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>[-0.10142087, 0.9861553]</td>\n", | |
" <td>5</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>[-0.29980388, 0.47371897]</td>\n", | |
" <td>6</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>[0.32089403, -0.11234523]</td>\n", | |
" <td>2</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>[-0.10230484, -0.2726765]</td>\n", | |
" <td>9</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>[-0.33229306, -0.032638874]</td>\n", | |
" <td>8</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" Emb Label\n", | |
"0 [-0.10142087, 0.9861553] 5\n", | |
"1 [-0.29980388, 0.47371897] 6\n", | |
"2 [0.32089403, -0.11234523] 2\n", | |
"3 [-0.10230484, -0.2726765] 9\n", | |
"4 [-0.33229306, -0.032638874] 8" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"test_embs.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 401.625x360 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"\n", | |
"plot_embeddings(test_embs)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Saving Model\n", | |
"-------------------------" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Model saved at: D:/Research/ContrastiveRepresentationLearning/Outputs/Weights/MNIST_PAIR_LOSS_3000_32_0.0005_10_1.0.pth\n" | |
] | |
} | |
], | |
"source": [ | |
"weightfol = r'D:/Research/ContrastiveRepresentationLearning/Outputs/Weights'\n", | |
"outpath = weightfol + r\"/\" + NAME + '.pth'\n", | |
"torch.save(model.state_dict(), outpath); print(\"Model saved at: \", outpath)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment