Skip to content

Instantly share code, notes, and snippets.

@calebrob6
Created May 9, 2026 05:47
Show Gist options
  • Select an option

  • Save calebrob6/9ce42c29dc956d2b09148e2ff5ca5b17 to your computer and use it in GitHub Desktop.

Select an option

Save calebrob6/9ce42c29dc956d2b09148e2ff5ca5b17 to your computer and use it in GitHub Desktop.
S2 super-resolution with gaussian splats
#!/usr/bin/env python3
"""Download a Sentinel-2 time series from the Microsoft Planetary Computer.
Companion script for the blog post "Super-Resolving Sentinel-2 with Gaussian
Splats" (https://geospatialml.com/posts/sentinel2-superresolution/).
Pulls N cloud-free Sentinel-2 L2A scenes over a configurable AOI, crops to a
common pixel window, and writes one 4-band GeoTIFF per scene (B02 Blue,
B03 Green, B04 Red, B08 NIR; uint16 reflectance in 0..10000).
Requirements:
pip install planetary-computer pystac-client rasterio numpy
Run:
python download_s2.py # uses defaults below
python download_s2.py --lat 47.674 --lon -122.121 --n-scenes 32 --out data
"""
import argparse
import json
import sys
from pathlib import Path
import numpy as np
import planetary_computer
import pystac_client
import rasterio
from rasterio.crs import CRS
from rasterio.warp import transform_bounds
BANDS = ["B02", "B03", "B04", "B08"] # Blue, Green, Red, NIR @ 10m
BAND_DESCS = ["Blue (B02)", "Green (B03)", "Red (B04)", "NIR (B08)"]
def parse_args():
p = argparse.ArgumentParser(description=__doc__.split("\n")[0])
p.add_argument("--lat", type=float, default=47.674,
help="AOI center latitude (default: Redmond, WA)")
p.add_argument("--lon", type=float, default=-122.121,
help="AOI center longitude")
p.add_argument("--half-deg-lat", type=float, default=0.045,
help="Half AOI extent in latitude degrees (~5km at default lat)")
p.add_argument("--half-deg-lon", type=float, default=0.065,
help="Half AOI extent in longitude degrees (~5km at default lat)")
p.add_argument("--n-scenes", type=int, default=32,
help="Number of cloud-free scenes to download")
p.add_argument("--max-cloud", type=float, default=10.0,
help="Maximum eo:cloud_cover percentage")
p.add_argument("--date-range", type=str, default="2025-04-01/2025-10-31",
help="STAC datetime range, e.g. 2025-04-01/2025-10-31")
p.add_argument("--out", type=Path, default=Path("data"),
help="Output directory")
return p.parse_args()
def search_scenes(bbox, date_range, max_cloud, n_scenes):
"""Find candidate L2A scenes over `bbox` in `date_range`, sorted by cloud cover.
Returns more candidates than `n_scenes` so the caller can drop scenes that
only partially cover the AOI (granule edges) and still hit the target count.
"""
catalog = pystac_client.Client.open(
"https://planetarycomputer.microsoft.com/api/stac/v1",
modifier=planetary_computer.sign_inplace,
)
search = catalog.search(
collections=["sentinel-2-l2a"],
bbox=bbox,
datetime=date_range,
query={"eo:cloud_cover": {"lt": max_cloud}},
sortby=[{"field": "eo:cloud_cover", "direction": "asc"}],
)
items = list(search.items())
if len(items) < n_scenes:
print(f"Warning: only {len(items)} scenes match (asked for {n_scenes}).",
file=sys.stderr)
return items
def compute_target_window(item, bbox_4326):
"""From a reference scene, compute the rasterio window covering bbox.
All other scenes will be cropped to the same pixel window so the stack
is co-registered out of the box.
"""
asset = planetary_computer.sign(item.assets["B04"])
with rasterio.open(asset.href) as src:
bounds_in_scene = transform_bounds(CRS.from_epsg(4326), src.crs, *bbox_4326)
window = rasterio.windows.from_bounds(*bounds_in_scene, transform=src.transform)
window = window.round_offsets().round_lengths()
transform = rasterio.windows.transform(window, src.transform)
return window, transform, src.crs
MIN_VALID_FRACTION = 0.99 # drop scenes where less than this fraction of the AOI is in-granule
def download_scene(item, idx, window, transform, crs, out_dir):
"""Read all 4 bands at the same window and write a single multi-band COG.
Returns (path, status). Status is "downloaded", "exists", or "skipped:<reason>".
Skipped scenes are not written to disk (the AOI sits at a granule edge and is
mostly NoData -- those scenes break phase correlation and bias the SR fit).
"""
date_str = item.datetime.strftime("%Y%m%d")
cloud = item.properties.get("eo:cloud_cover", -1)
out_path = out_dir / f"s2_{idx:03d}_{date_str}_cc{cloud:.0f}.tif"
if out_path.exists():
return out_path, "exists"
bands = []
for band in BANDS:
asset = planetary_computer.sign(item.assets[band])
with rasterio.open(asset.href) as src:
bands.append(src.read(1, window=window))
stack = np.stack(bands, axis=0).astype(np.uint16)
valid_frac = float((stack[0] > 0).sum()) / stack[0].size
if valid_frac < MIN_VALID_FRACTION:
return out_path, f"skipped:partial-coverage({valid_frac:.2%})"
profile = {
"driver": "GTiff", "dtype": "uint16",
"height": stack.shape[1], "width": stack.shape[2], "count": 4,
"crs": crs, "transform": transform,
"compress": "deflate", "tiled": True, "blockxsize": 256, "blockysize": 256,
}
with rasterio.open(out_path, "w", **profile) as dst:
dst.write(stack)
for i, desc in enumerate(BAND_DESCS, start=1):
dst.set_band_description(i, desc)
dst.update_tags(
datetime=item.datetime.isoformat(),
cloud_cover=str(cloud),
scene_id=item.id,
)
return out_path, "downloaded"
def main():
args = parse_args()
args.out.mkdir(parents=True, exist_ok=True)
bbox = [
args.lon - args.half_deg_lon, args.lat - args.half_deg_lat,
args.lon + args.half_deg_lon, args.lat + args.half_deg_lat,
]
print(f"AOI bbox (lon/lat): {bbox}")
print(f"Searching {args.date_range} for <{args.max_cloud}% cloud cover...")
items = search_scenes(bbox, args.date_range, args.max_cloud, args.n_scenes)
if not items:
sys.exit("No matching scenes found.")
print(f"Found {len(items)} candidate scenes (lowest cloud cover first); "
f"keeping the first {args.n_scenes} that fully cover the AOI.")
window, transform, crs = compute_target_window(items[0], bbox)
print(f"Target window: {int(window.height)} x {int(window.width)} px in {crs}")
saved = []
kept = 0
for item in items:
if kept >= args.n_scenes:
break
idx = kept
try:
path, status = download_scene(item, idx, window, transform, crs, args.out)
print(f" [{idx:02d}] {status:25s} {path.name}")
if status.startswith("skipped"):
continue
saved.append({
"file": path.name,
"date": item.datetime.isoformat(),
"cloud_cover": item.properties.get("eo:cloud_cover", -1),
"scene_id": item.id,
})
kept += 1
except Exception as e:
print(f" [{idx:02d}] {'FAILED':25s} {e}")
if kept < args.n_scenes:
print(f"Warning: kept {kept} scenes, fewer than the requested {args.n_scenes}.",
file=sys.stderr)
meta = {
"bbox_4326": bbox,
"crs": str(crs),
"transform": list(transform)[:6],
"height": int(window.height),
"width": int(window.width),
"pixel_size_m": 10.0,
"bands": BAND_DESCS,
"scenes": saved,
}
with open(args.out / "metadata.json", "w") as f:
json.dump(meta, f, indent=2)
print(f"Wrote {len(saved)} scenes and metadata.json to {args.out}/")
if __name__ == "__main__":
main()
#!/usr/bin/env python3
"""Multi-temporal Sentinel-2 super-resolution with Gaussian splats.
Companion script for the blog post "Super-Resolving Sentinel-2 with Gaussian
Splats" (https://geospatialml.com/posts/sentinel2-superresolution/).
The scene is represented as a continuous field of 2D Gaussians on a regular
grid. Each Sentinel-2 observation is the analytic integral of that field
over shifted 10m pixel footprints convolved with the sensor PSF. Because
both the splats and the PSF are Gaussian, the integral over each pixel
factors into a product of `erf` differences along x and y, with no
discretization of the latent image and no finite-difference shift gradients.
LBFGS jointly recovers the splat weights and the per-observation sub-pixel
shifts.
Pair this with `download_s2.py` for the input data.
Requirements:
pip install torch numpy rasterio scikit-image matplotlib
Run:
python gaussian_splat_sr.py # uses defaults
python gaussian_splat_sr.py --data data --out out --crop 256 --target-m 2.5
"""
import argparse
import math
import time
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import rasterio
import torch
import torch.nn as nn
import torch.nn.functional as F
PIXEL_M = 10.0 # native Sentinel-2 pixel size for B02/B03/B04/B08
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# -----------------------------------------------------------------------------
# Data loading
# -----------------------------------------------------------------------------
def load_stack(data_dir: Path, ref_band: int = 2):
"""Load every s2_*.tif from a directory and stack as [T, C, H, W] in 0..1."""
files = sorted(data_dir.glob("s2_*.tif"))
if not files:
raise FileNotFoundError(f"No s2_*.tif files in {data_dir}")
obs_list = []
for fp in files:
with rasterio.open(fp) as src:
obs_list.append(src.read().astype(np.float32))
obs = np.stack(obs_list, axis=0) # [T, C, H, W] uint16-as-float
obs = np.clip(obs, 0, 10000) / 10000.0 # crude reflectance normalization
with rasterio.open(files[0]) as src:
crs, transform = src.crs, src.transform
return torch.from_numpy(obs), crs, transform, files
def estimate_shifts_phase_corr(obs_crop: torch.Tensor, ref_band: int = 1):
"""Phase cross-correlation against observation 0 to seed shifts (in meters).
The returned convention matches the forward model: a positive dx means the
splat field for that observation is offset to higher x relative to obs 0,
which makes features in the rendered LR image appear at higher column index.
skimage's `phase_cross_correlation(reference, moving)` returns the shift
required to register moving to reference (i.e. the negative of the offset
moving has relative to reference), so we negate to align conventions.
Returns:
shifts: [T, 2] tensor with columns (dx, dy) in meters.
"""
from skimage.registration import phase_cross_correlation
T = obs_crop.shape[0]
ref = obs_crop[0, ref_band].numpy()
shifts = torch.zeros(T, 2)
for t in range(1, T):
s, _, _ = phase_cross_correlation(
ref, obs_crop[t, ref_band].numpy(), upsample_factor=100,
)
shifts[t, 0] = -float(s[1]) * PIXEL_M # dx (column shift)
shifts[t, 1] = -float(s[0]) * PIXEL_M # dy (row shift)
return shifts
# -----------------------------------------------------------------------------
# Splat scene with analytic erf forward model
# -----------------------------------------------------------------------------
class SplatScene(nn.Module):
"""Field of 2D Gaussian splats on a regular grid.
The forward model integrates the splat field over each pixel footprint
analytically using `erf`, which is separable along x and y:
I(p_lo, p_hi; mu, sigma) = 0.5 * (erf((p_hi - mu) / (sigma * sqrt(2)))
- erf((p_lo - mu) / (sigma * sqrt(2))))
The full pixel value is sum_k w_k * I_x(...) * I_y(...).
"""
def __init__(self, extent_y, extent_x, spacing, sigma_splat, n_channels):
super().__init__()
self.spacing = spacing
self.sigma_splat = sigma_splat
self.n_channels = n_channels
ny = max(1, int(round(extent_y / spacing)))
nx = max(1, int(round(extent_x / spacing)))
self.ny, self.nx = ny, nx
cy = (torch.arange(ny, dtype=torch.float32) + 0.5) * spacing
cx = (torch.arange(nx, dtype=torch.float32) + 0.5) * spacing
self.register_buffer("cy", cy)
self.register_buffer("cx", cx)
self.w = nn.Parameter(torch.zeros(n_channels, ny, nx))
@staticmethod
def _erf_integral(centers, edges_lo, edges_hi, sigma):
"""1D Gaussian integral over [edges_lo, edges_hi]. Returns [K, P]."""
denom = sigma * math.sqrt(2.0)
z_hi = (edges_hi.unsqueeze(0) - centers.unsqueeze(1)) / denom
z_lo = (edges_lo.unsqueeze(0) - centers.unsqueeze(1)) / denom
return 0.5 * (torch.erf(z_hi) - torch.erf(z_lo))
def _render_at(self, edges_y_lo, edges_y_hi, edges_x_lo, edges_x_hi,
centers_y, centers_x, sigma):
Iy = self._erf_integral(centers_y, edges_y_lo, edges_y_hi, sigma) # [Ky, H]
Ix = self._erf_integral(centers_x, edges_x_lo, edges_x_hi, sigma) # [Kx, W]
# pred[c, i, j] = sum_{ky, kx} w[c, ky, kx] * Iy[ky, i] * Ix[kx, j]
return torch.einsum("yi,cyx,xj->cij", Iy, self.w, Ix)
def predict_obs(self, n_h, n_w, sigma_psf, shifts):
"""Predict T low-resolution observations through warp+blur+downsample.
Args:
n_h, n_w: pixel grid size of each LR observation.
sigma_psf: scalar sensor PSF std in meters.
shifts: [T, 2] per-observation (dx, dy) shifts in meters.
Returns:
[T, C, n_h, n_w] predicted reflectance in [0, 1] (approximately).
"""
sigma_eff = math.sqrt(self.sigma_splat ** 2 + sigma_psf ** 2)
device = self.w.device
half = PIXEL_M / 2.0
py = (torch.arange(n_h, dtype=torch.float32, device=device) + 0.5) * PIXEL_M
px = (torch.arange(n_w, dtype=torch.float32, device=device) + 0.5) * PIXEL_M
preds = []
for t in range(shifts.shape[0]):
cy_t = self.cy + shifts[t, 1]
cx_t = self.cx + shifts[t, 0]
preds.append(self._render_at(
py - half, py + half, px - half, px + half,
cy_t, cx_t, sigma_eff,
))
return torch.stack(preds, dim=0)
@torch.no_grad()
def render(self, target_pixel_m, n_h, n_w):
"""Render the latent scene at a target resolution (no PSF)."""
device = self.w.device
half = target_pixel_m / 2.0
py = (torch.arange(n_h, dtype=torch.float32, device=device) + 0.5) * target_pixel_m
px = (torch.arange(n_w, dtype=torch.float32, device=device) + 0.5) * target_pixel_m
return self._render_at(
py - half, py + half, px - half, px + half,
self.cy, self.cx, self.sigma_splat,
)
# -----------------------------------------------------------------------------
# Training loop: Adam warmup, then LBFGS
# -----------------------------------------------------------------------------
def fit(scene, obs_dev, sigma_psf, init_shifts,
adam_steps=200, lbfgs_steps=50, lr_adam=5e-3, lr_lbfgs=1.0, verbose=True):
n_h, n_w = obs_dev.shape[-2:]
shifts = nn.Parameter(init_shifts.clone().to(obs_dev.device))
losses = []
if adam_steps > 0:
opt = torch.optim.Adam([scene.w, shifts], lr=lr_adam)
for step in range(adam_steps):
opt.zero_grad()
pred = scene.predict_obs(n_h, n_w, sigma_psf, shifts)
loss = F.mse_loss(pred, obs_dev)
loss.backward()
opt.step()
losses.append(loss.item())
if verbose and (step == 0 or (step + 1) % 50 == 0):
print(f" adam {step+1:4d}: loss={loss.item():.6f}")
if lbfgs_steps > 0:
opt = torch.optim.LBFGS(
[scene.w, shifts], lr=lr_lbfgs, max_iter=20,
history_size=10, line_search_fn="strong_wolfe",
)
for step in range(lbfgs_steps):
def closure():
opt.zero_grad()
pred = scene.predict_obs(n_h, n_w, sigma_psf, shifts)
loss = F.mse_loss(pred, obs_dev)
loss.backward()
return loss
loss = opt.step(closure)
losses.append(loss.item())
if verbose and (step == 0 or (step + 1) % 10 == 0):
shift_drift = (shifts.detach().cpu() - init_shifts).abs().mean().item()
print(f" lbfgs {step+1:3d}: loss={loss.item():.6f} "
f"mean|Δshift|={shift_drift:.3f}m")
return shifts.detach().cpu(), losses
# -----------------------------------------------------------------------------
# Visualization
# -----------------------------------------------------------------------------
def to_rgb(t, lo_pct=2, hi_pct=98):
"""[C, H, W] reflectance -> [H, W, 3] RGB for display, robust to outliers."""
rgb = t[[2, 1, 0]].clamp(0, 1.5) # R, G, B from B04, B03, B02
flat = rgb.reshape(-1)
sample = flat[torch.randperm(flat.numel())[:min(flat.numel(), 200_000)]]
lo = torch.quantile(sample, lo_pct / 100)
hi = torch.quantile(sample, hi_pct / 100)
return ((rgb - lo) / (hi - lo + 1e-6)).clamp(0, 1).permute(1, 2, 0).numpy()
def save_comparison(obs_one, render_native, render_target, target_m, out_path):
fig, axes = plt.subplots(1, 3, figsize=(18, 6.5))
titles = [
f"Single S2 obs ({PIXEL_M:.0f}m)",
f"SR rendered @ {PIXEL_M:.0f}m",
f"SR rendered @ {target_m:.1f}m",
]
for ax, img, title in zip(axes, [obs_one, render_native, render_target], titles):
ax.imshow(to_rgb(img))
ax.set_title(title, fontsize=14, color="#555555")
ax.axis("off")
fig.tight_layout()
fig.savefig(out_path, dpi=180, bbox_inches="tight",
facecolor="none", transparent=True)
plt.close(fig)
def save_loss_curve(losses, adam_steps, out_path):
fig, ax = plt.subplots(figsize=(8, 4))
ax.semilogy(losses)
if 0 < adam_steps < len(losses):
ax.axvline(adam_steps - 0.5, color="grey", linestyle="--", alpha=0.7,
label="Adam -> LBFGS")
ax.legend()
ax.set_xlabel("optimizer step")
ax.set_ylabel("MSE loss")
ax.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(out_path, dpi=120, bbox_inches="tight")
plt.close(fig)
def save_shift_scatter(measured, learned, out_path):
fig, ax = plt.subplots(figsize=(6, 6))
ax.scatter(measured[:, 0], measured[:, 1], s=40, alpha=0.7,
label="phase correlation")
ax.scatter(learned[:, 0], learned[:, 1], s=40, marker="x", alpha=0.9,
label="LBFGS recovered")
ax.set_xlabel("dx (m)")
ax.set_ylabel("dy (m)")
ax.set_aspect("equal")
ax.grid(True, alpha=0.3)
ax.legend()
fig.tight_layout()
fig.savefig(out_path, dpi=120, bbox_inches="tight")
plt.close(fig)
# -----------------------------------------------------------------------------
# Output
# -----------------------------------------------------------------------------
def save_geotiff(arr, out_path, crs, transform, descriptions=None):
"""Write a [C, H, W] float32 array as a multi-band GeoTIFF."""
arr_np = arr.cpu().numpy().astype(np.float32)
profile = {
"driver": "GTiff", "dtype": "float32",
"height": arr_np.shape[1], "width": arr_np.shape[2], "count": arr_np.shape[0],
"crs": crs, "transform": transform,
"compress": "deflate", "tiled": True, "blockxsize": 256, "blockysize": 256,
}
with rasterio.open(out_path, "w", **profile) as dst:
dst.write(arr_np)
if descriptions:
for i, d in enumerate(descriptions, start=1):
dst.set_band_description(i, d)
# -----------------------------------------------------------------------------
# Main
# -----------------------------------------------------------------------------
def parse_args():
p = argparse.ArgumentParser(description=__doc__.split("\n")[0])
p.add_argument("--data", type=Path, default=Path("data"),
help="Directory containing s2_*.tif files from download_s2.py")
p.add_argument("--out", type=Path, default=Path("output"),
help="Output directory for GeoTIFF + figures")
p.add_argument("--crop", type=int, default=256,
help="Centered LR pixel crop side. Use 0 to fit the full scene")
p.add_argument("--spacing", type=float, default=3.0,
help="Splat grid spacing in meters")
p.add_argument("--sigma-splat", type=float, default=3.0,
help="Splat width (sigma) in meters")
p.add_argument("--sigma-psf", type=float, default=3.0,
help="Assumed sensor PSF width (sigma) in meters")
p.add_argument("--target-m", type=float, default=2.5,
help="Target rendering pixel size in meters")
p.add_argument("--adam-steps", type=int, default=200)
p.add_argument("--lbfgs-steps", type=int, default=50)
p.add_argument("--seed", type=int, default=42)
return p.parse_args()
def main():
args = parse_args()
args.out.mkdir(parents=True, exist_ok=True)
torch.manual_seed(args.seed)
print(f"Device: {DEVICE}")
obs_all, crs, transform, files = load_stack(args.data)
T, C, H, W = obs_all.shape
print(f"Loaded {T} observations, {C} bands, {H}x{W} pixels")
if args.crop and args.crop < min(H, W):
cy = H // 2 - args.crop // 2
cx = W // 2 - args.crop // 2
obs = obs_all[:, :, cy:cy + args.crop, cx:cx + args.crop]
crop_origin = (cy, cx)
else:
obs = obs_all
crop_origin = (0, 0)
Tc, Cc, Hc, Wc = obs.shape
extent_y, extent_x = Hc * PIXEL_M, Wc * PIXEL_M
print(f"Working crop: {Hc}x{Wc} px ({extent_y:.0f}m x {extent_x:.0f}m)")
print("Estimating sub-pixel shifts via phase correlation...")
init_shifts = estimate_shifts_phase_corr(obs, ref_band=1)
print(f" std dx={init_shifts[:, 0].std():.2f}m "
f"std dy={init_shifts[:, 1].std():.2f}m "
f"max |s|={init_shifts.abs().max():.2f}m")
scene = SplatScene(
extent_y=extent_y, extent_x=extent_x,
spacing=args.spacing, sigma_splat=args.sigma_splat, n_channels=Cc,
).to(DEVICE)
print(f"Splat grid: {scene.ny} x {scene.nx} = {scene.ny * scene.nx} splats "
f"(spacing={args.spacing}m, sigma_splat={args.sigma_splat}m, "
f"sigma_psf={args.sigma_psf}m)")
# Bicubic-init the splat weights from the best (lowest cloud) observation
# to give the optimizer a sensible neighborhood to start from. Each LR pixel
# value is the sum of contributions from ~(PIXEL_M/spacing)^2 splats, so the
# raw bicubic-resampled values overshoot by that factor and must be scaled
# down. Without this scaling, the optimizer spends most of its budget
# undoing the over-bright init and converges to a worse shift solution.
best_t = int(obs.mean(dim=(1, 2, 3)).argmax()) # crude proxy: brightest scene
init = F.interpolate(obs[best_t].unsqueeze(0), size=(scene.ny, scene.nx),
mode="bicubic", align_corners=False).squeeze(0)
init_scale = (args.spacing / PIXEL_M) ** 2
scene.w.data.copy_((init * init_scale).clamp(0, 1.5).to(DEVICE))
obs_dev = obs.to(DEVICE)
print(f"Fitting with Adam ({args.adam_steps} steps) -> LBFGS ({args.lbfgs_steps} steps)...")
t0 = time.time()
learned_shifts, losses = fit(
scene, obs_dev, sigma_psf=args.sigma_psf, init_shifts=init_shifts,
adam_steps=args.adam_steps, lbfgs_steps=args.lbfgs_steps,
)
elapsed = time.time() - t0
print(f"Done in {elapsed:.1f}s, final loss={losses[-1]:.6f}")
# Render at native and target resolutions
rh = int(round(extent_y / args.target_m))
rw = int(round(extent_x / args.target_m))
render_native = scene.render(PIXEL_M, Hc, Wc).cpu().clamp(0, 1.5)
render_target = scene.render(args.target_m, rh, rw).cpu().clamp(0, 1.5)
print(f"Rendered SR @ {args.target_m}m -> {rh}x{rw} px")
# Save SR GeoTIFF with georeferencing inherited from the input crop
cy, cx = crop_origin
target_transform = rasterio.Affine(
args.target_m, 0, transform.c + cx * transform.a,
0, -args.target_m, transform.f + cy * transform.e,
)
save_geotiff(
render_target, args.out / f"sr_{args.target_m:.1f}m.tif",
crs, target_transform, descriptions=["Blue", "Green", "Red", "NIR"],
)
# Plots
save_comparison(obs[best_t], render_native, render_target,
args.target_m, args.out / "comparison.png")
save_loss_curve(losses, args.adam_steps, args.out / "loss.png")
save_shift_scatter(init_shifts, learned_shifts, args.out / "shifts.png")
# Quantitative agreement between learned and phase-correlation shifts
dx_corr = float(np.corrcoef(init_shifts[:, 0], learned_shifts[:, 0])[0, 1])
dy_corr = float(np.corrcoef(init_shifts[:, 1], learned_shifts[:, 1])[0, 1])
print(f"Shift agreement vs phase correlation: r_dx={dx_corr:.3f}, r_dy={dy_corr:.3f}")
print(f"Outputs in {args.out.absolute()}")
if __name__ == "__main__":
main()
@paolodep36

Copy link
Copy Markdown

where is the implementation with c2f? thanks

@calebrob6

Copy link
Copy Markdown
Author

Hi @paolodep36 C2F means coarse to fine. If this script isn't working for you, I just open sourced the full repo https://github.com/calebrob6/s2-superres

@paolodep36

Copy link
Copy Markdown

Thank you so much! Going to study it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment