Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save fepegar/c6163af3fcde2bad59207130ddd127d1 to your computer and use it in GitHub Desktop.

Select an option

Save fepegar/c6163af3fcde2bad59207130ddd127d1 to your computer and use it in GitHub Desktop.
Visualize positional embeddings in DINOv2 ViT.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyOLOMhi9qrUKPH3Um+2yz6b",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/fepegar/c6163af3fcde2bad59207130ddd127d1/visualize-positional-embeddings-in-dinov2-vit.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"%pip install -q einops"
],
"metadata": {
"id": "DR4_tjlhbQN9"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "PLM9MJHirgV6"
},
"outputs": [],
"source": [
"import matplotlib as mpl\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn.functional as F\n",
"import torchvision\n",
"from einops import rearrange\n",
"from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
"\n",
"torch.set_grad_enabled(False);"
]
},
{
"cell_type": "code",
"source": [
"pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "yLWKQVLksHAf",
"outputId": "8bdf7d89-02d6-4dbb-ccbd-9f0ec36869b8"
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main\n",
"/root/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/layers/swiglu_ffn.py:51: UserWarning: xFormers is not available (SwiGLU)\n",
" warnings.warn(\"xFormers is not available (SwiGLU)\")\n",
"/root/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/layers/attention.py:33: UserWarning: xFormers is not available (Attention)\n",
" warnings.warn(\"xFormers is not available (Attention)\")\n",
"/root/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/layers/block.py:40: UserWarning: xFormers is not available (Block)\n",
" warnings.warn(\"xFormers is not available (Block)\")\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"all_position_embeddings = pretrained.pos_embed\n",
"# From the code, we know the first embedding corresponds to the class token\n",
"patch_position_embeddings = all_position_embeddings[0][1:]\n",
"patch_position_embeddings.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "29jNxmfiriaH",
"outputId": "c035ac51-326b-430d-e05e-bc917eaaedbf"
},
"execution_count": 4,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Size([1369, 768])"
]
},
"metadata": {},
"execution_count": 4
}
]
},
{
"cell_type": "code",
"source": [
"num_tokens, dim = patch_position_embeddings.shape\n",
"num_tokens_y = num_tokens_x = int(np.sqrt(num_tokens)) # assume width = height\n",
"num_tokens_y"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PuQWIp6LrqwJ",
"outputId": "1d26f8e8-5a83-4921-f36e-549b3ae2fc8c"
},
"execution_count": 5,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"37"
]
},
"metadata": {},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"source": [
"patch_pos_embed_grid = rearrange(patch_position_embeddings, '(h w) c -> c h w', h=num_tokens_y)\n",
"patch_pos_embed_grid.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DEd2iar-cKPv",
"outputId": "050adaa0-217f-4257-a2c4-458b92c0f4d6"
},
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Size([768, 37, 37])"
]
},
"metadata": {},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"source": [
"def compute_cosine_similarity_matrix(tensor):\n",
" normalized_tensor = F.normalize(tensor, p=2, dim=1)\n",
" cosine_similarity_matrix = torch.mm(normalized_tensor, normalized_tensor.t())\n",
" return cosine_similarity_matrix\n",
"\n",
"similarity_matrix = compute_cosine_similarity_matrix(patch_position_embeddings)\n",
"similarity_matrix.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ZP5nb57rgo6C",
"outputId": "c689ef60-aa01-42e6-e84f-133081302328"
},
"execution_count": 7,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Size([1369, 1369])"
]
},
"metadata": {},
"execution_count": 7
}
]
},
{
"cell_type": "code",
"source": [
"similarities = []\n",
"for row in similarity_matrix:\n",
" grid = rearrange(row, '(h w) -> h w', h=num_tokens_y)\n",
" similarities.append(grid)\n",
"similarities = torch.stack(similarities)\n",
"similarities.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "osSuZ1-wh8iN",
"outputId": "8ebcc3ce-e1ec-4a4f-9d3b-f618bbe0581e"
},
"execution_count": 8,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Size([1369, 37, 37])"
]
},
"metadata": {},
"execution_count": 8
}
]
},
{
"cell_type": "code",
"source": [
"similarities_for_grid = rearrange(similarities, 'n h w -> n 1 h w')\n",
"grid = torchvision.utils.make_grid(\n",
" similarities_for_grid,\n",
" nrow=num_tokens_x,\n",
" pad_value=torch.nan,\n",
" padding=int(num_tokens_x * 0.15),\n",
")\n",
"grid.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bpt0JeHNp1dd",
"outputId": "d38ca0ce-5682-4468-f160-0dc0f0954e09"
},
"execution_count": 9,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Size([3, 1559, 1559])"
]
},
"metadata": {},
"execution_count": 9
}
]
},
{
"cell_type": "code",
"source": [
"image = grid[0]\n",
"\n",
"def plot_position_embeddings(colormap):\n",
" fig, axis = plt.subplots(figsize=(18, 16))\n",
"\n",
" im = axis.imshow(image, interpolation='none', cmap=colormap, vmin=-1, vmax=1)\n",
" divider = make_axes_locatable(axis)\n",
" cax = divider.append_axes('right', size='5%', pad=0.2)\n",
"\n",
" height, width = image.shape\n",
"\n",
" offset = width / num_tokens_x / 2\n",
" tick_positions = np.linspace(offset, width - offset, num_tokens_x)\n",
" tick_labels = [str(i + 1) for i in range(num_tokens_x)]\n",
" axis.set_xticks(tick_positions)\n",
" axis.set_xticklabels(tick_labels)\n",
" axis.tick_params(axis='x', length=0)\n",
"\n",
" offset = height / num_tokens_y / 2\n",
" tick_positions = np.linspace(offset, width - offset, num_tokens_y)\n",
" tick_labels = [str(i + 1) for i in range(num_tokens_y)]\n",
" axis.set_yticks(tick_positions)\n",
" axis.set_yticklabels(tick_labels)\n",
" axis.tick_params(axis='y', length=0)\n",
"\n",
" axis.set_title('Position embedding similarity')\n",
" axis.set_xlabel('Input patch column')\n",
" axis.set_ylabel('Input patch row')\n",
"\n",
" fig.colorbar(im, cax=cax)\n",
" fig.tight_layout()\n",
"\n",
"# Use colors in the paper\n",
"plot_position_embeddings('viridis')\n",
"\n",
"# Use better colors\n",
"cmap = mpl.colormaps.get_cmap('coolwarm')\n",
"cmap.set_bad(color='black')\n",
"plot_position_embeddings(cmap)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "LoWwuwIgqTBp",
"outputId": "dd8afc50-f239-4055-d91e-c2a9b7c0a74e"
},
"execution_count": 10,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1800x1600 with 2 Axes>"
],
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment