Skip to content

Instantly share code, notes, and snippets.

@ShairozS
Created December 6, 2021 01:25
Show Gist options
  • Save ShairozS/e52620f021690ebde4d866650a246fe5 to your computer and use it in GitHub Desktop.
Save ShairozS/e52620f021690ebde4d866650a246fe5 to your computer and use it in GitHub Desktop.
Inspecting Embedding Distances
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "irish-representation",
"metadata": {},
"source": [
"## Testing Semantic Distances in Feature Spaces of DL Models\n",
"------------------------------"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "large-blanket",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"D:\\Research\\ContrastiveRepresentationLearning\n"
]
}
],
"source": [
"from torch.utils.data import DataLoader, SubsetRandomSampler\n",
"from torch import nn\n",
"import torch\n",
"from torchvision.datasets import ImageFolder\n",
"from torchvision.utils import make_grid\n",
"from torchvision import models\n",
"from torchvision import transforms\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import torchsummary\n",
"import tqdm\n",
"import os\n",
"os.chdir('..'); os.chdir('..')\n",
"print(os.getcwd()) # Should be .\\ContrastiveRepresentationLearning\n",
"from Code.dataloaders import LabeledContrastiveDataset\n",
"\n",
"N = 3000\n",
"DEVICE = 'cuda'\n",
"MODEL_WEIGHTS = 'D:/Research/ContrastiveRepresentationLearning/Outputs/Weights/MNIST_NTXent_LOSS_3000_32_0.0005_10_1.0.pth'\n",
"TITLE = 'NTXENT'"
]
},
{
"cell_type": "markdown",
"id": "overall-international",
"metadata": {},
"source": [
"## Load Dataset\n",
"---------------------------"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "closed-mixture",
"metadata": {},
"outputs": [],
"source": [
"mean, std = 0.1307, 0.3081\n",
"tfms = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((mean,), (std,))\n",
" ])\n",
"MNIST_DL = ImageFolder(root = r'D:\\Data\\Imagery\\MNIST\\MNIST', transform = tfms)\n",
"MNIST_DL = DataLoader(MNIST_DL, batch_size = 64, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "statistical-swift",
"metadata": {},
"outputs": [],
"source": [
"root = r'D:\\Data\\Imagery\\MNIST\\MNIST'\n",
"mean, std = 0.1307, 0.3081\n",
"tfms = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((mean,), (std,))\n",
" ])\n",
"\n",
"\n",
"lcd = LabeledContrastiveDataset(root, transforms=tfms)\n",
"\n",
"\n",
"train_sampler = SubsetRandomSampler(range(int(N*0.9)))\n",
"test_sampler = SubsetRandomSampler(range(int(N*0.9), N))\n",
"\n",
"siamese_train_loader = torch.utils.data.DataLoader(lcd, batch_size=None, sampler=train_sampler)\n",
"siamese_test_loader = torch.utils.data.DataLoader(lcd, batch_size=None, sampler=test_sampler)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "utility-columbia",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\Shair\\.conda\\envs\\pytorch\\lib\\site-packages\\torchvision\\transforms\\functional.py:114: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ..\\torch\\csrc\\utils\\tensor_numpy.cpp:143.)\n",
" img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()\n",
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([0., 7., 4., 6., 2., 9., 8., 3., 5., 1.])\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAB4CAYAAADrPanmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAArwUlEQVR4nO2deXyU1fX/P/eZLctksgJZSSKQBAhUAmUxSNnB6EtUcEGlUBT3L/VLFXFppdXfV9QWqWuL0pZaK/JSgoCoQIiCLIFAQoCEhISErGSfyTr7+f2R5JFhJsmQzJIh9/16nVfy3Gc7c+fOee49z7nnMiICh8PhcDwPwd0KcDgcDqdvcAPO4XA4Hgo34BwOh+OhcAPO4XA4Hgo34BwOh+OhcAPO4XA4Hkq/DDhjbCFjLJ8xVsgYW+copTgcDofTO6yvceCMMQmAAgDzAJQDOAlgKRHlOk49DofD4XRHf3rgkwEUEtElItID2AZgkWPU4nA4HE5vSPtxbgSAsqu2ywFM6ekExhif9snhcDjXTx0RDbm2sD8GnNkoszLQjLHHADzWj/twOBzOYOeyrcL+GPByAFFXbUcCqLz2ICLaDGAzwHvgHA6H40j64wM/CWAUYyyWMSYH8ACAXY5Ri8PhcDi90WcDTkRGAM8A+B5AHoDtRHTeUYpxBh8SiQSrV6/Gzp07kZ+fj/z8fMTFxblbLQ5nwNIfFwqIaC+AvQ7ShTPIEQQBc+fORVJSEiIiIgAAq1atwnfffYe0tDQ3a8fhDDz6ZcA5HEciCAKSk5MRFBQklj333HPw8fHhBpzDsQGfSs/hcDgeisf2wNesWYNZs2aJ21VVVSgrK0NbWxu+/PJLXL5sM+qG42E888wz+PHHH92tRp8RBAFyuRxardbdqtwQCIKAF198ESNHjkRISIhYfuzYMXz88ceoq6uDM1cZY4whJCQEzz77LMaPH9/jsSaTCVlZWdixYwfOnj3rFH08xoDHxcUhKCgIKpUKAJCSkoI5c+aI+ysrK1FaWoq2tjY0Nzfj5MmTyMnJgclkcpfKA5ZJkyYhKCgIRASNRoOamho0NDSgqanJ3apZcejQIZw7d87danQLYwwymQwBAQEYPXo0TCYT9Ho9FAoFFAoFZDIZZDKZaMC1Wi1KSkpQVVUFg8HgZu0HNoIgwNvbG9OmTYMgCGJZSkoKRo0ahSFDfp7XEhgYiEuXLuHLL7+E0Wh0ij4xMTFISEhASEgIUlJSEB0djby8PLS2tsJsNkMulyMmJgbh4eFQKBQwmUwICwtDRkaG0ww4iMhlgo6JPtctEomENm7cSFlZWWQve/fuJZVKRRKJhDrjz7l0yoEDB4iIyGQy0bFjx+jFF1+kKVOmuF0vhUJB9fX1Ft/juHHj3K5XdyKRSEihUFBERAQtWbKEKisrqaioiI4ePUrFxcXU3t5u1S7Ly8tpzZo1NHToUN4uexBBEEipVFJiYiLp9Xq7fvMtLS2kVCqdptNzzz1HZrOZDAYDGQwG+vHHH2n8+PEklUoJAIWEhNCGDRuorKyMjEajqNevf/1rR9w/06ZNvV4j3B/pi+IjRoyg8vJyam5utvuLJCJqb2+n8vJy+vzzz2nhwoVub5ADSboMuNlsJp1ORxqNhg4fPux2vTzJgAcEBNClS5eooqKCqqqqqL6+noxGIxkMBtLpdGQwGMhsNlu1S6PRSGq1miorK+m1115z++cYqPLjjz9SRUUFVVdX26xHWzjbgK9YsYIOHz5MUVFRFB4eTsHBwaLxBjoeOiqVioYNG0a//e1vRb3eeOMNmjhxYn/vb9OAD2gXypAhQzBq1CiEh4eDMVsz97vHy8sL4eHhmDJlCqRSKRISErBp0yaH6hceHo5x48Zh+vTpEAQBWq0W5eXl+PLLL9Hc3NzjuYIgwNfXF7/5zW/Q3t6OzMxMZGdnO9V/dy2MMcjlcsjlcgQEBLjsvp7OyJEjMWPGDERGRkImk13XuRKJBP7+/vD398fkyZOxZMkS7NixA2az2UnaAn5+fli6dClCQ0OhUCjE8tTUVGRmZjrtvn0hIiICTz31FMaOHYvg4GCr/SaTCW+++SZaWloQFBSE3/3ud2CMoaSkBIcOHXKa+wQAMjMzodPpUF5ebvN3ajab0dTUhKamJqjVarFcq9Wira3NKToNWAOuVCoRHx+PpKSkbo/pegpptVpIpVLI5XKL/YwxxMbGIjY2FjfffDO++OIL1NXVOcz3GBISgilTpmDt2rWQy+VoampCTk4OcnNzUVtb2+O5EokEQUFBeOqpp6DRaKBSqdDS0oLKykq0trY6RL/u0Gg00Gg08Pf3d+p9rpfrfUi7i4iICPGhfS1NTU1oa2uDj48PfHx8IJV2/xNLSEjA/fffj507dzrNgAcGBiI+Ph6rVq1CQkIClEqluO/SpUvIyclBZGSkWGYymVBWVubUB0pPhIaG4qWXXhK3zWYzamtrodVqYTQaYTKZ8M4776Curg7Dhw/H3XffDT8/P5w7dw5ffPGFU98rnDt3zq73MeHh4Rb+eY1Gg4aGBuco1Zvbw5ECO4cLgiDQsmXLaN++fT0On4xGI7W2ttL+/fspNze31+FVWloaxcXFOWxINXToUJo7d67o6zSbzWQ2m8lkMtktXed0nffAAw84fXi6YMECeu+99yzq5+zZs24dMjPGSKlUUkNDg4VeA9GFEhYWRikpKaTT6aza2ZYtW+iOO+6gLVu20JUrV3psk2azmWpra0kulztN19dee82inV3NI488QklJSRbtsba2lgIDA0kikbilbidOnGiho1qtpgcffJDCw8OJMWb13kAul9Ovf/1ruuOOO9zeLrrk2LFjFnW9bNkyR1zXc1wogiDgmWeeQVxcXI+9sqKiIvz+979Hfn4+BEGAj48PAGDjxo2YPHmyxbFeXl6YOHGieIwjkEgkFkPSLl372pPsimiQyWRO7UlERkYiLCzMadfvC0QEo9FoNTRVqVTw9vZGe3u7mzSzpq6uDseOHcOsWbOsvuvq6mrU19fjwoUL+Nvf/mYxKnz++eexaNEicZsxBn9/f6Snp2PdunU4fPiww3RkjGHDhg2YN2+e1UjBYDDgww8/RE5ODgBY7Pf398fevXuxdu1ah+pjD+vWrcP9998vbm/btg2bNm1CUVERNBqNTbeFwWDAt99+61TXib1ERkbi888/x9ixY102mhyQBpwxhpiYGAu/rFqtRl1dHSoqKnDlyhUQkej3qq2ttQgXbGxstLpml+8xKioKZWVlqK+v77eeWq3W4XGn3t7e8PX1tfChOQqJRILIyEgkJiYiNjbW4dfvL7bqcf78+QCAI0eOuFqdbjEYDGhsbMTRo0e7PcZWG1y2bJlVmVQqxZQpUyxmnzoCxhhuueUWREdHi2VVVVWif3b//v2oqKiAQqHAtm3bkJKSApVKBZlMhqlTp+LOO++EIAguicGXSCRYtGgR5s+fjxEjRuDw4cOorq7GN998g4yMjB7PJSIrd6UgCGCMuTSEeOTIkbjllluQnJwsGm+DwYCvv/4aJSUlTrvvgDTgtnyLhYWFOHbsGPbs2YO0tLQevxy9Xg+dTgeZTAbGmMXT8Je//CUaGxvx008/9VvPxsZG5Obmii8oJBKJlR/+elGpVAgICHCKAVcoFJg9ezYWLFiA0aNHi+VGo9Fmj9/LywuMMZjNZphMJrf0cl555RUEBwcPKAPeF2Qymc12DXT80B3pc+56OR0bG4ugoCCYzWZotVqcOHECFy9eRFFREfbv3w+9Xg8AWLp0KXJycjBmzBhIJBIAHSkMfvGLXzjdgEskEgQEBOCf//wnVCoVCgsL8fbbb+PgwYPX9S7o6t9e16jYGb+h7pg/fz5WrlwpbhuNRqjVaqxYscK577Rs+VWcJbDD1xMSEkILFy6kxsZGC1/YjBkzyNvb2yJspzvx9fWllJQUysnJIYPBYHGdtrY2+vDDDx3q8/L396cxY8bQ448/3qPP0x4efPBBEgTBKb658PBwampqsohRJSL617/+RbfffrvFsSqViqqqqqixsZGys7Np9erV5Ofn51Tfoa0wwry8PFq3bp1T7+tskUql9Pbbb9O5c+esvm+DwUA5OTk0a9Ysh91v1KhR9M4771BzczMREZWUlFBAQAD5+vqSl5eXTZ/7q6++SkeOHLHQbd++fU6vm/vuu48aGxtFn3FmZqZdv/HurtMlBw8edOl3vGHDBovQ0c8++4z8/f0deQ/P8IF7e3sjOjraqrfS1tZmtx+0tbUVWVlZeP3117F582aLaAtvb+9+95KvRaPRiL3Yxx9/vNfj7777bsycORMKhcLKV2Y2m50SAZCSkoKlS5dCqVRa3bO4uBhnzpyBr68v1qxZg4iICMjlcoSEhEAqlUIQBCxduhTbt2/vNTzS0QwdOhSBgYF2Hz9x4kSsWrVK3G5ubkZGRgb27dvntJmmAQEBmDBhAgRBwNixYzFmzBiL/YIg4JZbbhEzLF67Lzw8HN7e3g7TJzg4GPPnz4dCoRDfLXTnQ+4iKSnJyq0WGBiImTNn4ujRo2Jv3dF0zWIFgH//+9/YuXOn3SO9gIAArFq1ChERERg3bpx4nc8++wy7d+92ir7dceTIEQwbNgzLly8H0OEF0Gg0Tr/vgDTgkZGR4lCur1RVVSE1NRXvvfeegzTrmdbWVly8eBEXL17s9dgRI0bg1ltvtSgzmUy4ePGiU4xMQkICUlJS8PDDD9vc7+XlhSFDhiA2NhbLli3DqFGjLParVCpMnjzZ4Q8+W7S1tcHPz0+Mrw4KCrIIfeuJkSNHYtasWRYP0cbGRgwbNgwNDQ3Iz89HRUWFQ/WNjo5GXFwcZs+eDYlEgunTp2PatGl2ny8IAoKDgxEXF4f8/HwUFxf3+wHu6+srPkTa29t7fOjKZDIMGTIEo0ePtnqx3fUgOHXqlNMM+NUcOnQIX3/9da/HBQUFISAgADfddBMefPBBDB8+HAaDAWfOnAEA7Ny50+UGPDs7GwqFQjTgUqkUXl5ezs+BY6tb7iyBHUOFiRMnUmpqKmm1Wovh3KRJk6572CGTyai6utpq2PrJJ5+4dHh1rWzatMlKJ41GQ/Hx8eTj4+PQezHGqLi42Op+14vJZKLhw4c7tV5kMhnt2bOHKioqLO79wQcf2HV+enp6j5/B3utcj3z11Vf9rtsu0tPTyd/fv99T7OfMmSMO5XNzc2nz5s3dHhsZGUl//etfRdfVtaGGBoOBwsPDnfadP/TQQ+K9nn766V5nUjLGaNWqVZSWlibqm5ubS3/605+c2jbtkVGjRon1t2vXLrr55psd6Q71jKn0CxYsII1GQyaTyaIhTZ069borQyKR0J/+9CfKzMy0uJa7DLhUKqVDhw5RbW2t1Y9XrVaTSqVyeH4MTzLgACg6Opo+/fRTi3v3ZnhDQ0PpzJkz1NLS0uNncJQBj46OpsWLF1NWVhap1ep+120XGo2GTp48SUOGDOmXfrGxsfTqq6+SRqMhrVZLOTk5NtvV+vXrKS8vj2pra8lgMNClS5fo+++/t/jtudKAV1RU0Pbt27s9duzYsfTxxx9TcXExNTc3U1lZGU2ePJni4+Np6NChTm+bvYlcLqfExETKyckhjUZDZ86cIW9vb0dd3zN84FKpVMw4eDX080PAbsxmM06fPt3jbE5XwhjD6NGjLdJgAh2xwzk5OdDpdA4NSQQ66u3TTz/F/PnzMWXKlH5dKywsDGq12qlZCy9fvmx1/cTERKxcuRJbt261GX0kk8mQmJjYbZRHF97e3ggODu5XCKmPjw/Gjh2Le+65B+PGjRNdfUSEtrY2qNVqNDc3QxAE1NXVISAgwMon3h1KpRKJiYlITk7GqVOnUFZW1icdGxoakJaWhqeffhoqlQphYWF48sknrdrW3LlzkZCQACJCeno6srOzUVxcjNmzZ0MQBFRXV+OHH35wWQx+eHg4xowZg3HjxqGgoAA6nQ5eXl5Yvnw5BEFAdHQ0pk+fjsjISJw8eRL79+/HmTNnoNfrHf676Qt6vR7nzp3Dp59+igULFiA5ORmPP/449u7di4KCAqfcc8AZ8O4wmUx9+pIKCwttxuS6GplMhsDAQJsB/kVFRdixY4fT4lb/8Ic/oKmpCfHx8TYfjleHWra3t8NgMMBkMiEgIMBC3/Hjx6OlpQUFBQVOnWik1WrR2toKX19fAMCMGTMwbtw47N69G42Njb2+5CIi6PV6SKVSi3cpXdPKjx071ucffGRkJKZPn44HH3xQLDOZTNDpdCgtLUVhYSEuX74MuVyOc+fOISEhwaYB12g0kEql4sQyxpiYO3zx4sXQ6/V9NuAajQZHjx5FTU0NlEolQkJC8MEHH9g81mg0orm5GVu2bEFGRgbUajX+8pe/AOh4mG7cuBEtLS190sMe9Ho9Ghsbxbbm7++PlJQUtLS0oLGxEcHBwXjvvffEyW3Nzc1oamrCtm3b8O677zpNr/7w9ttvo6qqCsnJydi4cSO0Wi2qqqqcEwBgq1tOlm6PKADp6Fi4+DyA33aWBwHYD+Bi599AO67V61Dh9ttvtzm87IsPXBAEWrlypegv68IdLpQ77riD2tvbbaYG+Pjjj50+dVkul9OoUaOoqqqK1Go1NTQ0UGVlJZWVlYnhZkREb775Js2bN4+GDRtGbW1tYnlX5sIvv/yS7rnnHqfqmpiYSE8//bRFXRkMBsrIyKDZs2dbHR8VFWUx7DcajfT555/ThQsXLOpZrVbTqVOn+jV1PTMz0yorZmFhIb3wwgsUHBxMMpmMpFIpyWQykkgk9MADD1h932azmSZMmEDPPPMMNTc3W3zOrnSlL730Ur/rUaFQ0IYNG2z+nro4c+YMDR06lGQyGTHGKDg4WEwRcOHCBXr22WfJ19fXad+1IAgUGBhIGo1G/Px6vZ7a29upqqqKLly4INbP999/TwqFghQKhdum+tsrU6ZMob///e+k1+tJp9PRjz/+2N9r9tmFYgTwOyI6zRjzA3CKMbYfwAoAaUS0gTG2DsA6AC/Ycb0e0Wg0OHXqFMaPH2+R6e2NN97Ajh078NFHH9l1na7JDA899BASExP7q1a/WL16NW677TZ4eXlZ7Vu7di0OHjzo9Fljer0eFRUVWLVqFSQSCYgIZrMZRASZTCYmXbpw4QKqq6uhVqvx0EMP4bnnnsMtt9wi1ue0adOgUCiQmprqtGHr5cuXcf78eYsyQRAQFxeHV199FU888QSICN988w1yc3OteuSMMUybNs0q/PDs2bP4z3/+06e6Hj58ODZs2IARI0ZYtMstW7YgPT0dJ0+eFMNJhwwZgvfffx+MMYtEUVej1+uxd+9eaDQabNmyRbwmY0wM3ewvOp0On376KU6fPg0fHx/cf//9GDp0KJqbm7F582bo9Xqo1Wo0NjbaHFG1t7ejqKjIqW2zK4Pf8uXL8T//8z+YPXu2mE4iKCgIfn5+Nj/XQKegoABbt27Fww8/DB8fH/j5+SEqKgpVVVUOnRDXqwEnoioAVZ3/NzPG8gBEAFgEYGbnYVsB/AAHGPDa2locOHAACQkJFj+UuXPnQq/X4+zZs6isrERNTU2PQ7vg4GAkJSUhKSnJYkp+Xl6ey5Zbk0qliI+Px8KFCzF16lSbxxw4cABZWVku0aetrQ179uyx+/g9e/ZYuAqADj+lyWSCIAhO+2E3NzejtLQUhw8fRlJSEpRKJQRBQEBAAGbMmAEAYg8kIiLCasZdl7/0WsrKypCent6nMD1/f38sWbLEKrtgc3MzGhsb0dbWhsmTJ0MikSAsLAxLliyxaYRbWlpw+vRptLa2oqSkRHyn4O/v32Pmwr5y/vx55ObmQqlUIiAgAKGhoWhqasJXX31lZQgFQbD4zbW3t6OkpMTpM3BNJhN27tyJuLg4BAQEiO+sulIddxEUFIRf/epX0Gq1uHTpEsrLy52qV3+Qy+Xw8/OzcEE64/u93iiSGAClAFQA1Nfsa+zmnMcAZHaKXcMFmUxGV65cselu0Ol0tGnTJkpKSur2fMYY3XnnnTaHjPfddx9FRES4ZBgVFBREu3btshl10pV98Oabb3b7cK+7OgwNDaWvv/7aSvfS0lJSKBQuWVEmOzvbInNjXzGZTPT+++/3WY/x48d3e+2LFy/S888/bzV72BZZWVkW1/X396cjR45YZWF85ZVXXP6de3t7U0JCguhCccVMzGtl4sSJdn3Xa9euFbMTumNlo+7u21W+YsUKUVez2UxHjx7t7z37F0YIQAngFIB7OrftMuDXHGOXsoIg0IoVK8SVY67GbDaTRqOhRx99lKKjo2nmzJnk5eUlnhcaGkrbtm3rNpXnlClTXOY/CwsLszl1nYiooKCAYmJinJpKtL8NNCgoiFJTU610Ly0tddp0/2slPDycJkyYQPfee6/N9K32YDabaf78+RQUFNRnPXoy4Hq9nurr661CX21xrQGXSCQ0ceJEqxh2dxjwCRMm0Msvvyymn3CHAZfL5RQTE0OFhYU91mN9fT0VFxdTcXExvf/++3T33Xe7VM8tW7ZQeno6rV69Wpy7ERISQm+++SadOnWKampqRF23b99Oy5cv7+89+x5GyBiTAfgKwGdEtKOzuJoxFkZEVYyxMAA19lzLHsxmM44dO4Zp06Zh6tSp8PHxsUjVqlKpsGjRIowZMwbBwcGYNWsWDAYDGGPw8/PDlClTMGzYMItrmkwmqNVqaLVal2QpGz9+PObMmQNfX1+roXRaWhq+/fZbp2Yp6y9EhPb2dpt11VXPLS0tTq/LyspKqNVqNDQ04P/+7/+gUqkQGxuLu+++u9dzS0pKUFRUhKNHjyI7O9tpSfW7/LXdYTAYsGnTJmi1Wly5csVin9lsRmFhoegONJvNePfdd12eyhUAYmNjMX/+fAiCgH379uHbb791uQ56vR4lJSW9RjkFBQWJdf6rX/0Kw4YNQ1xcHN555x2nzhqVSCQYPXo0Ro4cieHDhyMmJgYSiQRz5szBvHnzMHfuXMTHx4sRVEDHQhAnTpxwij69GnDWYTm3AMgjoo1X7doFYDmADZ1/e58Dex3k5+fjwoULKCkpsRmGdccdd9h9LYPBgJaWFpw7d85lMa2TJk3CypUrLXxgXUbxm2++webNm12iR39ob29HXV0d6uvrLZa3kslkiI+PR15enktyo7S1teHy5cv44x//iJCQECQnJ+OXv/xlr+edOHECP/zwg90vvnvCaDSitrYWgYGBdvsyuzoMLS0teP31123GzxMRNBoNrly5gvLycnHJsGsNvSuIiooS3zHs3r0bu3btcrkOfSExMRGJiYlISUnBzp07UVZW5rwlzKRSjB07Fn5+fvDy8kJERAQiIyOxZMkSPPHEE+JxOp0Ora2taGtrw/nz55GXl+cUfexxnUxHRxc+B0B2p6QACAaQho4wwjQAQXZc67qGDYwxGj16dL98n0Qd04n//ve/u2zYD4DWrVtntQqK0Wik//73vzR37lyXD037KqGhobR06VKL+jSbzWQ0Gh2aQe96RRAEu8RR/lGlUknz5s2jS5cu2d3uVq9eTYmJiXa1O8aYqLO76vTqhXgXLVrk1naXl5d33b/zrnbpAHdFtyKTyWjq1Kl05MgR8T2W0Wi0cp8dOXKEVqxY4ch3RX1zoRDRTwC6W15iTm/n9wciwuXLlzFv3jx88MEHiI+Pv67zTSYT3nvvPRw9ehRnz5512Tp/jz76KGbMmGEz0+D777+PCxcuuEQPR9DQ0GCVMJ8xBolE4tY1LF29ZmPXwtMrV67Evffei6eeegpAR1RTZmYmtm7danVOfn4+Ghsb7dKVfu7kuIWIiAiLUZY7dQGA3/zmN7j//vuxYsUK+Pv729XWDAYDPvroI6sQVEdiMplw5coV6HQ6q7UGjEYj6urqkJqaip9++gnp6elOnyU64GditrW1IS0tDbt370ZpaSlCQkIwfvz4HrMVHj16VFyl57vvvsO5c+ccnoWuJ2bMmGGV0a+qqkp8kLg6JWt/6FocY7BjMpnEhUD8/PzE1LAFBQXIyspCWlqamzXsH2PHju02Zt0dHD9+HP7+/ggODoafnx9Gjx6NYcOGWYQEd9Ha2oqKigrk5OTgu+++c+pvnYigVqtx+PBheHl5iZknz58/j/z8fNTX1+O7775Dbm4uqqqqnKaHhUKuEvRzGBEbG0vLly+npqYm0ul03Yo7XRRSqZSKioqshnc7d+5065C0PzJ9+nTS6XSiO6jLNeROFwoXx8o777xDWVlZYnu988473a7T1bJ+/XrKyMiw+XvPzc2l1157zeU6LViwQNThf//3f519P89IZtUTly9fxrZt2/DNN9/0eJwrEqnb4qabbsKaNWsshqI3AhkZGYiKikJmZiaioqJgMplQXV3tkhzRHNeg0+mcmt+mv7z11lvYtGmTzRfIXbloXM3BgwfFkZhTl03rAY8y4GazGTqdbsAO6eVyOcLDw60a2datW90SkuUoDAYDampqsH79evj5+YGI0NraikuXLrlbNY6DEATBIdP3nUVbW5vTIkv6isFgQF1dnVt18CgDPtAxGo2or68XY6OJOpaz2rVrFw4cOOBm7frPP/7xD3erwHESXT1wk8mE4uJit/UoOdfHwH3keiCFhYV47LHHxKgNk8mEiooK1NTUODWHNofTX06ePImioiKo1WrceuutOHTokLtV4tgBc2W4UGc85A1PYmKiuKCswWBAcXGxU3Mqczj9xd/fH0FBQVCpVDh//nyf8+9znMYpIpp0bSE34BwOhzPwsWnAuQuFw+FwPBRuwDkcDsdD4Qacw+FwPBRuwDkcDsdD4Qacw+FwPBRuwDkcDsdD4Qacw+FwPBRuwDkcDsdD4Qacw+FwPJQbJpmVUqlEVFQUNmzYYDOrmk6nQ35+Pj788EOXLu5gD4GBgRgxYgReeuklyGQy5Obm4oUXXnC3WhwOZ4DjkQZcKpVizJgxCA0NFcv8/PwwfPhw3H777TZX69FqtRg5ciTOnDmDzMzMAZMKNTIyEvHx8Zg2bRpuv/12yOVyxMbG4sCBA/jpp59ctggzh8PxQK5jNR0JgCwAezq3gwDsR8eixvsBBDp7RR5BEEgqlVJISAht377d9sqm1LF4sMFgsBKj0Ujnz5+nV155xe0rjAAdC9k++eSTtHv3bgv9zWYz6fV6iomJcbuOXLhwGRDS7xV5fgsgD4Cqc3sdgDQi2sAYW9e57bRx/6233orbbrsNy5cvB2PM5tp4XSxbtgxHjhyBt7c3WlpaQERQKpWYNWsWkpOT3bZiz9UEBgbivvvuw+rVqxETE2Ox7+TJk7jrrrtQU1PjHuU4HI5HYJcBZ4xFArgdwP8DsKazeBGAmZ3/bwXwA5xgwENCQrB69WqMGDECCQkJCA8Pt9hPRKirq8PJkydx/PhxGI1GHD9+HOXl5ZDJZOLqPb6+vqioqEBlZSXUarWj1bxuVCoV7r33XkRERMDLy8tin16vd82CqBwOx6Oxtwe+CcBaAH5XlQ0joioAIKIqxthQWycyxh4D8FhflJPL5YiKisILL7wAqVRq8XKyvr4era2tYr7t1NRU/Pvf/7ZYp/HqpdfMZjMaGhpw6dIlVFdX90Udh8AYg4+PD8LDwzF79mwwxiz2V1dXo7Ky0k3acTgcj8IOv/UdAD7s/H8mfvaBq685rtHRPvCJEyfS888/L66CfjVr166l2NhY6swx7jHi5+dHS5YsoS1btlh9JiKipUuXul1HLly4DDjpsw88GcCdjLEUAF4AVIyx/wCoZoyFdfa+wwA43GE7efJkLF68GEBHz1Wj0aC4uBhr1qxBQUEB6urqPGrVEKlUiu3btyMmJgaBgYEW+4xGI7Kysty+SCqHw/EcejXgRPQigBcBgDE2E8BzRPQwY+xtAMsBbOj8+7WjlVMoFPD19RW3a2trceLECRw5csTCVTLQCQsLQ3h4OBISEpCcnAw/Pz+rY4xGI3bv3j3gYtQ5HM7ApT9x4BsAbGeMPQKgFMC9jlHpZ2pra1FSUoKxY8cCACoqKnDgwAEIggDGmMf0vidNmoSUlBQ88cQTNvcbjUao1Wr8+c9/5nHfHA7HfuyNA3eE4Dr9PkqlkqZNmyb6ivV6PTU0NNAf/vAHGjt2rLt9UnbL+vXr6fjx4zZj1omI/vWvf5FKpXK7nly4cBmw0u84cJfT1taGpqYmcVsmk0GlUuGuu+7CuHHjcPLkSbz11ltu1LBnVCoVXn31Vdx6661Wsd5dvP7669i3b5/F5+RwOBy7GMg9cAAUExND2dnZ1NraatVzzc7OpvHjx1NwcDApFAp3PyGtJDQ0lBobG8lgMFjo3draSjU1NZSdnU0JCQlu15MLFy4DXmz2wAe8Ae+SEydOiOGEV4ffmc1meuKJJ2jMmDHurmArCQ8PtzLeZrOZTp06RX/+85/drh8XLlw8RjzbgI8YMYISExNp4sSJlJaWRtXV1aJBLC8vp4KCAjp48CAJguDuiiYA9Oijj1JeXp7Fw8ZkMtHhw4fpkUceodDQULfryIULF48Rz/OBX01RURGADj/49u3bkZGRgbi4OCxevBgREREAgKFDh+LJJ59EZWUlLl++jNOnT7tN3yFDhiAhIUHc1mq1qKurw44dO5CRkYErV664TTcOh3OD4Ck9cFsyadIkqq+vJ51ORyaTSezpHj9+nH7/+9+77WmpVCpp/fr1Fq6TK1eu0O7du3m0CRcuXPoinu1CsSWMMfLy8qI9e/ZQZWWlhasiNTXVbZX91VdfUWlpqYUB37VrF8nlcnc3Ai5cuHim2DTgHr2kGhFBq9XijTfewLPPPouXX34ZZrMZgiBg8uTJ+O9//wuVStX7hRwIYww33XQTQkJCLMrNZrNHzR7lcDgDH4824F0cOXIEe/fuxd69e2E2mwEA4eHhWLx4MRQKhcv08PX1RVJSEkJCQuDt7S2W19TU8BwnHA7H4XjMS8zeaGlpcfsCCKNGjcInn3yCIUOGAECX2whpaWn44Ycf3KgZh8O5EbkheuAA8Mgjj+DgwYM218N0Fd7e3hg9ejTkcjmAjoUZ0tLS8Le//Q2pqalu04vD4dyYeHQPXCqVIjAwEEuXLsXChQsRGxsr7tPpdKivrxddKq5AEAQLl41Op0N6ejpKS0vR2trqMj04HM7gwOMMOGMMYWFhYIzB29sbMTExeOGFFyyWWtNqtaivr8f58+dhMpncpqvRaERBQQEkEgmUSiVaWlrcpguHw7kB8bQwQqVSSRqNhoxGoyjXrmxz8OBBeu6551w+KzM5OdlCD7PZTEajkdLS0mjNmjXuDkPiwoWL54pnz8QEAKVSiejoaEilUpu+boPBgNzcXGzevBlHjx51qftEoVBYLU7MGINEIsGECRNQVlbmMl04HM7gwKMM+FU9eauywsJC5OTkID8/H5mZmSgtLXWpboIgWD1UjEYjysvLUVRUhPz8fJfqw+Fwbnw8yoC3traitLQUOp0OUunPquv1enz++edYv36923Qzm80wGo0Wk3VaW1uxd+9evPvuu9yAczgch8Ou7dE69WYdK8j39xoICgoCY8yivK2tDW1tbf29fL+QyWTw9/cXt4kIOp0O7e3tbn2ZyuFwPJ5TRDTp2kK7DDhjLADAJwAS0eFQXwkgH8AXAGIAlAC4j4gae7mO654WHA6Hc+Ng04DbO5HnrwC+I6IEAL8AkAdgHYA0IhoFIK1zm8PhcDguotceOGNMBeAMgJvoqoMZY/kAZhJRFWMsDMAPRBTfy7V4D5zD4XCunz73wG8CUAvgn4yxLMbYJ4wxXwDDiKgKADr/DnWouhwOh8PpEXsMuBRAEoCPiGgCgFZch7uEMfYYYyyTMZbZRx05HA6HYwN7DHg5gHIiyujc/hIdBr2603WCzr82UwES0WYimmSr+8/hcDicvtNrHDgRXWGMlTHG4okoH8AcALmdshzAhs6/X9txvzp09OB5cmxLQsDr5Fp4nVjD68SawVIn0bYK7Q0jvBkdYYRyAJcA/AYdvfftAIYDKAVwLxE12HGtTN4bt4TXiTW8TqzhdWLNYK8Tu2ZiElE2AFuVNMeh2nA4HA7Hbm6YBR04HA5nsOEOA77ZDfcc6PA6sYbXiTW8TqwZ1HXi0lwoHA6Hw3Ec3IXC4XA4HorLDDhjbCFjLJ8xVsgYG7R5UxhjJYyxs4yx7K7JTYyxIMbYfsbYxc6/ge7W09kwxv7BGKthjJ27qqzbemCMvdjZdvIZYwvco7Vz6aZO1jPGKjrbSzZjLOWqfYOhTqIYY+mMsTzG2HnG2G87ywd1WxFx0VJqEgBF6JiWL0dHbpUxrlzObaAIOjI3hlxT9haAdZ3/rwPwprv1dEE9zEDHhLBzvdUDgDGdbUYBILazLUnc/RlcVCfrATxn49jBUidhAJI6//cDUND52Qd1W+kSV/XAJwMoJKJLRKQHsA3AIhfd2xNYBGBr5/9bAdzlPlVcAxEdAnDtvIHu6mERgG1EpCOiYgCF6GhTNxTd1El3DJY6qSKi053/N6MjE2oEBnlb6cJVBjwCwNWLQpZ3lg1GCMA+xtgpxthjnWU8MVgH3dXDYG8/zzDGcjpdLF2ugkFXJ4yxGAATAGSAtxUArjPgzEbZYA1/SSaiJAC3AXiaMTbD3Qp5AIO5/XwEYASAmwFUAfhLZ/mgqhPGmBLAVwCeJaKmng61UXbD1ourDHg5gKirtiMBVLro3gMKIqrs/FsDIBUdwzu7EoMNArqrh0HbfoiomohMRGQG8DF+dgcMmjphjMnQYbw/I6IdncW8rcB1BvwkgFGMsVjGmBzAAwB2uejeAwbGmC9jzK/rfwDzAZxDR10s7zzM3sRgNyLd1cMuAA8wxhSMsVgAowCccIN+LqfLSHVyNzraCzBI6oR1LH67BUAeEW28ahdvK4BrolA63w6noOMNchGAl9399tYdgo4onDOdcr6rHgAEo2NZuoudf4PcrasL6uJzdLgEDOjoNT3SUz0AeLmz7eQDuM3d+ruwTj4FcBZADjqMU9ggq5Pp6HCB5ADI7pSUwd5WuoTPxORwOBwPhc/E5HA4HA+FG3AOh8PxULgB53A4HA+FG3AOh8PxULgB53A4HA+FG3AOh8PxULgB53A4HA+FG3AOh8PxUP4/qSoxaQLGAZ4AAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"batch = next(iter(lcd))\n",
"batch_x = batch[\"x1\"]; batch_y = batch[\"x2\"]; batch_labels = batch[\"labels\"]\n",
"img_grid = make_grid(batch_x); img_grid = img_grid.permute((1, 2, 0))\n",
"plt.imshow(img_grid); print(batch_labels)"
]
},
{
"cell_type": "markdown",
"id": "present-donna",
"metadata": {},
"source": [
"## Load Model\n",
"-----------------"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "rural-kelly",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"'''\n",
"Note: If a model was saved using torch.save(model.state_dict(), ...) then the architecture needs to be reinitialized before loading the model weights \n",
"using model.load_state_dict(torch.load('saved_state_dict.pth')). However, if we saved the entire model using torch.save(model,...) then we can\n",
"simply use torch.load(model) directly. The Pytorch documentation reccomends saving the state_dict. \n",
"'''\n",
"model = models.resnet18()\n",
"model.conv1 = nn.Conv2d(1, 64, (7,7), (2,2), (3,3))\n",
"model.fc = nn.Linear(512, 32)\n",
"model.train(); print() #; torchsummary.summary(model, input_size = [(1, 28, 28)], device=DEVICE)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "understood-dimension",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.load_state_dict(torch.load(MODEL_WEIGHTS))"
]
},
{
"cell_type": "markdown",
"id": "continuing-logan",
"metadata": {},
"source": [
"## Calculate Distances\n",
"-----------------------"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "afraid-feedback",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████████████| 300/300 [03:03<00:00, 1.63it/s]\n"
]
}
],
"source": [
"avg_dist_array = np.zeros((9,9))\n",
"\n",
"for idx, batch in tqdm.tqdm(enumerate(siamese_test_loader), total=len(siamese_test_loader)):#, total=int(N*0.1)):\n",
" \n",
" x1 = batch[\"x1\"]; x2 = batch[\"x2\"]\n",
" labels = batch[\"labels\"]\n",
" x1.to(DEVICE); x2.to(DEVICE)\n",
" emb_x1 = model(x1)\n",
" emb_x2 = model(x2)\n",
" N = 0\n",
" for i in range(len(emb_x1) - 1):\n",
" for j in range(len(emb_x2) - 1):\n",
" label1 = labels[i]\n",
" label2 = labels[j]\n",
" dist_ij = (emb_x1[i] - emb_x2[j])**2#; print(dist_ij)\n",
" N += 1\n",
" avg_dist_array[i,j] = ((N*avg_dist_array[i,j]) + dist_ij.mean()) / (N+1)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "robust-jumping",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<AxesSubplot:title={'center':'MNIST Distances - NTXENT loss'}>"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAEICAYAAABhxi57AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAf8ElEQVR4nO3de5xcZZ3n8c83IXJJkIBAzM0BJSIgGsYYHXEEBSQoGthdluCIkUHj7IDCjjMrMDsvQDaMuoij63oJF4mAxHDTiIyKWRFY5RIwXEISCCSQJiHhDhEI6e7f/HGehkPSdenuqjpVJ993XufVVc+5/J6q7vzqqec85zyKCMzMrPWGFV0BM7OtlROwmVlBnIDNzAriBGxmVhAnYDOzgjgBm5kVxAm4RCT9taTlRdfDhkbSHpJC0jZF18WaywkYkLRK0iuSdt2sfHH6j7BHen5Jej41t81ekiL3/EZJn8s9P0PSSkkbJHVJ+mkqX5LKNkjqkfRy7vkZ/dTxLEmbJL2QlgckfVfS2L5tIuLmiNi7jtd7lqTLBvg2tS1JB6ffy//drPwWSZ9Nv4O+9/bl9H73PV8iaYKkZyR9MLfvxFT2vvT8xs1+Rxsk/aIR8VvxHll7cgJ+zUrguL4nkvYHtu9nu6eB/1XPASXNBI4HDo2IUcAUYCFAROwXEaNS+c3AyX3PI+LcCof8aUTsCOwCHA28Gbgzn4S3Yn8GPtP3YZkXEefm3uu/A/6Ye6/3i4gu4CvAhZK2S7v9EPhRRNyWO1T+dzQqIj7RiPiNePHWmZyAX3Mp8Jnc85nAj/vZbi7wLkkH1XHM9wK/joiHACLi8YiYM9SKRsSmiFgCHAs8AXwZXm2JdfVtJ+krkh5LLeblkg6RNA04Azg2tcDuTtueIGlp2vZhSV/IHefg1Hr/sqT1ktZKOiG3fntJ35T0iKTnUstv+7Tu/ZL+IOlZSXdLOji332dTrBfSt4S/GcLb8ixwCXDmIPe/AFgLnJk+OPcG/mcL41ckaZykBZKelrRC0udz66ZKWiTpeUnrJJ2fyreTdJmkp9J7f4ekMY2umw2N+5hecytwvKR9gAfIktsH2bK1+yJwLjA7ra91zO9Iegz4HfCniOhpVIUjokfSz4HDN18naW/gZOC9EbEmtcyGR8RDks4F9oqIT+d2WQ8cCTwMfAj4d0l3RMRdaf2bgZ2A8cBhwFWSfhYRzwDnAfsBHwAeB94H9EoaD/yS7FvAr4BDgKslvYPsffxOqt/y1IrfZYhvyWzgAUlfi4gB9YVHRCjrOroL6AX+c0S82Kr4NVwBLAHGAe8AbpD0cEQsBL4NfDsiLpU0Cnhn2mcm2e9rIrARmAy81MA6WQO4Bfx6fa3gw4BlwGMVtvsh8BZJR1Q7WERcBnyRLEH+Hlgv6bTGVReANfSfuHqAbYF9JY2IiFV9LfEKdf1lRDwUmd8DvwH+OrfJJuCrqfV9PbAB2FvSMOBvgVMi4rGI6ImIP0TERuDTwPURcX1E9EbEDcAi4GPpmL3AOyVtHxFrU6t+0CLiceAHwFcHeYhHyN7P54Gb+ln/ndSa7FvOaXD8LUiaSPZB/5WIeDkiFgMXkn2oQfZ72UvSrhGxISJuzZW/ieyDtici7oyI5xtVL2sMJ+DXuxT4FPBZ+u9+ACAll3PSomoHjIjLI+JQYDRZ/99XJW3RYh2C8WT90pvHXQGcCpxFlvjnSRpX6SCSjpB0a/qa+yxZksyflHwqIrpzz18ERqVttgP6S+5/ARyTT1pkyWRsRPyZ7FvG3wFrJf0ytYz7q1v+xNdbKr2G5OvA4ZLeXWO7/pwGPEX2beAf+1n/pYgYnVv+pcHx+zMOeDoiXsiVPUL2ewc4EXg7sCx1MxyZyi8Ffg3Mk7RG0jckjWhQnaxBnIBzIuIRspNxHwOuqbH5j8i+4h1d57E3RcSVwD289jVxSFLr8xNkJ/H6i/mTiPggWSIMsuRAepw/zrbA1WRdCWMiYjRwPTU+XJIngZeBt/WzbjVw6WZJa2REfC3V79cRcRgwluwbxwUVXkf+xNej1SoTEU8B/0b24Vg3SfsC/wR8jiypnSFp0kCOMZT4VawBdpG0Y67sLaRvZxHxYEQcB+xO9vu9StLI9Pd2dkTsS9Y1dCSvP8dhbcAJeEsnAh9JLbSKUmvwLLKz5/1KJ5k+LmlHScNSl8V+wG2V9qmHpBGpr/oKsr7Z8/vZZm9JH0nJ9WWy/r++/ud1wB4pgQO8gay74gmgO9Xzo/XUJSJ6gYuB89PJouGS/irFvQz4hKTDU/l26YTeBEljJH1S0kiyPsoNufoN1flkSWefejZO78NFwDciYllE3EPWPz1HUj0fQkOKX01ErAb+APxrev/eRfY3enmq+6cl7ZZ+D8+m3XokfVjS/pKGk3WpbKJx7681iBPwZlI/6KI6N7+C7Mx5Jc+TjTh4lOw/xzeA/xYRtwyyesdK2pCOtYDs6/J7ImJNP9tuC3yNrIX6OFkLqW988ZXp51OS7kpfb78EzAeeIeuGWTCAev0jcC9wB1l3yNeBYSl5TE9xnyBrEf8T2d/dMLLRG2vSPgcBfz+AmBWlvs5vUP9JvVOAHdI+fc4h+3D7XK7su5t1h9zZoPi1HAfsQfZeXQucmfrTAaYBS9LfxbeBGRHxcqr7VWR/g0vJzkGUZux3WSh8Q3Yzs0K4BWxmVhAnYDOzgjgBm5kVxAnYzKwgTb8UeZs3jG/5Wb51h+3V6pAA9LxcSFhG/af9C4n755/dW0jcbd/yhpbHfPHBTS2PCTB69omFxH35uxcXEnf0T383mGF/r7PpyYfrzjkjdn3rkOMNhVvAZmYF8c14zKxcejvnehMnYDMrl57u2tu0CSdgMyuV7KrszuAEbGbl0usEbGZWDLeAzcwK4pNwZmYFcQvYzKwY4VEQZmYFKdNJuDRP13SyOaiC7KbQCyJiaZPrZmY2cB3UBVH1UmRJXwHmkc0NdjvZjAcCrqg2u6+kWZIWSVrU21t1Zh8zs8bq7al/KVitFvCJwH4R8bo7kUg6H1hCNuXNFiJiDjAHirkZj5ltxTqoBVwrAfeSTYv9yGblY9M6M7P2UqKTcKcCCyU9SDahImRTYu8FnNzEepmZDU5ZTsJFxK8kvR2YSnYSTkAXcEdEFN+BYma2mUalJknbATeRzTC+DXBVRJwp6Szg82QzfQOcERHXp31OJ+u67QG+FBG/rhaj5iiIyO5scetgX4SZWUs1rg94I/CRiNggaQRwi6R/T+u+FRHn5TeWtC8wA9iPrOv2t5LeXq2x6huym1m59PbWv1QRmQ3p6Yi0VBtUMB2YFxEbI2IlsIKs96AiJ2AzK5forXvJD5lNy6z8oSQNl7QYWA/cEBG3pVUnS7pH0sWSdk5l43ntXBlk3bXjq1XVCdjMyqVnU91LRMyJiCm5ZU7+UBHRExGTgQnAVEnvBL4PvA2YDKwFvpk2729+uarDcJ2AzaxcGtQFkRcRzwI3AtMiYl1KzL3ABbzWzdAFTMztNoHsyuGKnIDNrFwG0AVRjaTdJI1Oj7cHDgWWSRqb2+xo4L70eAEwQ9K2kvYEJpFdQVyRb8ZjZuXSuHHAY4G5koaTNVbnR8R1ki6VNJmse2EV8AWAiFgiaT5wP9ANnFRruG7TE/C6w/ZqdogtjF+4suUxARaNe1chcXdYubr2Rk2w7qEdC4n7xmdeannMrsfe1PKYAAesebiQuM8uL6ZtNroRB2lQAo6Ie4AD+ik/vso+s4HZ9cZwC9jMSiV6NtXeqE04AZtZuZToZjxmZp2lLPeCMDPrOG4Bm5kVxC1gM7OCuAVsZlaQ7vLckN3MrLO4BWxmVhD3AZuZFcQtYDOzgnRQC3jQd0OTdEIjK2Jm1hANuhtaKwzldpRnV1qRv8v83EfXDiGEmdkAdXfXvxSsaheEpHsqrQLGVNov3VV+DsBTHz+o6h3hzcwaKjon5dTqAx4DHA48s1m5gD80pUZmZkPRQX3AtRLwdcCoiFi8+QpJNzajQmZmQ1KWBBwRJ1ZZ96nGV8fMbIja4ORavTwMzczKpafqLEBtxQnYzMqlg7ogPCuymZVLg6all7SdpNsl3S1piaSzU/kukm6Q9GD6uXNun9MlrZC0XNLhtarqBGxm5dK4CzE2Ah+JiHcDk4Fpkt4PnAYsjIhJwML0HEn7AjOA/YBpwPfSjMoVOQGbWalEb9S9VD1OZkN6OiItAUwH5qbyucBR6fF0YF5EbIyIlcAKYGq1GE7AZlYuA+iCyF+1m5ZZ+UNJGi5pMbAeuCEibgPGRMRagPRz97T5eGB1bveuVFZR00/CFTEiZPFb3tn6oMC7HllcSNynug8oJG5vrwqJO3xE66902uENBU113gaXy3acAYyCyF+1W2F9DzBZ0mjgWknVkkt//yGq/rF6FISZlUsTRkFExLPp4rNpwDpJYyNiraSxZK1jyFq8E3O7TQDWVDuuuyDMrFwaNwpit9TyRdL2wKHAMmABMDNtNhP4eXq8AJghaVtJewKTgNurxXAL2MzKpXE34xkLzE0jGYYB8yPiOkl/BOZLOhF4FDgmCxtLJM0H7ge6gZNSF0ZFTsBmVi4N6oKIiHuALU6wRMRTwCEV9pkNzK43hhOwmZVLjeFl7cQJ2MzKxfeCMDMrRnTQvSCcgM2sXNwFYWZWEN8P2MysIB3UAq55IYakd0g6RNKozcqnNa9aZmaD1N1T/1KwqglY0pfIrvL4InCfpOm51ec2s2JmZoPSuNtRNl2tFvDngfdExFHAwcC/SDolrat4J5b8HYZ+vHptQypqZlaX3qh/KVitPuDhfffDjIhVkg4GrpL0F1RJwPk7DD15xEHFv0oz22p00jC0Wi3gxyVN7nuSkvGRwK7A/k2sl5nZ4JSoBfwZsptKvCoiuoHPSPph02plZjZYbZBY61U1AUdEV5V1/7/x1TEzGyJfimxmVoxac721EydgMysXJ2Azs4J00CgIJ2AzKxe3gM3MCuIEbGZWjOhxF8Srdpg+udkhtjBhxaqWxwR46pXJhcTd+Xt3FRL3mVnvLiTusPETa2/UYG9cvb72Rk2wzeEnFBJ3t2XLC4nbEB3UAva09GZWKtEbdS/VSJoo6XeSlkpa0ncfHElnSXpM0uK0fCy3z+mSVkhaLunwWnV1F4SZlUvjWsDdwJcj4i5JOwJ3SrohrftWRJyX31jSvsAMYD9gHPBbSW+vNjW9W8BmVi69A1iqiIi1EXFXevwCsBQYX2WX6cC8iNgYESuBFcDUajGcgM2sVKK7t+4lf+vctMzq75iS9gAOAG5LRSdLukfSxZJ2TmXjgdW53bqonrCdgM2sZAbQAo6IORExJbfM2fxwaTagq4FTI+J54PvA24DJwFrgm32b9lObqv0h7gM2s1Jp5L0gJI0gS76XR8Q1ABGxLrf+AuC69LQLyA/RmQCsqXZ8t4DNrFwa1AcsScBFwNKIOD9XPja32dHAfenxAmCGpG0l7QlMAm6vFsMtYDMrlQa2gA8EjgfulbQ4lZ0BHJcmqghgFfAFgIhYImk+cD/ZCIqTqo2AACdgMyubBl0IFxG30H+/7vVV9pkNzK43hhOwmZVKdNfepl3UTMCSpgIREXekgcbTgGURUfFTwMysKG0w23zdqiZgSWcCRwDbpCtA3gfcCJwm6YDU3DYzax8dlIBrjYL4L2Qd0R8CTgKOioivAocDx1baKT+4+eKb76u0mZlZw0Vv/UvRanVBdKezeC9KeigNQiYiXpJUsfppMPMcgBd/cErn3JrIzDpeOyTWetVKwK9I2iEiXgTe01coaSc6qqFvZluL6Olv4EJ7qpWAPxQRGwEiXve5MgKY2bRamZkNUmlawH3Jt5/yJ4Enm1IjM7MhiN7ytIDNzDpKaVrAZmadJsItYDOzQrgFbGZWkN4SjYIwM+soPglnZlYQJ2Azs4JEB1172/QE/MK8u5sdYgvrVu3Y8pgA3T3DC4n7zN/uX0jcN11wTyFxl79jn5bHvPnJMS2PCXDMe39QSNyuK18oJO7e5wz9GG4Bm5kVxMPQzMwK0uNREGZmxeikFrBnRTazUole1b1UI2mipN9JWippiaRTUvkukm6Q9GD6uXNun9MlrZC0XNLhterqBGxmpRJR/1JDN/DliNgHeD9wUpqW7TRgYURMAham56R1M4D9yKZu+56kqmfmnYDNrFQa1QKOiLURcVd6/AKwFBgPTAfmps3mAkelx9OBeRGxMSJWAiuAqdViuA/YzEqlp7f+dqWkWcCsXNGcNKPP5tvtARwA3AaMiYi1kCVpSbunzcYDt+Z260plFTkBm1mpDORCjPz0aZVIGgVcDZwaEc9LFVvO/a2oWhsnYDMrld4GjoKQNIIs+V4eEdek4nWSxqbW71hgfSrvAibmdp8ArKl2/AH3AUv68UD3MTNrlQjVvVSjrKl7EbA0Is7PrVrAa1OyzQR+niufIWlbSXsCk4Dbq8Wo2gKWtGDzIuDDkkZnLzQ+WfUVmJm1WAPvBXEgcDxwr6TFqewM4GvAfEknAo8Cx2RxY4mk+cD9ZCMoTkqzyldUqwtiQjrYhWR9GQKmAN+stlO+Y/t/T5rE8ePG1QhjZtYYjeqCiIhb6L9fF+CQCvvMBmbXG6NWF8QU4E7gn4HnIuJG4KWI+H1E/L7SThExJyKmRMQUJ18za6We3mF1L0WrNStyL/AtSVemn+tq7WNmVqQOuhtlfck0IrqAYyR9HHi+uVUyMxu8Ro6CaLYBtWYj4pfAL5tUFzOzIeukm/G4O8HMSqWDJkV2AjazcomKAxfajxOwmZVKt7sgzMyK4RawmVlB3AdsZlYQt4DNzAriFnDOdnu2Psfv9PxLLY8JMGx4MdfgDJtY9Z7PTfPAPvsUEnfvZctaHnPJWwu6bHXiBwsJO3rsbYXEbYQet4DNzIpRY6ahtuIEbGal0usWsJlZMUp3Mx4zs07hk3BmZgXprTxpZttxAjazUqk6B1CbcQI2s1LppFEQxc/JYWbWQL2o7qUWSRdLWi/pvlzZWZIek7Q4LR/LrTtd0gpJyyUdXuv4TsBmVioxgKUOlwDT+in/VkRMTsv1AJL2BWYA+6V9vidpeLWDDygBS/qgpH+Q9NGB7Gdm1iq9qn+pJSJuAp6uM/R0YF5EbIyIlcAKYGq1HaomYEm35x5/HvgusCNwpqTT6qyUmVnL9A5gGYKTJd2Tuih2TmXjgdW5bbpSWUW1WsAjco9nAYdFxNnAR4G/qbSTpFmSFkladMnyx2qEMDNrnB7Vv+RzVVpm1RHi+8DbgMnAWuCbqby/NnXVno5aoyCGpew+DFBEPAEQEX+W1F1pp4iYA8wBeO6EQzvpwhQz63ADadnmc9UA9lnX91jSBcB16WkXMDG36QRgTbVj1WoB7wTcCSwCdpH05hR0FP1nezOzQjW7C0LS2NzTo4G+ERILgBmStpW0JzAJuH3z/fOqtoAjYo8Kq3pTYDOzttLIKeEkXQEcDOwqqQs4EzhY0mSy7oVVwBcAImKJpPnA/UA3cFJEVL0uZFAXYkTEi8DKwexrZtZMjbwXREQc10/xRVW2nw3Mrvf4vhLOzErFlyKbmRWkky5FdgI2s1Lx7SjNzAriBGxmVpBOuvDACdjMSsV9wGZmBfEoiJyXHqp4xXLTPLpm15bHBBg5YlMhcUevXl9I3JueGFNI3CVvbf1dVPdfuazlMQHWL7iu9kZN8HTXyELiNuIvqreDOiHcAjazUvFJODOzgnRO+9cJ2MxKxi1gM7OCdKtz2sBOwGZWKp2Tfp2Azaxk3AVhZlYQD0MzMytI56RfJ2AzKxl3QZiZFaSng9rAVa/plPQ+SW9Mj7eXdLakX0j6uqSdWlNFM7P6NXtSzkaqdVH9xcCL6fG3yWZJ/noq+1GlnSTNkrRI0qJL11adldnMrKFiAP+KVisBD4uIvrvpTImIUyPilog4G3hrpZ0iYk5ETImIKcePHdewypqZ1dLIFrCkiyWtl3RfrmwXSTdIejD93Dm37nRJKyQtl3R4rePXSsD3STohPb5b0pQU5O1AMbf+MjOropeoe6nDJcC0zcpOAxZGxCRgYXqOpH2BGcB+aZ/vSRpe7eC1EvDngIMkPQTsC/xR0sPABWmdmVlbiQEsNY8VcRPw9GbF04G56fFc4Khc+byI2BgRK4EVwNRqx686CiIingM+K2lHsi6HbYCuiFhXR93NzFquewB9u5JmAbNyRXMiYk6N3cZExFqAiFgrafdUPh64NbddVyqrqK5haBHxAnB3PduamRVpICfXUrKtlXDr1d9kSFUr0/qpBczMmqgFw9DWSRoLkH72TUnTBUzMbTcBqDoMzAnYzEqlBcPQFgAz0+OZwM9z5TMkbStpT2AScHu1A/lKODMrlUZeYCHpCuBgYFdJXcCZwNeA+ZJOBB4FjgGIiCWS5gP3A93ASRFRdY5QJ2AzK5WeaNwFFhFxXIVVh1TYfjYwu97jOwGbWan4dpRmZgVph0uM69X0BDz6vJOaHWILUx68p+UxAegu5uLA4YcdX0jcY95b8XYgzTWu6tj2plj/i2LOV7/ph4sLifvcvJMLidsI7XCTnXq5BWxmpeIuCDOzgrgLwsysII0cBdFsTsBmVirugjAzK4hPwpmZFcR9wGZmBXEXhJlZQcIn4czMitFJ09I7AZtZqXRSF0TV6yslfUnSxGrbmJm1k4ioeylarQvczwFuk3SzpL+XtFsrKmVmNlgNnhW5qWol4IfJptU4B3gPcL+kX0mamSbq7JekWZIWSVp00bW/bWB1zcyqa8GMGA1Tqw84IqIX+A3wG0kjgCOA44DzgH5bxPmJ7l6+/criX6WZbTXKdCny62b5jIhNZPMeLZC0fdNqZWY2SO3QtVCvWgn42EorIuKlBtfFzGzISpOAI+KBVlXEzKwRGjm6QdIq4AWgB+iOiCmSdgF+CuwBrAL+a0Q8M5jje1p6MyuVJoyC+HBETI6IKen5acDCiJgELEzPB8UJ2MxKpQWjIKYDc9PjucBRgz2QE7CZlUpP9Na95IfMpmXWZocLshFgd+bWjYmItQDp5+6DrasvRTazUhlIH3B+yGwFB0bEGkm7AzdIWjbU+uW5BWxmpdLIPuCIWJN+rgeuBaYC6ySNBUg/1w+2rk7AZlYqjeoDljSy74pfSSOBjwL3kV0LMTNtNhP4+WDr2vQuiBf/9bvNDrGFZx7cruUxAaKguVDefN/yQuKu/tnGQuLuPHZRy2M+2TWq5TEBnpt3ciFxdzz2/xQSt/uVQQ8oeFVv44ahjQGulQRZrvxJRPxK0h3AfEknAo8Cxww2gPuAzaxUGnWPh4h4GHh3P+VPAYc0IoYTsJmVSk9RX0UHwQnYzEqlgV0QTecEbGal0g63mayXE7CZlYpbwGZmBXEL2MysID3RU3QV6uYEbGal0g6TbdbLCdjMSqU0N2Q3M+s0pWkBS3oDMANYExG/lfQp4APAUmBOmiPOzKxtlGkUxI/SNjtImgmMAq4huwxvKq/dkMLMrC2UaRTE/hHxLknbAI8B4yKiR9JlwN2Vdko3Lp4FcP7kSczcY2zDKmxmVk2ZLkUelrohRgI7ADsBTwPbAiMq7ZS/yfHTRx/UOR9HZtbxStMHDFwELAOGA/8MXCnpYeD9wLwm183MbMBK0wccEd+S9NP0eI2kHwOHAhdExO2tqKCZ2UCUqQX86pQc6fGzwFXNrJCZ2VB4HLCZWUFK1QI2M+skZRoFYWbWUUpzEs7MrNO4C8LMrCBluhLOzKyjuAVsZlaQTuoDVjt/WkialS5rdtwSxXTc8sYsMm4nGlZ0BWqY5biljOm45Y1ZZNyO0+4J2MystJyAzcwK0u4JuKh+pK0p7tb0Wre2uFvTa+1IbX0SzsyszNq9BWxmVlpOwGZmBWnbBCxpmqTlklZIOq1FMS+WtF7Sfa2Il2JOlPQ7SUslLZF0Sovibifpdkl3p7hntyJuij1c0p8kXdeqmCnuKkn3SlosaVGLYo6WdJWkZel3/FctiLl3eo19y/OSTm123BT7v6e/p/skXSFpu1bE7VRt2QcsaTjwAHAY0AXcARwXEfc3Oe6HgA3AjyPinc2MlYs5FhgbEXdJ2hG4EziqBa9VwMiI2CBpBHALcEpE3NrMuCn2PwBTgDdGxJHNjpeLuwqYEhFPtjDmXODmiLgwza+4Q5rYoFXxh5NNqPu+iHikybHGk/0d7RsRL0maD1wfEZc0M24na9cW8FRgRUQ8HBGvkM0/N73ZQSPiJrJJR1smItZGxF3p8QvAUmB8C+JGRGxIT0ekpemfxpImAB8HLmx2rKJJeiPwIbK5FYmIV1qZfJNDgIeanXxztgG2TzOp7wCsqbH9Vq1dE/B4YHXueRctSEpFk7QHcABwW4viDZe0GFgP3BARrYj7b8D/AIq4a3YAv5F0p6RWXK31VuAJ4Eepy+VCSSNbEDdvBnBFKwJFxGPAecCjwFrguYj4TStid6p2TcDqp6z9+koaSNIo4Grg1Ih4vhUxI6InIiYDE4Cpkpra7SLpSGB9RNzZzDhVHBgRfwkcAZyUupyaaRvgL4HvR8QBwJ+BlpzPAEhdHp8ErmxRvJ3JvqnuCYwDRkr6dCtid6p2TcBdwMTc8wmU+KtM6oO9Grg8Iq5pdfz0tfhGYFqTQx0IfDL1xc4DPiLpsibHfFXfBLMRsR64lqyrq5m6gK7cN4uryBJyqxwB3BUR61oU71BgZUQ8ERGbgGuAD7Qodkdq1wR8BzBJ0p7pU3wGsKDgOjVFOhl2EbA0Is5vYdzdJI1Oj7cn+8+zrJkxI+L0iJgQEXuQ/U7/X0S0pIUkaWQ6yUnqBvgo0NTRLhHxOLBa0t6p6BCgqSdXN3McLep+SB4F3i9ph/R3fQjZOQ2roC3vBxwR3ZJOBn4NDAcujoglzY4r6QrgYGBXSV3AmRFxUZPDHggcD9yb+mMBzoiI65scdywwN50lHwbMj4iWDgtrsTHAtVleYBvgJxHxqxbE/SJweWpIPAyc0IKYSNqBbBTRF1oRDyAibpN0FXAX0A38CV+WXFVbDkMzM9satGsXhJlZ6TkBm5kVxAnYzKwgTsBmZgVxAjYzK4gTsJlZQZyAzcwK8h+/ru90qD54qQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import seaborn as sns\n",
"plt.title(\"MNIST Distances - \" + TITLE + \" loss\")\n",
"sns.heatmap(avg_dist_array)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment