Skip to content

Instantly share code, notes, and snippets.

@calebrob6
Created November 14, 2025 21:22
Show Gist options
  • Select an option

  • Save calebrob6/d6c0fdda8788d5ed821988a8ce8f39b3 to your computer and use it in GitHub Desktop.

Select an option

Save calebrob6/d6c0fdda8788d5ed821988a8ce8f39b3 to your computer and use it in GitHub Desktop.
Notebook showing how to use the AnyUp model to upsample low (spatial) resolution embeddings.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "866ff8b4",
"metadata": {},
"outputs": [],
"source": [
"from typing import Optional, Sequence\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from sklearn.decomposition import PCA\n",
"\n",
"from ftw_tools.training.datasets import FTW"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "59e16ffd",
"metadata": {},
"outputs": [],
"source": [
"class AnyUpUpsampler:\n",
" \"\"\"Wrapper around pretrained AnyUp model from torch.hub, with Sentinel-2 pseudo-RGB normalization support.\n",
" \"\"\"\n",
"\n",
" def __init__(self, device: Optional[int]=0, norm_constant=3000.0):\n",
" \"\"\"Initialize AnyUp upsampler.\n",
"\n",
" Args:\n",
" device (Optional[int]): CUDA device index. If None, use CPU.\n",
" norm_constant (float): Normalization constant for Sentinel-2 reflectance values. The `hr_image` input\n",
" to the `upsample` method will be divided by this constant before normalization.\n",
" \"\"\"\n",
" if device is not None:\n",
" if torch.cuda.is_available():\n",
" device = torch.device(f\"cuda:{device}\")\n",
" else:\n",
" print(\"WARNING: CUDA not available, using CPU instead.\")\n",
" device = torch.device(\"cpu\")\n",
" else:\n",
" device = torch.device(\"cpu\")\n",
" self.norm_constant = norm_constant\n",
" self.mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)\n",
" self.std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)\n",
" self.model = torch.hub.load('wimmerth/anyup', 'anyup', trust_repo=True).eval().to(device)\n",
"\n",
" def upsample(self, sentinel_2_rgb: torch.Tensor, lr_features: torch.Tensor) -> torch.Tensor:\n",
" \"\"\"Upsample low-resolution features to match high-resolution image size.\n",
"\n",
" Args:\n",
" sentinel_2_rgb (torch.Tensor): High-res Sentinel-2 patch [B, C, H, W] where the first\n",
" 3 channels are assumed to be (B4, B3, B2)\n",
" lr_features (torch.Tensor): Low-res feature map [B, D, h, w]\n",
"\n",
" Returns:\n",
" hr_features (torch.Tensor): Upsampled features [B, D, H, W]\n",
" \"\"\"\n",
" assert sentinel_2_rgb.dim() == 4, \"sentinel_2_rgb must be a 4D tensor [B, C, H, W]\"\n",
" assert sentinel_2_rgb.size(1) >= 3, \"sentinel_2_rgb must have at least 3 channels in (B4, B3, B2) order\"\n",
" rgb = torch.clamp(sentinel_2_rgb[:, :3, :, :] / self.norm_constant, 0, 1)\n",
"\n",
" # ImageNet-style normalization\n",
" hr_image_norm = (rgb - self.mean) / self.std\n",
"\n",
" # Upsample\n",
" with torch.inference_mode():\n",
" hr_features = self.model(hr_image_norm, lr_features)\n",
" return hr_features\n",
"\n",
"\n",
"class BatchedDinoWrapper(nn.Module):\n",
" def __init__(self, layers: Optional[Sequence[int]] = None):\n",
" super().__init__()\n",
" # choose which transformer blocks to read; default: last layer only\n",
" self.layers = list(range(24)) if layers is None else list(layers)\n",
"\n",
" # keep the backbone in self.backbone; don't wrap it yet\n",
" self.backbone = torch.hub.load(\n",
" \"facebookresearch/dinov3\",\n",
" \"dinov3_vitl16\",\n",
" source=\"github\",\n",
" weights=\"dinov3_vitl16_pretrain_sat493m-eadcf0ff.pth\",\n",
" ).eval()\n",
"\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" \"\"\"\n",
" x: (N, 3, H, W)\n",
" returns: (N, HW, C) features from the last selected layer\n",
" \"\"\"\n",
" if x.dim() == 3:\n",
" x = x.unsqueeze(0) # -> (1,3,H,W)\n",
" assert x.dim() == 4 and x.size(1) == 3, \"Expected (N,3,H,W)\"\n",
"\n",
" with torch.inference_mode():\n",
" feats_list = self.backbone.get_intermediate_layers(\n",
" x, n=self.layers, reshape=True, norm=True\n",
" )\n",
" feats = feats_list[-1] # take the last requested layer\n",
" assert feats.dim() == 4 and feats.size(0) == x.size(0), (\n",
" f\"Unexpected feats shape: {feats.shape}\"\n",
" )\n",
"\n",
" N, C, h, w = feats.shape\n",
" feats = feats.view(N, C, h * w).transpose(1, 2).contiguous()\n",
"\n",
" return feats"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "36120fb9",
"metadata": {},
"outputs": [],
"source": [
"ds = FTW(\n",
" root=\"data/ftw/\",\n",
" countries=\"austria\",\n",
" split=\"train\",\n",
" verbose=False\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a91fd38a",
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "29c9c741",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using cache found in /home/davrob/.cache/torch/hub/facebookresearch_dinov3_main\n",
"Using cache found in /home/davrob/.cache/torch/hub/wimmerth_anyup_main\n"
]
}
],
"source": [
"model = BatchedDinoWrapper().to(device)\n",
"upsampler = AnyUpUpsampler(device=0, norm_constant=1.0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5700f177",
"metadata": {},
"outputs": [],
"source": [
"img = torch.clamp(ds[0][\"image\"][:3] / 3000.0, 0, 1)\n",
"\n",
"img_big = F.interpolate(img.unsqueeze(0), scale_factor=16, mode='bilinear', align_corners=False).squeeze(0)\n",
"\n",
"img_viz = img.permute(1, 2, 0).numpy()\n",
"\n",
"image_batch = img.unsqueeze(0).to(device)\n",
"image_big_batch = img_big.unsqueeze(0).to(device)\n",
"embedding = model(image_batch)\n",
"embedding_big = model(image_big_batch)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "33418fb8",
"metadata": {},
"outputs": [],
"source": [
"embedding = embedding.squeeze().reshape(img.shape[1]//16, img.shape[2]//16, -1).permute(2, 0, 1).unsqueeze(0)\n",
"embedding_big = embedding_big.squeeze().reshape(img_big.shape[1]//16, img_big.shape[2]//16, -1).permute(2, 0, 1).unsqueeze(0)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "de3a03ba",
"metadata": {},
"outputs": [],
"source": [
"hr_embedding = upsampler.upsample(image_batch, embedding)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "c95ce9bb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 1024, 256, 256])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hr_embedding.shape"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "f6b9c1e8",
"metadata": {},
"outputs": [],
"source": [
"def get_pca(embedding):\n",
" embedding = embedding.squeeze()\n",
" _, h, w = embedding.shape\n",
" pixels = embedding.reshape(1024, -1).T.copy()\n",
"\n",
" print(pixels.shape)\n",
"\n",
" pca = PCA(n_components=3)\n",
" x_pca = pca.fit(pixels[::10]).transform(pixels)\n",
" x_pca = x_pca.reshape(h, w, -1)\n",
"\n",
" p1 = np.percentile(x_pca, 1)\n",
" p99 = np.percentile(x_pca, 99)\n",
" x_pca = (x_pca - p1) / (p99 - p1)\n",
" x_pca = np.clip(x_pca, 0, 1)\n",
" return x_pca"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "505a16f3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(256, 1024)\n",
"(65536, 1024)\n",
"(65536, 1024)\n"
]
}
],
"source": [
"embedding_lr_pca = get_pca(embedding.cpu().numpy().squeeze())\n",
"embedding_hr_pca = get_pca(hr_embedding.cpu().numpy().squeeze())\n",
"embedding_big_pca = get_pca(embedding_big.cpu().numpy().squeeze())"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "681726f1",
"metadata": {},
"outputs": [
{
"data": {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment