Created
May 16, 2024 12:18
-
-
Save srmsoumya/ce14ccdf50a3089a6ebe860789ae7dd3 to your computer and use it in GitHub Desktop.
Embeddings from Prithvi 100M NASA/IBM
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import yaml | |
import numpy as np | |
import rasterio as rio | |
import torch | |
from prithvi import MaskedAutoencoderViT | |
from huggingface_hub import hf_hub_download | |
# Constants | |
REPO_ID = "ibm-nasa-geospatial/Prithvi-100M" | |
CONFIG = "Prithvi_100M_config.yaml" | |
CHECKPOINT = "Prithvi_100M.pt" | |
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
NO_DATA = -9999 | |
NO_DATA_FLOAT = 0.0001 | |
def get_model(config, checkpoint): | |
""" | |
Initializes and returns the MaskedAutoencoderViT model with the provided configuration and checkpoint. | |
Args: | |
config (dict): The configuration dictionary for the model. | |
checkpoint (str): Path to the model checkpoint. | |
Returns: | |
MaskedAutoencoderViT: The initialized model. | |
""" | |
model_args = config["model_args"] | |
model = MaskedAutoencoderViT(**model_args) | |
model = model.to(DEVICE) | |
state_dict = torch.load(checkpoint, map_location=DEVICE) | |
model.load_state_dict(state_dict) | |
model.eval() | |
return model | |
def get_data(chips, mean, std): | |
""" | |
Preprocesses the input data chips and returns them as a tensor. | |
Args: | |
chips (list of np.ndarray): List of chips where each chip is a numpy array of shape (channels, height, width). | |
mean (float): The mean value used for normalization. | |
std (float): The standard deviation value used for normalization. | |
Returns: | |
torch.Tensor: The preprocessed data tensor of shape (batch, channels, timesteps, height, width). | |
""" | |
for i in range(len(chips)): | |
chips[i] = np.moveaxis(chips[i], 0, -1) # Rearrange axes to (height, width, channels) | |
chips[i] = np.where(chips[i] == NO_DATA, NO_DATA_FLOAT, (chips[i] - mean) / std) # Normalize and replace NO_DATA values | |
chips = np.stack(chips, axis=0) # Stack chips along a new axis to form (num_frames, height, width, channels) | |
chips = np.moveaxis(chips, -1, 0).astype('float32') # Rearrange axes to (channels, num_frames, height, width) | |
chips = np.expand_dims(chips, axis=0) # Add batch dimension | |
chips = torch.tensor(chips, device=DEVICE) | |
return chips | |
if __name__ == "__main__": | |
# Download and load configuration and checkpoint | |
config_path = hf_hub_download(repo_id=REPO_ID, filename=CONFIG) | |
config = yaml.safe_load(open(config_path, "r")) | |
checkpoint_path = hf_hub_download(repo_id=REPO_ID, filename=CHECKPOINT) | |
# Initialize the model | |
model = get_model(config, checkpoint_path) | |
# Generate random data chips for testing | |
chips = [np.random.randn(6, 224, 224) for _ in range(3)] | |
mean = config["train_params"]["data_mean"] | |
std = config["train_params"]["data_std"] | |
data = get_data(chips, mean, std) | |
# Obtain feature embeddings from the model | |
embedding = model.forward_encoder(data, mask_ratio=0.0) | |
# Extract CLS token embedding | |
cls_embedding = embedding[:, 0, :] | |
# Print shapes of the embeddings | |
print(embedding.shape, cls_embedding.shape) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Code from: https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M | |
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# -------------------------------------------------------- | |
# References: | |
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm | |
# DeiT: https://github.com/facebookresearch/deit | |
# -------------------------------------------------------- | |
from functools import partial | |
import torch | |
import torch.nn as nn | |
from timm.models.vision_transformer import Block | |
from timm.models.layers import to_2tuple | |
import numpy as np | |
from einops import rearrange | |
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | |
""" | |
embed_dim: output dimension for each position | |
pos: a list of positions to be encoded: size (M,) | |
out: (M, D) | |
""" | |
assert embed_dim % 2 == 0 | |
omega = np.arange(embed_dim // 2, dtype=np.float32) | |
omega /= embed_dim / 2. | |
omega = 1. / 10000**omega # (D/2,) | |
pos = pos.reshape(-1) # (M,) | |
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product | |
emb_sin = np.sin(out) # (M, D/2) | |
emb_cos = np.cos(out) # (M, D/2) | |
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) | |
return emb | |
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | |
assert embed_dim % 2 == 0 | |
# use half of dimensions to encode grid_h | |
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) | |
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) | |
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) | |
return emb | |
def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): | |
""" | |
grid_size: 3d tuple of grid size: t, h, w | |
return: | |
pos_embed: L, D | |
""" | |
assert embed_dim % 16 == 0 | |
t_size, h_size, w_size = grid_size | |
w_embed_dim = embed_dim // 16 * 6 | |
h_embed_dim = embed_dim // 16 * 6 | |
t_embed_dim = embed_dim // 16 * 4 | |
w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size)) | |
h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size)) | |
t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size)) | |
w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1)) | |
h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1)) | |
t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0) | |
pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1) | |
if cls_token: | |
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) | |
return pos_embed | |
class PatchEmbed(nn.Module): | |
""" Frames of 2D Images to Patch Embedding | |
The 3D version of timm.models.vision_transformer.PatchEmbed | |
""" | |
def __init__( | |
self, | |
img_size=224, | |
patch_size=16, | |
num_frames=3, | |
tubelet_size=1, | |
in_chans=3, | |
embed_dim=768, | |
norm_layer=None, | |
flatten=True, | |
bias=True, | |
): | |
super().__init__() | |
img_size = to_2tuple(img_size) | |
patch_size = to_2tuple(patch_size) | |
self.img_size = img_size | |
self.patch_size = patch_size | |
self.num_frames = num_frames | |
self.tubelet_size = tubelet_size | |
self.grid_size = (num_frames // tubelet_size, img_size[0] // patch_size[0], img_size[1] // patch_size[1]) | |
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] | |
self.flatten = flatten | |
self.proj = nn.Conv3d(in_chans, embed_dim, | |
kernel_size=(tubelet_size, patch_size[0], patch_size[1]), | |
stride=(tubelet_size, patch_size[0], patch_size[1]), bias=bias) | |
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() | |
def forward(self, x): | |
B, C, T, H, W = x.shape | |
x = self.proj(x) | |
if self.flatten: | |
x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C | |
x = self.norm(x) | |
return x | |
class MaskedAutoencoderViT(nn.Module): | |
""" Masked Autoencoder with VisionTransformer backbone | |
""" | |
def __init__(self, img_size=224, patch_size=16, | |
num_frames=3, tubelet_size=1, | |
in_chans=3, embed_dim=1024, depth=24, num_heads=16, | |
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, | |
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): | |
super().__init__() | |
# -------------------------------------------------------------------------- | |
# MAE encoder specifics | |
self.patch_embed = PatchEmbed(img_size, patch_size,num_frames, tubelet_size, in_chans, embed_dim) | |
num_patches = self.patch_embed.num_patches | |
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding | |
self.blocks = nn.ModuleList([ | |
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) | |
for i in range(depth)]) | |
self.norm = norm_layer(embed_dim) | |
# -------------------------------------------------------------------------- | |
# -------------------------------------------------------------------------- | |
# MAE decoder specifics | |
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) | |
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) | |
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding | |
self.decoder_blocks = nn.ModuleList([ | |
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) | |
for i in range(decoder_depth)]) | |
self.decoder_norm = norm_layer(decoder_embed_dim) | |
self.decoder_pred = nn.Linear(decoder_embed_dim, tubelet_size * patch_size * patch_size * in_chans, bias=True) # decoder to patch | |
# -------------------------------------------------------------------------- | |
self.norm_pix_loss = norm_pix_loss | |
self.initialize_weights() | |
def initialize_weights(self): | |
# initialization | |
# initialize (and freeze) pos_embed by sin-cos embedding | |
pos_embed = get_3d_sincos_pos_embed(self.pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True) | |
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) | |
decoder_pos_embed = get_3d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True) | |
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) | |
# initialize patch_embed like nn.Linear (instead of nn.Conv2d) | |
w = self.patch_embed.proj.weight.data | |
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) | |
torch.nn.init.normal_(self.cls_token, std=.02) | |
torch.nn.init.normal_(self.mask_token, std=.02) | |
# initialize nn.Linear and nn.LayerNorm | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
# we use xavier_uniform following official JAX ViT: | |
torch.nn.init.xavier_uniform_(m.weight) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def patchify(self, imgs): | |
""" | |
imgs: B, C, T, H, W | |
x: B, L, D | |
""" | |
p = self.patch_embed.patch_size[0] | |
tub = self.patch_embed.tubelet_size | |
x = rearrange(imgs, 'b c (t tub) (h p) (w q) -> b (t h w) (tub p q c)', tub=tub, p=p, q=p) | |
return x | |
def unpatchify(self, x): | |
""" | |
x: B, L, D | |
imgs: B, C, T, H, W | |
""" | |
p = self.patch_embed.patch_size[0] | |
num_p = self.patch_embed.img_size[0] // p | |
tub = self.patch_embed.tubelet_size | |
imgs = rearrange(x, 'b (t h w) (tub p q c) -> b c (t tub) (h p) (w q)', h=num_p, w=num_p, tub=tub, p=p, q=p) | |
return imgs | |
def random_masking(self, x, mask_ratio): | |
""" | |
Perform per-sample random masking by per-sample shuffling. | |
Per-sample shuffling is done by argsort random noise. | |
x: [N, L, D], sequence | |
""" | |
N, L, D = x.shape # batch, length, dim | |
len_keep = int(L * (1 - mask_ratio)) | |
noise = torch.rand(N, L, device=x.device) # noise in [0, 1] | |
# sort noise for each sample | |
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove | |
ids_restore = torch.argsort(ids_shuffle, dim=1) | |
# keep the first subset | |
ids_keep = ids_shuffle[:, :len_keep] | |
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) | |
# generate the binary mask: 0 is keep, 1 is remove | |
mask = torch.ones([N, L], device=x.device) | |
mask[:, :len_keep] = 0 | |
# unshuffle to get the binary mask | |
mask = torch.gather(mask, dim=1, index=ids_restore) | |
return x_masked, mask, ids_restore | |
def forward_encoder(self, x, mask_ratio): | |
# embed patches | |
x = self.patch_embed(x) | |
# add pos embed w/o cls token | |
x = x + self.pos_embed[:, 1:, :] | |
if mask_ratio: | |
# Apply random masking to the input | |
x, mask, ids_restore = self.random_masking(x, mask_ratio) | |
# Add the cls token to the input | |
cls_token = self.cls_token + self.pos_embed[:, :1, :] | |
cls_tokens = cls_token.expand(x.shape[0], -1, -1) | |
x = torch.cat((cls_tokens, x), dim=1) | |
# Pass the input through Transformer blocks | |
for blk in self.blocks: | |
x = blk(x) | |
x = self.norm(x) | |
if mask_ratio: | |
return x, mask, ids_restore | |
else: | |
return x | |
def forward_decoder(self, x, ids_restore): | |
# embed tokens | |
x = self.decoder_embed(x) | |
# append mask tokens to sequence | |
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) | |
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token | |
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle | |
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token | |
# add pos embed | |
x = x + self.decoder_pos_embed | |
# apply Transformer blocks | |
for blk in self.decoder_blocks: | |
x = blk(x) | |
x = self.decoder_norm(x) | |
# predictor projection | |
x = self.decoder_pred(x) | |
# remove cls token | |
x = x[:, 1:, :] | |
return x | |
def forward_loss(self, imgs, pred, mask): | |
""" | |
imgs: B, C, T, H, W | |
target: B, L, D | |
pred: B, L, D | |
mask: B, L. 0 is keep, 1 is remove, | |
""" | |
target = self.patchify(imgs) | |
if self.norm_pix_loss: | |
mean = target.mean(dim=-1, keepdim=True) | |
var = target.var(dim=-1, keepdim=True) | |
target = (target - mean) / (var + 1.e-6)**.5 | |
loss = (pred - target) ** 2 | |
loss = loss.mean(dim=-1) # [N, L], mean loss per patch | |
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches | |
return loss | |
def forward(self, imgs, mask_ratio=0.75): | |
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) | |
pred = self.forward_decoder(latent, ids_restore) | |
loss = self.forward_loss(imgs, pred, mask) | |
return loss, pred, mask |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment