Created
November 14, 2025 21:22
-
-
Save calebrob6/d6c0fdda8788d5ed821988a8ce8f39b3 to your computer and use it in GitHub Desktop.
Notebook showing how to use the AnyUp model to upsample low (spatial) resolution embeddings.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "866ff8b4", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from typing import Optional, Sequence\n", | |
| "\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "import numpy as np\n", | |
| "import torch\n", | |
| "import torch.nn as nn\n", | |
| "import torch.nn.functional as F\n", | |
| "from sklearn.decomposition import PCA\n", | |
| "\n", | |
| "from ftw_tools.training.datasets import FTW" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "59e16ffd", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class AnyUpUpsampler:\n", | |
| " \"\"\"Wrapper around pretrained AnyUp model from torch.hub, with Sentinel-2 pseudo-RGB normalization support.\n", | |
| " \"\"\"\n", | |
| "\n", | |
| " def __init__(self, device: Optional[int]=0, norm_constant=3000.0):\n", | |
| " \"\"\"Initialize AnyUp upsampler.\n", | |
| "\n", | |
| " Args:\n", | |
| " device (Optional[int]): CUDA device index. If None, use CPU.\n", | |
| " norm_constant (float): Normalization constant for Sentinel-2 reflectance values. The `hr_image` input\n", | |
| " to the `upsample` method will be divided by this constant before normalization.\n", | |
| " \"\"\"\n", | |
| " if device is not None:\n", | |
| " if torch.cuda.is_available():\n", | |
| " device = torch.device(f\"cuda:{device}\")\n", | |
| " else:\n", | |
| " print(\"WARNING: CUDA not available, using CPU instead.\")\n", | |
| " device = torch.device(\"cpu\")\n", | |
| " else:\n", | |
| " device = torch.device(\"cpu\")\n", | |
| " self.norm_constant = norm_constant\n", | |
| " self.mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)\n", | |
| " self.std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)\n", | |
| " self.model = torch.hub.load('wimmerth/anyup', 'anyup', trust_repo=True).eval().to(device)\n", | |
| "\n", | |
| " def upsample(self, sentinel_2_rgb: torch.Tensor, lr_features: torch.Tensor) -> torch.Tensor:\n", | |
| " \"\"\"Upsample low-resolution features to match high-resolution image size.\n", | |
| "\n", | |
| " Args:\n", | |
| " sentinel_2_rgb (torch.Tensor): High-res Sentinel-2 patch [B, C, H, W] where the first\n", | |
| " 3 channels are assumed to be (B4, B3, B2)\n", | |
| " lr_features (torch.Tensor): Low-res feature map [B, D, h, w]\n", | |
| "\n", | |
| " Returns:\n", | |
| " hr_features (torch.Tensor): Upsampled features [B, D, H, W]\n", | |
| " \"\"\"\n", | |
| " assert sentinel_2_rgb.dim() == 4, \"sentinel_2_rgb must be a 4D tensor [B, C, H, W]\"\n", | |
| " assert sentinel_2_rgb.size(1) >= 3, \"sentinel_2_rgb must have at least 3 channels in (B4, B3, B2) order\"\n", | |
| " rgb = torch.clamp(sentinel_2_rgb[:, :3, :, :] / self.norm_constant, 0, 1)\n", | |
| "\n", | |
| " # ImageNet-style normalization\n", | |
| " hr_image_norm = (rgb - self.mean) / self.std\n", | |
| "\n", | |
| " # Upsample\n", | |
| " with torch.inference_mode():\n", | |
| " hr_features = self.model(hr_image_norm, lr_features)\n", | |
| " return hr_features\n", | |
| "\n", | |
| "\n", | |
| "class BatchedDinoWrapper(nn.Module):\n", | |
| " def __init__(self, layers: Optional[Sequence[int]] = None):\n", | |
| " super().__init__()\n", | |
| " # choose which transformer blocks to read; default: last layer only\n", | |
| " self.layers = list(range(24)) if layers is None else list(layers)\n", | |
| "\n", | |
| " # keep the backbone in self.backbone; don't wrap it yet\n", | |
| " self.backbone = torch.hub.load(\n", | |
| " \"facebookresearch/dinov3\",\n", | |
| " \"dinov3_vitl16\",\n", | |
| " source=\"github\",\n", | |
| " weights=\"dinov3_vitl16_pretrain_sat493m-eadcf0ff.pth\",\n", | |
| " ).eval()\n", | |
| "\n", | |
| " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", | |
| " \"\"\"\n", | |
| " x: (N, 3, H, W)\n", | |
| " returns: (N, HW, C) features from the last selected layer\n", | |
| " \"\"\"\n", | |
| " if x.dim() == 3:\n", | |
| " x = x.unsqueeze(0) # -> (1,3,H,W)\n", | |
| " assert x.dim() == 4 and x.size(1) == 3, \"Expected (N,3,H,W)\"\n", | |
| "\n", | |
| " with torch.inference_mode():\n", | |
| " feats_list = self.backbone.get_intermediate_layers(\n", | |
| " x, n=self.layers, reshape=True, norm=True\n", | |
| " )\n", | |
| " feats = feats_list[-1] # take the last requested layer\n", | |
| " assert feats.dim() == 4 and feats.size(0) == x.size(0), (\n", | |
| " f\"Unexpected feats shape: {feats.shape}\"\n", | |
| " )\n", | |
| "\n", | |
| " N, C, h, w = feats.shape\n", | |
| " feats = feats.view(N, C, h * w).transpose(1, 2).contiguous()\n", | |
| "\n", | |
| " return feats" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "36120fb9", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "ds = FTW(\n", | |
| " root=\"data/ftw/\",\n", | |
| " countries=\"austria\",\n", | |
| " split=\"train\",\n", | |
| " verbose=False\n", | |
| ")\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "a91fd38a", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "device = torch.device(\"cuda:0\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "29c9c741", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Using cache found in /home/davrob/.cache/torch/hub/facebookresearch_dinov3_main\n", | |
| "Using cache found in /home/davrob/.cache/torch/hub/wimmerth_anyup_main\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "model = BatchedDinoWrapper().to(device)\n", | |
| "upsampler = AnyUpUpsampler(device=0, norm_constant=1.0)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "5700f177", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "img = torch.clamp(ds[0][\"image\"][:3] / 3000.0, 0, 1)\n", | |
| "\n", | |
| "img_big = F.interpolate(img.unsqueeze(0), scale_factor=16, mode='bilinear', align_corners=False).squeeze(0)\n", | |
| "\n", | |
| "img_viz = img.permute(1, 2, 0).numpy()\n", | |
| "\n", | |
| "image_batch = img.unsqueeze(0).to(device)\n", | |
| "image_big_batch = img_big.unsqueeze(0).to(device)\n", | |
| "embedding = model(image_batch)\n", | |
| "embedding_big = model(image_big_batch)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "33418fb8", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "embedding = embedding.squeeze().reshape(img.shape[1]//16, img.shape[2]//16, -1).permute(2, 0, 1).unsqueeze(0)\n", | |
| "embedding_big = embedding_big.squeeze().reshape(img_big.shape[1]//16, img_big.shape[2]//16, -1).permute(2, 0, 1).unsqueeze(0)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "de3a03ba", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "hr_embedding = upsampler.upsample(image_batch, embedding)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "c95ce9bb", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([1, 1024, 256, 256])" | |
| ] | |
| }, | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "hr_embedding.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "id": "f6b9c1e8", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def get_pca(embedding):\n", | |
| " embedding = embedding.squeeze()\n", | |
| " _, h, w = embedding.shape\n", | |
| " pixels = embedding.reshape(1024, -1).T.copy()\n", | |
| "\n", | |
| " print(pixels.shape)\n", | |
| "\n", | |
| " pca = PCA(n_components=3)\n", | |
| " x_pca = pca.fit(pixels[::10]).transform(pixels)\n", | |
| " x_pca = x_pca.reshape(h, w, -1)\n", | |
| "\n", | |
| " p1 = np.percentile(x_pca, 1)\n", | |
| " p99 = np.percentile(x_pca, 99)\n", | |
| " x_pca = (x_pca - p1) / (p99 - p1)\n", | |
| " x_pca = np.clip(x_pca, 0, 1)\n", | |
| " return x_pca" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "id": "505a16f3", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "(256, 1024)\n", | |
| "(65536, 1024)\n", | |
| "(65536, 1024)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "embedding_lr_pca = get_pca(embedding.cpu().numpy().squeeze())\n", | |
| "embedding_hr_pca = get_pca(hr_embedding.cpu().numpy().squeeze())\n", | |
| "embedding_big_pca = get_pca(embedding_big.cpu().numpy().squeeze())" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "id": "681726f1", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment