Created
December 19, 2021 03:05
-
-
Save ShairozS/26eca61a3fe22a603782c586199711fb to your computer and use it in GitHub Desktop.
Linear probe on Pairwise-Contrastive Model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "greenhouse-shock", | |
"metadata": {}, | |
"source": [ | |
"## Testing Linear Seperability in Feature Spaces of DL Models\n", | |
"----------------" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "german-madison", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"D:\\Research\\ContrastiveRepresentationLearning\n" | |
] | |
} | |
], | |
"source": [ | |
"import os\n", | |
"#import torchsummary\n", | |
"import matplotlib.pyplot as plt\n", | |
"import torch\n", | |
"import numpy as np\n", | |
"import cv2\n", | |
"import tqdm\n", | |
"from torch import nn\n", | |
"from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler\n", | |
"from torchvision import transforms, models\n", | |
"from torch.optim import Adam\n", | |
"\n", | |
"os.chdir('..'); os.chdir('..')\n", | |
"print(os.getcwd()) # Should be .\\ContrastiveLearning\n", | |
"from Code.trainers import Trainer\n", | |
"from Code.losses import form_triplets, ContrastiveLoss\n", | |
"from Code.dataloaders import LabeledContrastiveDataset\n", | |
"from Code.utils import extract_embeddings, plot_embeddings\n", | |
"\n", | |
"\n", | |
"# Hyperparameters\n", | |
"N = 3000\n", | |
"EMB_SIZE = 32\n", | |
"DEVICE = 'cuda'\n", | |
"LR = 0.0005\n", | |
"EPOCHS = 5\n", | |
"MARGIN = 1.0\n", | |
"N = 3000\n", | |
"DEVICE = 'cuda'\n", | |
"MODEL_WEIGHTS = 'D:/Research/ContrastiveRepresentationLearning/Outputs/Weights/MNIST_PAIR_LOSS_3000_32_0.0005_10_1.0.pth'\n", | |
"TITLE = 'ContrastivePairs'\n", | |
"\n", | |
"# Reproduciblity\n", | |
"SEED = 911\n", | |
"torch.manual_seed(SEED)\n", | |
"torch.backends.cudnn.deterministic = True\n", | |
"torch.backends.cudnn.benchmark = False\n", | |
"np.random.seed(SEED)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "attended-heating", | |
"metadata": {}, | |
"source": [ | |
"## Loading Data\n", | |
"------------------" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "harmful-tolerance", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"root = r'D:\\Data\\Imagery\\MNIST\\MNIST'\n", | |
"mean, std = 0.1307, 0.3081\n", | |
"tfms = transforms.Compose([\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize((mean,), (std,))\n", | |
" ])\n", | |
"\n", | |
"\n", | |
"lcd = LabeledContrastiveDataset(root, transforms=tfms)\n", | |
"\n", | |
"\n", | |
"\n", | |
"train_sampler = SubsetRandomSampler(range(int(N*0.9)))\n", | |
"test_sampler = SubsetRandomSampler(range(int(N*0.9), N))\n", | |
"\n", | |
"siamese_train_loader = torch.utils.data.DataLoader(lcd, batch_size=None, sampler=train_sampler)\n", | |
"siamese_test_loader = torch.utils.data.DataLoader(lcd, batch_size=None, sampler=test_sampler)\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "historic-artist", | |
"metadata": {}, | |
"source": [ | |
"## Loading Model\n", | |
"----------------" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "narrow-sarah", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"embedding_net = models.resnet18()\n", | |
"embedding_net.conv1 = nn.Conv2d(1, 64, (7,7), (2,2), (3,3))\n", | |
"embedding_net.fc = nn.Linear(512, EMB_SIZE)\n", | |
"model = embedding_net\n", | |
"model.eval(); print() ; #torchsummary.summary(model, input_size = [(1,28,28),(1, 28, 28)], device=DEVICE)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "ethical-vermont", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<All keys matched successfully>" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.load_state_dict(torch.load(MODEL_WEIGHTS))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "tribal-receipt", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"linear_clf = nn.Linear(EMB_SIZE, 10).to(DEVICE)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "premium-algebra", | |
"metadata": {}, | |
"source": [ | |
"## Training Linear Classifier\n", | |
"----------------" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "administrative-optimization", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 0%| | 0/2700 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"------------ EPOCH: 0 ---------------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [08:15<00:00, 5.45it/s]\n", | |
" 0%| | 0/2700 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch loss: 0.4060716205697369\n", | |
"------------ EPOCH: 2 ---------------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [06:56<00:00, 6.49it/s]\n", | |
" 0%| | 0/2700 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch loss: 0.23509877727263503\n", | |
"------------ EPOCH: 3 ---------------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [05:42<00:00, 7.89it/s]\n", | |
" 0%| | 2/2700 [00:00<03:56, 11.43it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch loss: 0.19861162111818514\n", | |
"------------ EPOCH: 4 ---------------\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████████████████████████████████████████████████████████████████████████| 2700/2700 [04:52<00:00, 9.22it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch loss: 0.1862925045737238\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"epoch_losses = []\n", | |
"model = model.to(DEVICE)\n", | |
"optimizer = Adam(linear_clf.parameters(), lr = LR)\n", | |
"loss_fn = nn.CrossEntropyLoss()\n", | |
"\n", | |
"for epoch in range(EPOCHS):\n", | |
" \n", | |
" batch_losses = []\n", | |
" print(\"------------ EPOCH: \" + str(epoch) + \" ---------------\")\n", | |
" for idx, batch in tqdm.tqdm(enumerate(siamese_train_loader), total=len(siamese_train_loader)):\n", | |
" \n", | |
" optimizer.zero_grad()\n", | |
" # Extract batched data elements\n", | |
" \n", | |
" x1 = batch[\"x1\"]; x2 = batch[\"x2\"]\n", | |
" y = batch[\"labels\"]\n", | |
" \n", | |
" # Send to device\n", | |
" x1 = x1.to(DEVICE); x2 = x2.to(DEVICE)\n", | |
" y = y.to(DEVICE)\n", | |
" \n", | |
" # Get embeddings of data elements\n", | |
" emb_x1 = model(x1); emb_x2 = model(x2)\n", | |
" \n", | |
" # Make predictions based on embeddings \n", | |
" pred_x1 = linear_clf(emb_x1); pred_x2 = linear_clf(emb_x2)\n", | |
" \n", | |
" # Use embeddings and label to train linear classifier\n", | |
" loss = loss_fn(pred_x1, y) + loss_fn(pred_x2, y)\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" \n", | |
" batch_losses.append(loss.item())\n", | |
" \n", | |
" epoch_losses.append(np.mean(batch_losses))\n", | |
" print(\"Epoch loss: \" + str(np.mean(batch_losses)))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "lucky-north", | |
"metadata": {}, | |
"source": [ | |
"## Measuring Performance on Test Embeddings\n", | |
"-----------------------" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "excessive-breast", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|████████████████████████████████████████████████████████████████████████████████| 300/300 [00:11<00:00, 26.41it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Accuracy: tensor(0.9775, device='cuda:0')%\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"corr = 0\n", | |
"tot = 0\n", | |
"\n", | |
"for idx, batch in tqdm.tqdm(enumerate(siamese_test_loader), total=len(siamese_test_loader)):\n", | |
" \n", | |
" x1 = batch[\"x1\"]; x2 = batch[\"x2\"]\n", | |
" y = batch[\"labels\"]\n", | |
" \n", | |
" # Send to device\n", | |
" x1 = x1.to(DEVICE); x2 = x2.to(DEVICE)\n", | |
" y = y.to(DEVICE)\n", | |
" \n", | |
" # Get embeddings of data elements\n", | |
" emb_x1 = model(x1); emb_x2 = model(x2)\n", | |
" emb_x1 = emb_x1.to(DEVICE); emb_x2 = emb_x2.to(DEVICE)\n", | |
" \n", | |
" # Use embeddings and label to train linear classifier\n", | |
" pred_x1 = linear_clf(emb_x1); pred_x1 = torch.argmax(pred_x1, axis=1)\n", | |
" pred_x2 = linear_clf(emb_x2); pred_x2 = torch.argmax(pred_x2, axis=1)\n", | |
" \n", | |
" tot += 10\n", | |
" corr += torch.sum(pred_x1==y)\n", | |
" \n", | |
" tot += 10\n", | |
" corr += torch.sum(pred_x2==y)\n", | |
" \n", | |
"print(\"Accuracy: \" + str((corr/tot).item()) + \"%\")" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment