Created
May 9, 2026 05:47
-
-
Save calebrob6/9ce42c29dc956d2b09148e2ff5ca5b17 to your computer and use it in GitHub Desktop.
S2 super-resolution with gaussian splats
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """Download a Sentinel-2 time series from the Microsoft Planetary Computer. | |
| Companion script for the blog post "Super-Resolving Sentinel-2 with Gaussian | |
| Splats" (https://geospatialml.com/posts/sentinel2-superresolution/). | |
| Pulls N cloud-free Sentinel-2 L2A scenes over a configurable AOI, crops to a | |
| common pixel window, and writes one 4-band GeoTIFF per scene (B02 Blue, | |
| B03 Green, B04 Red, B08 NIR; uint16 reflectance in 0..10000). | |
| Requirements: | |
| pip install planetary-computer pystac-client rasterio numpy | |
| Run: | |
| python download_s2.py # uses defaults below | |
| python download_s2.py --lat 47.674 --lon -122.121 --n-scenes 32 --out data | |
| """ | |
| import argparse | |
| import json | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import planetary_computer | |
| import pystac_client | |
| import rasterio | |
| from rasterio.crs import CRS | |
| from rasterio.warp import transform_bounds | |
| BANDS = ["B02", "B03", "B04", "B08"] # Blue, Green, Red, NIR @ 10m | |
| BAND_DESCS = ["Blue (B02)", "Green (B03)", "Red (B04)", "NIR (B08)"] | |
| def parse_args(): | |
| p = argparse.ArgumentParser(description=__doc__.split("\n")[0]) | |
| p.add_argument("--lat", type=float, default=47.674, | |
| help="AOI center latitude (default: Redmond, WA)") | |
| p.add_argument("--lon", type=float, default=-122.121, | |
| help="AOI center longitude") | |
| p.add_argument("--half-deg-lat", type=float, default=0.045, | |
| help="Half AOI extent in latitude degrees (~5km at default lat)") | |
| p.add_argument("--half-deg-lon", type=float, default=0.065, | |
| help="Half AOI extent in longitude degrees (~5km at default lat)") | |
| p.add_argument("--n-scenes", type=int, default=32, | |
| help="Number of cloud-free scenes to download") | |
| p.add_argument("--max-cloud", type=float, default=10.0, | |
| help="Maximum eo:cloud_cover percentage") | |
| p.add_argument("--date-range", type=str, default="2025-04-01/2025-10-31", | |
| help="STAC datetime range, e.g. 2025-04-01/2025-10-31") | |
| p.add_argument("--out", type=Path, default=Path("data"), | |
| help="Output directory") | |
| return p.parse_args() | |
| def search_scenes(bbox, date_range, max_cloud, n_scenes): | |
| """Find candidate L2A scenes over `bbox` in `date_range`, sorted by cloud cover. | |
| Returns more candidates than `n_scenes` so the caller can drop scenes that | |
| only partially cover the AOI (granule edges) and still hit the target count. | |
| """ | |
| catalog = pystac_client.Client.open( | |
| "https://planetarycomputer.microsoft.com/api/stac/v1", | |
| modifier=planetary_computer.sign_inplace, | |
| ) | |
| search = catalog.search( | |
| collections=["sentinel-2-l2a"], | |
| bbox=bbox, | |
| datetime=date_range, | |
| query={"eo:cloud_cover": {"lt": max_cloud}}, | |
| sortby=[{"field": "eo:cloud_cover", "direction": "asc"}], | |
| ) | |
| items = list(search.items()) | |
| if len(items) < n_scenes: | |
| print(f"Warning: only {len(items)} scenes match (asked for {n_scenes}).", | |
| file=sys.stderr) | |
| return items | |
| def compute_target_window(item, bbox_4326): | |
| """From a reference scene, compute the rasterio window covering bbox. | |
| All other scenes will be cropped to the same pixel window so the stack | |
| is co-registered out of the box. | |
| """ | |
| asset = planetary_computer.sign(item.assets["B04"]) | |
| with rasterio.open(asset.href) as src: | |
| bounds_in_scene = transform_bounds(CRS.from_epsg(4326), src.crs, *bbox_4326) | |
| window = rasterio.windows.from_bounds(*bounds_in_scene, transform=src.transform) | |
| window = window.round_offsets().round_lengths() | |
| transform = rasterio.windows.transform(window, src.transform) | |
| return window, transform, src.crs | |
| MIN_VALID_FRACTION = 0.99 # drop scenes where less than this fraction of the AOI is in-granule | |
| def download_scene(item, idx, window, transform, crs, out_dir): | |
| """Read all 4 bands at the same window and write a single multi-band COG. | |
| Returns (path, status). Status is "downloaded", "exists", or "skipped:<reason>". | |
| Skipped scenes are not written to disk (the AOI sits at a granule edge and is | |
| mostly NoData -- those scenes break phase correlation and bias the SR fit). | |
| """ | |
| date_str = item.datetime.strftime("%Y%m%d") | |
| cloud = item.properties.get("eo:cloud_cover", -1) | |
| out_path = out_dir / f"s2_{idx:03d}_{date_str}_cc{cloud:.0f}.tif" | |
| if out_path.exists(): | |
| return out_path, "exists" | |
| bands = [] | |
| for band in BANDS: | |
| asset = planetary_computer.sign(item.assets[band]) | |
| with rasterio.open(asset.href) as src: | |
| bands.append(src.read(1, window=window)) | |
| stack = np.stack(bands, axis=0).astype(np.uint16) | |
| valid_frac = float((stack[0] > 0).sum()) / stack[0].size | |
| if valid_frac < MIN_VALID_FRACTION: | |
| return out_path, f"skipped:partial-coverage({valid_frac:.2%})" | |
| profile = { | |
| "driver": "GTiff", "dtype": "uint16", | |
| "height": stack.shape[1], "width": stack.shape[2], "count": 4, | |
| "crs": crs, "transform": transform, | |
| "compress": "deflate", "tiled": True, "blockxsize": 256, "blockysize": 256, | |
| } | |
| with rasterio.open(out_path, "w", **profile) as dst: | |
| dst.write(stack) | |
| for i, desc in enumerate(BAND_DESCS, start=1): | |
| dst.set_band_description(i, desc) | |
| dst.update_tags( | |
| datetime=item.datetime.isoformat(), | |
| cloud_cover=str(cloud), | |
| scene_id=item.id, | |
| ) | |
| return out_path, "downloaded" | |
| def main(): | |
| args = parse_args() | |
| args.out.mkdir(parents=True, exist_ok=True) | |
| bbox = [ | |
| args.lon - args.half_deg_lon, args.lat - args.half_deg_lat, | |
| args.lon + args.half_deg_lon, args.lat + args.half_deg_lat, | |
| ] | |
| print(f"AOI bbox (lon/lat): {bbox}") | |
| print(f"Searching {args.date_range} for <{args.max_cloud}% cloud cover...") | |
| items = search_scenes(bbox, args.date_range, args.max_cloud, args.n_scenes) | |
| if not items: | |
| sys.exit("No matching scenes found.") | |
| print(f"Found {len(items)} candidate scenes (lowest cloud cover first); " | |
| f"keeping the first {args.n_scenes} that fully cover the AOI.") | |
| window, transform, crs = compute_target_window(items[0], bbox) | |
| print(f"Target window: {int(window.height)} x {int(window.width)} px in {crs}") | |
| saved = [] | |
| kept = 0 | |
| for item in items: | |
| if kept >= args.n_scenes: | |
| break | |
| idx = kept | |
| try: | |
| path, status = download_scene(item, idx, window, transform, crs, args.out) | |
| print(f" [{idx:02d}] {status:25s} {path.name}") | |
| if status.startswith("skipped"): | |
| continue | |
| saved.append({ | |
| "file": path.name, | |
| "date": item.datetime.isoformat(), | |
| "cloud_cover": item.properties.get("eo:cloud_cover", -1), | |
| "scene_id": item.id, | |
| }) | |
| kept += 1 | |
| except Exception as e: | |
| print(f" [{idx:02d}] {'FAILED':25s} {e}") | |
| if kept < args.n_scenes: | |
| print(f"Warning: kept {kept} scenes, fewer than the requested {args.n_scenes}.", | |
| file=sys.stderr) | |
| meta = { | |
| "bbox_4326": bbox, | |
| "crs": str(crs), | |
| "transform": list(transform)[:6], | |
| "height": int(window.height), | |
| "width": int(window.width), | |
| "pixel_size_m": 10.0, | |
| "bands": BAND_DESCS, | |
| "scenes": saved, | |
| } | |
| with open(args.out / "metadata.json", "w") as f: | |
| json.dump(meta, f, indent=2) | |
| print(f"Wrote {len(saved)} scenes and metadata.json to {args.out}/") | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """Multi-temporal Sentinel-2 super-resolution with Gaussian splats. | |
| Companion script for the blog post "Super-Resolving Sentinel-2 with Gaussian | |
| Splats" (https://geospatialml.com/posts/sentinel2-superresolution/). | |
| The scene is represented as a continuous field of 2D Gaussians on a regular | |
| grid. Each Sentinel-2 observation is the analytic integral of that field | |
| over shifted 10m pixel footprints convolved with the sensor PSF. Because | |
| both the splats and the PSF are Gaussian, the integral over each pixel | |
| factors into a product of `erf` differences along x and y, with no | |
| discretization of the latent image and no finite-difference shift gradients. | |
| LBFGS jointly recovers the splat weights and the per-observation sub-pixel | |
| shifts. | |
| Pair this with `download_s2.py` for the input data. | |
| Requirements: | |
| pip install torch numpy rasterio scikit-image matplotlib | |
| Run: | |
| python gaussian_splat_sr.py # uses defaults | |
| python gaussian_splat_sr.py --data data --out out --crop 256 --target-m 2.5 | |
| """ | |
| import argparse | |
| import math | |
| import time | |
| from pathlib import Path | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import rasterio | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| PIXEL_M = 10.0 # native Sentinel-2 pixel size for B02/B03/B04/B08 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ----------------------------------------------------------------------------- | |
| # Data loading | |
| # ----------------------------------------------------------------------------- | |
| def load_stack(data_dir: Path, ref_band: int = 2): | |
| """Load every s2_*.tif from a directory and stack as [T, C, H, W] in 0..1.""" | |
| files = sorted(data_dir.glob("s2_*.tif")) | |
| if not files: | |
| raise FileNotFoundError(f"No s2_*.tif files in {data_dir}") | |
| obs_list = [] | |
| for fp in files: | |
| with rasterio.open(fp) as src: | |
| obs_list.append(src.read().astype(np.float32)) | |
| obs = np.stack(obs_list, axis=0) # [T, C, H, W] uint16-as-float | |
| obs = np.clip(obs, 0, 10000) / 10000.0 # crude reflectance normalization | |
| with rasterio.open(files[0]) as src: | |
| crs, transform = src.crs, src.transform | |
| return torch.from_numpy(obs), crs, transform, files | |
| def estimate_shifts_phase_corr(obs_crop: torch.Tensor, ref_band: int = 1): | |
| """Phase cross-correlation against observation 0 to seed shifts (in meters). | |
| The returned convention matches the forward model: a positive dx means the | |
| splat field for that observation is offset to higher x relative to obs 0, | |
| which makes features in the rendered LR image appear at higher column index. | |
| skimage's `phase_cross_correlation(reference, moving)` returns the shift | |
| required to register moving to reference (i.e. the negative of the offset | |
| moving has relative to reference), so we negate to align conventions. | |
| Returns: | |
| shifts: [T, 2] tensor with columns (dx, dy) in meters. | |
| """ | |
| from skimage.registration import phase_cross_correlation | |
| T = obs_crop.shape[0] | |
| ref = obs_crop[0, ref_band].numpy() | |
| shifts = torch.zeros(T, 2) | |
| for t in range(1, T): | |
| s, _, _ = phase_cross_correlation( | |
| ref, obs_crop[t, ref_band].numpy(), upsample_factor=100, | |
| ) | |
| shifts[t, 0] = -float(s[1]) * PIXEL_M # dx (column shift) | |
| shifts[t, 1] = -float(s[0]) * PIXEL_M # dy (row shift) | |
| return shifts | |
| # ----------------------------------------------------------------------------- | |
| # Splat scene with analytic erf forward model | |
| # ----------------------------------------------------------------------------- | |
| class SplatScene(nn.Module): | |
| """Field of 2D Gaussian splats on a regular grid. | |
| The forward model integrates the splat field over each pixel footprint | |
| analytically using `erf`, which is separable along x and y: | |
| I(p_lo, p_hi; mu, sigma) = 0.5 * (erf((p_hi - mu) / (sigma * sqrt(2))) | |
| - erf((p_lo - mu) / (sigma * sqrt(2)))) | |
| The full pixel value is sum_k w_k * I_x(...) * I_y(...). | |
| """ | |
| def __init__(self, extent_y, extent_x, spacing, sigma_splat, n_channels): | |
| super().__init__() | |
| self.spacing = spacing | |
| self.sigma_splat = sigma_splat | |
| self.n_channels = n_channels | |
| ny = max(1, int(round(extent_y / spacing))) | |
| nx = max(1, int(round(extent_x / spacing))) | |
| self.ny, self.nx = ny, nx | |
| cy = (torch.arange(ny, dtype=torch.float32) + 0.5) * spacing | |
| cx = (torch.arange(nx, dtype=torch.float32) + 0.5) * spacing | |
| self.register_buffer("cy", cy) | |
| self.register_buffer("cx", cx) | |
| self.w = nn.Parameter(torch.zeros(n_channels, ny, nx)) | |
| @staticmethod | |
| def _erf_integral(centers, edges_lo, edges_hi, sigma): | |
| """1D Gaussian integral over [edges_lo, edges_hi]. Returns [K, P].""" | |
| denom = sigma * math.sqrt(2.0) | |
| z_hi = (edges_hi.unsqueeze(0) - centers.unsqueeze(1)) / denom | |
| z_lo = (edges_lo.unsqueeze(0) - centers.unsqueeze(1)) / denom | |
| return 0.5 * (torch.erf(z_hi) - torch.erf(z_lo)) | |
| def _render_at(self, edges_y_lo, edges_y_hi, edges_x_lo, edges_x_hi, | |
| centers_y, centers_x, sigma): | |
| Iy = self._erf_integral(centers_y, edges_y_lo, edges_y_hi, sigma) # [Ky, H] | |
| Ix = self._erf_integral(centers_x, edges_x_lo, edges_x_hi, sigma) # [Kx, W] | |
| # pred[c, i, j] = sum_{ky, kx} w[c, ky, kx] * Iy[ky, i] * Ix[kx, j] | |
| return torch.einsum("yi,cyx,xj->cij", Iy, self.w, Ix) | |
| def predict_obs(self, n_h, n_w, sigma_psf, shifts): | |
| """Predict T low-resolution observations through warp+blur+downsample. | |
| Args: | |
| n_h, n_w: pixel grid size of each LR observation. | |
| sigma_psf: scalar sensor PSF std in meters. | |
| shifts: [T, 2] per-observation (dx, dy) shifts in meters. | |
| Returns: | |
| [T, C, n_h, n_w] predicted reflectance in [0, 1] (approximately). | |
| """ | |
| sigma_eff = math.sqrt(self.sigma_splat ** 2 + sigma_psf ** 2) | |
| device = self.w.device | |
| half = PIXEL_M / 2.0 | |
| py = (torch.arange(n_h, dtype=torch.float32, device=device) + 0.5) * PIXEL_M | |
| px = (torch.arange(n_w, dtype=torch.float32, device=device) + 0.5) * PIXEL_M | |
| preds = [] | |
| for t in range(shifts.shape[0]): | |
| cy_t = self.cy + shifts[t, 1] | |
| cx_t = self.cx + shifts[t, 0] | |
| preds.append(self._render_at( | |
| py - half, py + half, px - half, px + half, | |
| cy_t, cx_t, sigma_eff, | |
| )) | |
| return torch.stack(preds, dim=0) | |
| @torch.no_grad() | |
| def render(self, target_pixel_m, n_h, n_w): | |
| """Render the latent scene at a target resolution (no PSF).""" | |
| device = self.w.device | |
| half = target_pixel_m / 2.0 | |
| py = (torch.arange(n_h, dtype=torch.float32, device=device) + 0.5) * target_pixel_m | |
| px = (torch.arange(n_w, dtype=torch.float32, device=device) + 0.5) * target_pixel_m | |
| return self._render_at( | |
| py - half, py + half, px - half, px + half, | |
| self.cy, self.cx, self.sigma_splat, | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| # Training loop: Adam warmup, then LBFGS | |
| # ----------------------------------------------------------------------------- | |
| def fit(scene, obs_dev, sigma_psf, init_shifts, | |
| adam_steps=200, lbfgs_steps=50, lr_adam=5e-3, lr_lbfgs=1.0, verbose=True): | |
| n_h, n_w = obs_dev.shape[-2:] | |
| shifts = nn.Parameter(init_shifts.clone().to(obs_dev.device)) | |
| losses = [] | |
| if adam_steps > 0: | |
| opt = torch.optim.Adam([scene.w, shifts], lr=lr_adam) | |
| for step in range(adam_steps): | |
| opt.zero_grad() | |
| pred = scene.predict_obs(n_h, n_w, sigma_psf, shifts) | |
| loss = F.mse_loss(pred, obs_dev) | |
| loss.backward() | |
| opt.step() | |
| losses.append(loss.item()) | |
| if verbose and (step == 0 or (step + 1) % 50 == 0): | |
| print(f" adam {step+1:4d}: loss={loss.item():.6f}") | |
| if lbfgs_steps > 0: | |
| opt = torch.optim.LBFGS( | |
| [scene.w, shifts], lr=lr_lbfgs, max_iter=20, | |
| history_size=10, line_search_fn="strong_wolfe", | |
| ) | |
| for step in range(lbfgs_steps): | |
| def closure(): | |
| opt.zero_grad() | |
| pred = scene.predict_obs(n_h, n_w, sigma_psf, shifts) | |
| loss = F.mse_loss(pred, obs_dev) | |
| loss.backward() | |
| return loss | |
| loss = opt.step(closure) | |
| losses.append(loss.item()) | |
| if verbose and (step == 0 or (step + 1) % 10 == 0): | |
| shift_drift = (shifts.detach().cpu() - init_shifts).abs().mean().item() | |
| print(f" lbfgs {step+1:3d}: loss={loss.item():.6f} " | |
| f"mean|Δshift|={shift_drift:.3f}m") | |
| return shifts.detach().cpu(), losses | |
| # ----------------------------------------------------------------------------- | |
| # Visualization | |
| # ----------------------------------------------------------------------------- | |
| def to_rgb(t, lo_pct=2, hi_pct=98): | |
| """[C, H, W] reflectance -> [H, W, 3] RGB for display, robust to outliers.""" | |
| rgb = t[[2, 1, 0]].clamp(0, 1.5) # R, G, B from B04, B03, B02 | |
| flat = rgb.reshape(-1) | |
| sample = flat[torch.randperm(flat.numel())[:min(flat.numel(), 200_000)]] | |
| lo = torch.quantile(sample, lo_pct / 100) | |
| hi = torch.quantile(sample, hi_pct / 100) | |
| return ((rgb - lo) / (hi - lo + 1e-6)).clamp(0, 1).permute(1, 2, 0).numpy() | |
| def save_comparison(obs_one, render_native, render_target, target_m, out_path): | |
| fig, axes = plt.subplots(1, 3, figsize=(18, 6.5)) | |
| titles = [ | |
| f"Single S2 obs ({PIXEL_M:.0f}m)", | |
| f"SR rendered @ {PIXEL_M:.0f}m", | |
| f"SR rendered @ {target_m:.1f}m", | |
| ] | |
| for ax, img, title in zip(axes, [obs_one, render_native, render_target], titles): | |
| ax.imshow(to_rgb(img)) | |
| ax.set_title(title, fontsize=14, color="#555555") | |
| ax.axis("off") | |
| fig.tight_layout() | |
| fig.savefig(out_path, dpi=180, bbox_inches="tight", | |
| facecolor="none", transparent=True) | |
| plt.close(fig) | |
| def save_loss_curve(losses, adam_steps, out_path): | |
| fig, ax = plt.subplots(figsize=(8, 4)) | |
| ax.semilogy(losses) | |
| if 0 < adam_steps < len(losses): | |
| ax.axvline(adam_steps - 0.5, color="grey", linestyle="--", alpha=0.7, | |
| label="Adam -> LBFGS") | |
| ax.legend() | |
| ax.set_xlabel("optimizer step") | |
| ax.set_ylabel("MSE loss") | |
| ax.grid(True, alpha=0.3) | |
| fig.tight_layout() | |
| fig.savefig(out_path, dpi=120, bbox_inches="tight") | |
| plt.close(fig) | |
| def save_shift_scatter(measured, learned, out_path): | |
| fig, ax = plt.subplots(figsize=(6, 6)) | |
| ax.scatter(measured[:, 0], measured[:, 1], s=40, alpha=0.7, | |
| label="phase correlation") | |
| ax.scatter(learned[:, 0], learned[:, 1], s=40, marker="x", alpha=0.9, | |
| label="LBFGS recovered") | |
| ax.set_xlabel("dx (m)") | |
| ax.set_ylabel("dy (m)") | |
| ax.set_aspect("equal") | |
| ax.grid(True, alpha=0.3) | |
| ax.legend() | |
| fig.tight_layout() | |
| fig.savefig(out_path, dpi=120, bbox_inches="tight") | |
| plt.close(fig) | |
| # ----------------------------------------------------------------------------- | |
| # Output | |
| # ----------------------------------------------------------------------------- | |
| def save_geotiff(arr, out_path, crs, transform, descriptions=None): | |
| """Write a [C, H, W] float32 array as a multi-band GeoTIFF.""" | |
| arr_np = arr.cpu().numpy().astype(np.float32) | |
| profile = { | |
| "driver": "GTiff", "dtype": "float32", | |
| "height": arr_np.shape[1], "width": arr_np.shape[2], "count": arr_np.shape[0], | |
| "crs": crs, "transform": transform, | |
| "compress": "deflate", "tiled": True, "blockxsize": 256, "blockysize": 256, | |
| } | |
| with rasterio.open(out_path, "w", **profile) as dst: | |
| dst.write(arr_np) | |
| if descriptions: | |
| for i, d in enumerate(descriptions, start=1): | |
| dst.set_band_description(i, d) | |
| # ----------------------------------------------------------------------------- | |
| # Main | |
| # ----------------------------------------------------------------------------- | |
| def parse_args(): | |
| p = argparse.ArgumentParser(description=__doc__.split("\n")[0]) | |
| p.add_argument("--data", type=Path, default=Path("data"), | |
| help="Directory containing s2_*.tif files from download_s2.py") | |
| p.add_argument("--out", type=Path, default=Path("output"), | |
| help="Output directory for GeoTIFF + figures") | |
| p.add_argument("--crop", type=int, default=256, | |
| help="Centered LR pixel crop side. Use 0 to fit the full scene") | |
| p.add_argument("--spacing", type=float, default=3.0, | |
| help="Splat grid spacing in meters") | |
| p.add_argument("--sigma-splat", type=float, default=3.0, | |
| help="Splat width (sigma) in meters") | |
| p.add_argument("--sigma-psf", type=float, default=3.0, | |
| help="Assumed sensor PSF width (sigma) in meters") | |
| p.add_argument("--target-m", type=float, default=2.5, | |
| help="Target rendering pixel size in meters") | |
| p.add_argument("--adam-steps", type=int, default=200) | |
| p.add_argument("--lbfgs-steps", type=int, default=50) | |
| p.add_argument("--seed", type=int, default=42) | |
| return p.parse_args() | |
| def main(): | |
| args = parse_args() | |
| args.out.mkdir(parents=True, exist_ok=True) | |
| torch.manual_seed(args.seed) | |
| print(f"Device: {DEVICE}") | |
| obs_all, crs, transform, files = load_stack(args.data) | |
| T, C, H, W = obs_all.shape | |
| print(f"Loaded {T} observations, {C} bands, {H}x{W} pixels") | |
| if args.crop and args.crop < min(H, W): | |
| cy = H // 2 - args.crop // 2 | |
| cx = W // 2 - args.crop // 2 | |
| obs = obs_all[:, :, cy:cy + args.crop, cx:cx + args.crop] | |
| crop_origin = (cy, cx) | |
| else: | |
| obs = obs_all | |
| crop_origin = (0, 0) | |
| Tc, Cc, Hc, Wc = obs.shape | |
| extent_y, extent_x = Hc * PIXEL_M, Wc * PIXEL_M | |
| print(f"Working crop: {Hc}x{Wc} px ({extent_y:.0f}m x {extent_x:.0f}m)") | |
| print("Estimating sub-pixel shifts via phase correlation...") | |
| init_shifts = estimate_shifts_phase_corr(obs, ref_band=1) | |
| print(f" std dx={init_shifts[:, 0].std():.2f}m " | |
| f"std dy={init_shifts[:, 1].std():.2f}m " | |
| f"max |s|={init_shifts.abs().max():.2f}m") | |
| scene = SplatScene( | |
| extent_y=extent_y, extent_x=extent_x, | |
| spacing=args.spacing, sigma_splat=args.sigma_splat, n_channels=Cc, | |
| ).to(DEVICE) | |
| print(f"Splat grid: {scene.ny} x {scene.nx} = {scene.ny * scene.nx} splats " | |
| f"(spacing={args.spacing}m, sigma_splat={args.sigma_splat}m, " | |
| f"sigma_psf={args.sigma_psf}m)") | |
| # Bicubic-init the splat weights from the best (lowest cloud) observation | |
| # to give the optimizer a sensible neighborhood to start from. Each LR pixel | |
| # value is the sum of contributions from ~(PIXEL_M/spacing)^2 splats, so the | |
| # raw bicubic-resampled values overshoot by that factor and must be scaled | |
| # down. Without this scaling, the optimizer spends most of its budget | |
| # undoing the over-bright init and converges to a worse shift solution. | |
| best_t = int(obs.mean(dim=(1, 2, 3)).argmax()) # crude proxy: brightest scene | |
| init = F.interpolate(obs[best_t].unsqueeze(0), size=(scene.ny, scene.nx), | |
| mode="bicubic", align_corners=False).squeeze(0) | |
| init_scale = (args.spacing / PIXEL_M) ** 2 | |
| scene.w.data.copy_((init * init_scale).clamp(0, 1.5).to(DEVICE)) | |
| obs_dev = obs.to(DEVICE) | |
| print(f"Fitting with Adam ({args.adam_steps} steps) -> LBFGS ({args.lbfgs_steps} steps)...") | |
| t0 = time.time() | |
| learned_shifts, losses = fit( | |
| scene, obs_dev, sigma_psf=args.sigma_psf, init_shifts=init_shifts, | |
| adam_steps=args.adam_steps, lbfgs_steps=args.lbfgs_steps, | |
| ) | |
| elapsed = time.time() - t0 | |
| print(f"Done in {elapsed:.1f}s, final loss={losses[-1]:.6f}") | |
| # Render at native and target resolutions | |
| rh = int(round(extent_y / args.target_m)) | |
| rw = int(round(extent_x / args.target_m)) | |
| render_native = scene.render(PIXEL_M, Hc, Wc).cpu().clamp(0, 1.5) | |
| render_target = scene.render(args.target_m, rh, rw).cpu().clamp(0, 1.5) | |
| print(f"Rendered SR @ {args.target_m}m -> {rh}x{rw} px") | |
| # Save SR GeoTIFF with georeferencing inherited from the input crop | |
| cy, cx = crop_origin | |
| target_transform = rasterio.Affine( | |
| args.target_m, 0, transform.c + cx * transform.a, | |
| 0, -args.target_m, transform.f + cy * transform.e, | |
| ) | |
| save_geotiff( | |
| render_target, args.out / f"sr_{args.target_m:.1f}m.tif", | |
| crs, target_transform, descriptions=["Blue", "Green", "Red", "NIR"], | |
| ) | |
| # Plots | |
| save_comparison(obs[best_t], render_native, render_target, | |
| args.target_m, args.out / "comparison.png") | |
| save_loss_curve(losses, args.adam_steps, args.out / "loss.png") | |
| save_shift_scatter(init_shifts, learned_shifts, args.out / "shifts.png") | |
| # Quantitative agreement between learned and phase-correlation shifts | |
| dx_corr = float(np.corrcoef(init_shifts[:, 0], learned_shifts[:, 0])[0, 1]) | |
| dy_corr = float(np.corrcoef(init_shifts[:, 1], learned_shifts[:, 1])[0, 1]) | |
| print(f"Shift agreement vs phase correlation: r_dx={dx_corr:.3f}, r_dy={dy_corr:.3f}") | |
| print(f"Outputs in {args.out.absolute()}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thank you so much! Going to study it!