Skip to content

Instantly share code, notes, and snippets.

@ShairozS
Created December 6, 2021 01:25
Show Gist options
  • Save ShairozS/e52620f021690ebde4d866650a246fe5 to your computer and use it in GitHub Desktop.
Save ShairozS/e52620f021690ebde4d866650a246fe5 to your computer and use it in GitHub Desktop.
Inspecting Embedding Distances
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "irish-representation",
"metadata": {},
"source": [
"## Testing Semantic Distances in Feature Spaces of DL Models\n",
"------------------------------"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "large-blanket",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"D:\\Research\\ContrastiveRepresentationLearning\n"
]
}
],
"source": [
"from torch.utils.data import DataLoader, SubsetRandomSampler\n",
"from torch import nn\n",
"import torch\n",
"from torchvision.datasets import ImageFolder\n",
"from torchvision.utils import make_grid\n",
"from torchvision import models\n",
"from torchvision import transforms\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import torchsummary\n",
"import tqdm\n",
"import os\n",
"os.chdir('..'); os.chdir('..')\n",
"print(os.getcwd()) # Should be .\\ContrastiveRepresentationLearning\n",
"from Code.dataloaders import LabeledContrastiveDataset\n",
"\n",
"N = 3000\n",
"DEVICE = 'cuda'\n",
"MODEL_WEIGHTS = 'D:/Research/ContrastiveRepresentationLearning/Outputs/Weights/MNIST_NTXent_LOSS_3000_32_0.0005_10_1.0.pth'\n",
"TITLE = 'NTXENT'"
]
},
{
"cell_type": "markdown",
"id": "overall-international",
"metadata": {},
"source": [
"## Load Dataset\n",
"---------------------------"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "closed-mixture",
"metadata": {},
"outputs": [],
"source": [
"mean, std = 0.1307, 0.3081\n",
"tfms = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((mean,), (std,))\n",
" ])\n",
"MNIST_DL = ImageFolder(root = r'D:\\Data\\Imagery\\MNIST\\MNIST', transform = tfms)\n",
"MNIST_DL = DataLoader(MNIST_DL, batch_size = 64, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "statistical-swift",
"metadata": {},
"outputs": [],
"source": [
"root = r'D:\\Data\\Imagery\\MNIST\\MNIST'\n",
"mean, std = 0.1307, 0.3081\n",
"tfms = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((mean,), (std,))\n",
" ])\n",
"\n",
"\n",
"lcd = LabeledContrastiveDataset(root, transforms=tfms)\n",
"\n",
"\n",
"train_sampler = SubsetRandomSampler(range(int(N*0.9)))\n",
"test_sampler = SubsetRandomSampler(range(int(N*0.9), N))\n",
"\n",
"siamese_train_loader = torch.utils.data.DataLoader(lcd, batch_size=None, sampler=train_sampler)\n",
"siamese_test_loader = torch.utils.data.DataLoader(lcd, batch_size=None, sampler=test_sampler)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "utility-columbia",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Shair\\.conda\\envs\\pytorch\\lib\\site-packages\\torchvision\\transforms\\functional.py:114: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ..\\torch\\csrc\\utils\\tensor_numpy.cpp:143.)\n",
" img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([0., 7., 4., 6., 2., 9., 8., 3., 5., 1.])\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"batch = next(iter(lcd))\n",
"batch_x = batch[\"x1\"]; batch_y = batch[\"x2\"]; batch_labels = batch[\"labels\"]\n",
"img_grid = make_grid(batch_x); img_grid = img_grid.permute((1, 2, 0))\n",
"plt.imshow(img_grid); print(batch_labels)"
]
},
{
"cell_type": "markdown",
"id": "present-donna",
"metadata": {},
"source": [
"## Load Model\n",
"-----------------"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "rural-kelly",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"'''\n",
"Note: If a model was saved using torch.save(model.state_dict(), ...) then the architecture needs to be reinitialized before loading the model weights \n",
"using model.load_state_dict(torch.load('saved_state_dict.pth')). However, if we saved the entire model using torch.save(model,...) then we can\n",
"simply use torch.load(model) directly. The Pytorch documentation reccomends saving the state_dict. \n",
"'''\n",
"model = models.resnet18()\n",
"model.conv1 = nn.Conv2d(1, 64, (7,7), (2,2), (3,3))\n",
"model.fc = nn.Linear(512, 32)\n",
"model.train(); print() #; torchsummary.summary(model, input_size = [(1, 28, 28)], device=DEVICE)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "understood-dimension",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.load_state_dict(torch.load(MODEL_WEIGHTS))"
]
},
{
"cell_type": "markdown",
"id": "continuing-logan",
"metadata": {},
"source": [
"## Calculate Distances\n",
"-----------------------"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "afraid-feedback",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████████████| 300/300 [03:03<00:00, 1.63it/s]\n"
]
}
],
"source": [
"avg_dist_array = np.zeros((9,9))\n",
"\n",
"for idx, batch in tqdm.tqdm(enumerate(siamese_test_loader), total=len(siamese_test_loader)):#, total=int(N*0.1)):\n",
" \n",
" x1 = batch[\"x1\"]; x2 = batch[\"x2\"]\n",
" labels = batch[\"labels\"]\n",
" x1.to(DEVICE); x2.to(DEVICE)\n",
" emb_x1 = model(x1)\n",
" emb_x2 = model(x2)\n",
" N = 0\n",
" for i in range(len(emb_x1) - 1):\n",
" for j in range(len(emb_x2) - 1):\n",
" label1 = labels[i]\n",
" label2 = labels[j]\n",
" dist_ij = (emb_x1[i] - emb_x2[j])**2#; print(dist_ij)\n",
" N += 1\n",
" avg_dist_array[i,j] = ((N*avg_dist_array[i,j]) + dist_ij.mean()) / (N+1)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "robust-jumping",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<AxesSubplot:title={'center':'MNIST Distances - NTXENT loss'}>"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import seaborn as sns\n",
"plt.title(\"MNIST Distances - \" + TITLE + \" loss\")\n",
"sns.heatmap(avg_dist_array)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment