Created
October 17, 2023 22:48
-
-
Save fepegar/c6163af3fcde2bad59207130ddd127d1 to your computer and use it in GitHub Desktop.
Visualize positional embeddings in DINOv2 ViT.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "provenance": [], | |
| "authorship_tag": "ABX9TyOLOMhi9qrUKPH3Um+2yz6b", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/fepegar/c6163af3fcde2bad59207130ddd127d1/visualize-positional-embeddings-in-dinov2-vit.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "%pip install -q einops" | |
| ], | |
| "metadata": { | |
| "id": "DR4_tjlhbQN9" | |
| }, | |
| "execution_count": 1, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": { | |
| "id": "PLM9MJHirgV6" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import matplotlib as mpl\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "import numpy as np\n", | |
| "import torch\n", | |
| "import torch.nn.functional as F\n", | |
| "import torchvision\n", | |
| "from einops import rearrange\n", | |
| "from mpl_toolkits.axes_grid1 import make_axes_locatable\n", | |
| "\n", | |
| "torch.set_grad_enabled(False);" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "yLWKQVLksHAf", | |
| "outputId": "8bdf7d89-02d6-4dbb-ccbd-9f0ec36869b8" | |
| }, | |
| "execution_count": 3, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stderr", | |
| "text": [ | |
| "Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main\n", | |
| "/root/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/layers/swiglu_ffn.py:51: UserWarning: xFormers is not available (SwiGLU)\n", | |
| " warnings.warn(\"xFormers is not available (SwiGLU)\")\n", | |
| "/root/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/layers/attention.py:33: UserWarning: xFormers is not available (Attention)\n", | |
| " warnings.warn(\"xFormers is not available (Attention)\")\n", | |
| "/root/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/layers/block.py:40: UserWarning: xFormers is not available (Block)\n", | |
| " warnings.warn(\"xFormers is not available (Block)\")\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "all_position_embeddings = pretrained.pos_embed\n", | |
| "# From the code, we know the first embedding corresponds to the class token\n", | |
| "patch_position_embeddings = all_position_embeddings[0][1:]\n", | |
| "patch_position_embeddings.shape" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "29jNxmfiriaH", | |
| "outputId": "c035ac51-326b-430d-e05e-bc917eaaedbf" | |
| }, | |
| "execution_count": 4, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([1369, 768])" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 4 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "num_tokens, dim = patch_position_embeddings.shape\n", | |
| "num_tokens_y = num_tokens_x = int(np.sqrt(num_tokens)) # assume width = height\n", | |
| "num_tokens_y" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "PuQWIp6LrqwJ", | |
| "outputId": "1d26f8e8-5a83-4921-f36e-549b3ae2fc8c" | |
| }, | |
| "execution_count": 5, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "37" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 5 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "patch_pos_embed_grid = rearrange(patch_position_embeddings, '(h w) c -> c h w', h=num_tokens_y)\n", | |
| "patch_pos_embed_grid.shape" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "DEd2iar-cKPv", | |
| "outputId": "050adaa0-217f-4257-a2c4-458b92c0f4d6" | |
| }, | |
| "execution_count": 6, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([768, 37, 37])" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 6 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def compute_cosine_similarity_matrix(tensor):\n", | |
| " normalized_tensor = F.normalize(tensor, p=2, dim=1)\n", | |
| " cosine_similarity_matrix = torch.mm(normalized_tensor, normalized_tensor.t())\n", | |
| " return cosine_similarity_matrix\n", | |
| "\n", | |
| "similarity_matrix = compute_cosine_similarity_matrix(patch_position_embeddings)\n", | |
| "similarity_matrix.shape" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "ZP5nb57rgo6C", | |
| "outputId": "c689ef60-aa01-42e6-e84f-133081302328" | |
| }, | |
| "execution_count": 7, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([1369, 1369])" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 7 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "similarities = []\n", | |
| "for row in similarity_matrix:\n", | |
| " grid = rearrange(row, '(h w) -> h w', h=num_tokens_y)\n", | |
| " similarities.append(grid)\n", | |
| "similarities = torch.stack(similarities)\n", | |
| "similarities.shape" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "osSuZ1-wh8iN", | |
| "outputId": "8ebcc3ce-e1ec-4a4f-9d3b-f618bbe0581e" | |
| }, | |
| "execution_count": 8, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([1369, 37, 37])" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 8 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "similarities_for_grid = rearrange(similarities, 'n h w -> n 1 h w')\n", | |
| "grid = torchvision.utils.make_grid(\n", | |
| " similarities_for_grid,\n", | |
| " nrow=num_tokens_x,\n", | |
| " pad_value=torch.nan,\n", | |
| " padding=int(num_tokens_x * 0.15),\n", | |
| ")\n", | |
| "grid.shape" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "bpt0JeHNp1dd", | |
| "outputId": "d38ca0ce-5682-4468-f160-0dc0f0954e09" | |
| }, | |
| "execution_count": 9, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([3, 1559, 1559])" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 9 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "image = grid[0]\n", | |
| "\n", | |
| "def plot_position_embeddings(colormap):\n", | |
| " fig, axis = plt.subplots(figsize=(18, 16))\n", | |
| "\n", | |
| " im = axis.imshow(image, interpolation='none', cmap=colormap, vmin=-1, vmax=1)\n", | |
| " divider = make_axes_locatable(axis)\n", | |
| " cax = divider.append_axes('right', size='5%', pad=0.2)\n", | |
| "\n", | |
| " height, width = image.shape\n", | |
| "\n", | |
| " offset = width / num_tokens_x / 2\n", | |
| " tick_positions = np.linspace(offset, width - offset, num_tokens_x)\n", | |
| " tick_labels = [str(i + 1) for i in range(num_tokens_x)]\n", | |
| " axis.set_xticks(tick_positions)\n", | |
| " axis.set_xticklabels(tick_labels)\n", | |
| " axis.tick_params(axis='x', length=0)\n", | |
| "\n", | |
| " offset = height / num_tokens_y / 2\n", | |
| " tick_positions = np.linspace(offset, width - offset, num_tokens_y)\n", | |
| " tick_labels = [str(i + 1) for i in range(num_tokens_y)]\n", | |
| " axis.set_yticks(tick_positions)\n", | |
| " axis.set_yticklabels(tick_labels)\n", | |
| " axis.tick_params(axis='y', length=0)\n", | |
| "\n", | |
| " axis.set_title('Position embedding similarity')\n", | |
| " axis.set_xlabel('Input patch column')\n", | |
| " axis.set_ylabel('Input patch row')\n", | |
| "\n", | |
| " fig.colorbar(im, cax=cax)\n", | |
| " fig.tight_layout()\n", | |
| "\n", | |
| "# Use colors in the paper\n", | |
| "plot_position_embeddings('viridis')\n", | |
| "\n", | |
| "# Use better colors\n", | |
| "cmap = mpl.colormaps.get_cmap('coolwarm')\n", | |
| "cmap.set_bad(color='black')\n", | |
| "plot_position_embeddings(cmap)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 1000 | |
| }, | |
| "id": "LoWwuwIgqTBp", | |
| "outputId": "dd8afc50-f239-4055-d91e-c2a9b7c0a74e" | |
| }, | |
| "execution_count": 10, | |
| "outputs": [ | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "<Figure size 1800x1600 with 2 Axes>" | |
| ], |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment