Created
January 22, 2026 10:22
-
-
Save ceshine/d5602adec39e377c938dc59d38e2b47d to your computer and use it in GitHub Desktop.
Notebooks and Code Used in the blog post
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "6acf63e6-176b-409a-8570-e865892c92a2", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "import numpy as np\n", | |
| "from PIL import Image\n", | |
| "from transformers.utils.import_utils import is_flash_attn_2_available\n", | |
| "\n", | |
| "from colpali_engine.models import ColQwen2, ColQwen2Processor\n", | |
| "from colpali_engine.interpretability import get_similarity_maps_from_embeddings\n", | |
| "from matplotlib import pyplot as plt\n", | |
| "\n", | |
| "from plot_utils import plot_similarity_map" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "79730c3b-f038-480d-a892-539eb7122e27", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "model_name = \"vidore/colqwen2-v1.0\"\n", | |
| "\n", | |
| "processor = ColQwen2Processor.from_pretrained(model_name)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "5e9914ff-1264-4536-aada-3de1cc941ca3", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "images = [\n", | |
| " Image.open(\"../data/attention-is-all-you-need/page-0.png\"),\n", | |
| " Image.open(\"../data/attention-is-all-you-need/page-1.png\"),\n", | |
| " Image.open(\"../data/attention-is-all-you-need/page-2.png\"),\n", | |
| " Image.open(\"../data/attention-is-all-you-need/page-3.png\"),\n", | |
| "]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "7f18d60e-f26d-4edd-b2ee-030b3c59048b", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([4, 2976, 1176])" | |
| ] | |
| }, | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "batch_images = processor.process_images(images)\n", | |
| "batch_images[\"pixel_values\"].shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "f0fe7c1a-153b-43fe-b68d-a871ddfb88b2", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[ 1, 62, 48],\n", | |
| " [ 1, 62, 48],\n", | |
| " [ 1, 62, 48],\n", | |
| " [ 1, 62, 48]])" | |
| ] | |
| }, | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "batch_images[\"image_grid_thw\"]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "1ccae831-9879-4cfc-8d3e-f51610b4a092", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(672, 868)" | |
| ] | |
| }, | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# The shape of the resized image\n", | |
| "48 * 14, 62 * 14" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "2e085490-c428-46f1-b1e6-c130e04cd263", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "`torch_dtype` is deprecated! Use `dtype` instead!\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "b745a5d6ed8e4cc2af0ddf5c636c4e64", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Fetching 2 files: 0%| | 0/2 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "45799375af544003b2f8faf08958a997", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "model = ColQwen2.from_pretrained(\n", | |
| " model_name,\n", | |
| " torch_dtype=torch.bfloat16,\n", | |
| " device_map=\"cuda:0\",\n", | |
| " attn_implementation=\"flash_attention_2\" if is_flash_attn_2_available() else None,\n", | |
| ").eval()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "3d8343f8-186d-478d-a39f-1f2a54609737", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "queries = [\n", | |
| " \"What are the key innovations of the Transformer architecture?\",\n", | |
| " \"How do the Encoder and Decoder stacks work together in Transformers?\",\n", | |
| "]\n", | |
| "batch_queries = processor.process_queries(queries).to(model.device)\n", | |
| "batch_images = batch_images.to(model.device)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "1fefd2dc-e6ef-4525-a85a-e94ceef24c2a", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(torch.Size([4, 755, 128]), torch.Size([2, 22, 128]))" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# Forward pass\n", | |
| "with torch.no_grad():\n", | |
| " image_embeddings = model(**batch_images)\n", | |
| " query_embeddings = model(**batch_queries)\n", | |
| "image_embeddings.shape, query_embeddings.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "8c972025-aa52-4861-b1eb-7706e32662be", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[13.5625, 13.7500, 15.0000, 7.8750],\n", | |
| " [13.5000, 13.8750, 17.0000, 8.9375]])" | |
| ] | |
| }, | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "scores = processor.score_multi_vector(query_embeddings, image_embeddings)\n", | |
| "scores" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "id": "e279323d-82ed-4747-af0a-1bc4b2e23a8e", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(24, 31)" | |
| ] | |
| }, | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# Get the number of image patches\n", | |
| "n_patches = processor.get_n_patches(image_size=images[0].size, spatial_merge_size=2)\n", | |
| "n_patches" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "id": "23c8e5df-6382-48d4-a6e3-8c5fad12639f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(torch.Size([4, 755]), 744)" | |
| ] | |
| }, | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# Get the tensor mask to filter out the embeddings that are not related to the image\n", | |
| "image_mask = processor.get_image_mask(batch_images)\n", | |
| "image_mask.shape, image_mask[0].sum().item()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "id": "e0522f5e-c521-45ce-a60f-283daafb7d97", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([22, 24, 31])" | |
| ] | |
| }, | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# Generate the similarity maps\n", | |
| "query_idx = 1\n", | |
| "image_idx = 2\n", | |
| "batched_similarity_maps = get_similarity_maps_from_embeddings(\n", | |
| " image_embeddings=image_embeddings[image_idx].unsqueeze(0),\n", | |
| " query_embeddings=query_embeddings[query_idx].unsqueeze(0),\n", | |
| " n_patches=n_patches,\n", | |
| " image_mask=image_mask[image_idx].unsqueeze(0),\n", | |
| ")\n", | |
| "batched_similarity_maps[0].shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "id": "95c95f31-4a99-472f-896b-872c94cfc9b1", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(['How',\n", | |
| " 'Ġdo',\n", | |
| " 'Ġthe',\n", | |
| " 'ĠEncoder',\n", | |
| " 'Ġand',\n", | |
| " 'ĠDecoder',\n", | |
| " 'Ġstacks',\n", | |
| " 'Ġwork',\n", | |
| " 'Ġtogether',\n", | |
| " 'Ġin',\n", | |
| " 'ĠTransformers',\n", | |
| " '?'],\n", | |
| " 12)" | |
| ] | |
| }, | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# Tokenize the query\n", | |
| "query_content = processor.decode(batch_queries.input_ids[query_idx]).replace(processor.tokenizer.pad_token, \"\")\n", | |
| "query_content = query_content.replace(processor.query_augmentation_token, \"\").strip()\n", | |
| "\n", | |
| "query_tokens = processor.tokenizer.tokenize(query_content)\n", | |
| "query_tokens, len(query_tokens)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 29, | |
| "id": "e5b261f6-f69f-40c7-8d4b-ee8bafc98739", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Token: `ĠDecoder` Max_Value: 0.84\n" | |
| ] | |
| }, | |
| { | |
| "data": { |
View raw
(Sorry about that, but we can’t show files that are this big right now.)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment