Created
May 5, 2025 13:34
-
-
Save dylanlee/355078c11b8bd63495501eebacad2abc to your computer and use it in GitHub Desktop.
Script demonstrating formatting Benchmark STAC queries in a consistent way across collections for use in the auto-eval pipeline
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import argparse | |
import json | |
import logging | |
import re | |
from collections import defaultdict | |
from dataclasses import dataclass, field | |
from datetime import datetime, timezone, timedelta | |
from pathlib import Path | |
from typing import Any, Callable, Dict, List, Pattern, Tuple | |
import requests | |
from pystac_client import Client | |
logging.basicConfig( | |
format="%(asctime)s [%(levelname)s] %(message)s", level=logging.INFO | |
) | |
logger = logging.getLogger(__name__) | |
def parse_iso(ts: str) -> datetime: | |
""" | |
Strict ISO‐8601 → datetime with tzinfo. | |
Replace trailing 'Z' so fromisoformat handles UTC. | |
""" | |
if ts.endswith("Z"): | |
ts = ts[:-1] + "+00:00" | |
return datetime.fromisoformat(ts) | |
def dictify(o: Any) -> Any: | |
""" | |
Recursively convert defaultdict → dict for JSON serialization. | |
(For nested dataclass support you could use dataclasses-json [github.com].) | |
""" | |
if isinstance(o, defaultdict): | |
o = dict(o) | |
if isinstance(o, dict): | |
return {k: dictify(v) for k, v in o.items()} | |
if isinstance(o, list): | |
return [dictify(v) for v in o] | |
return o | |
def compare_versions(v1: str, v2: str) -> int: | |
"""Compare two flows2fim_version strings (e.g., '0_3_0' vs '0_10_3') | |
Returns 1 if v1 > v2, -1 if v1 < v2, 0 if equal.""" | |
try: | |
v1_parts = [int(p) for p in v1.split("_")] | |
v2_parts = [int(p) for p in v2.split("_")] | |
for i in range(min(len(v1_parts), len(v2_parts))): | |
if v1_parts[i] > v2_parts[i]: | |
return 1 | |
elif v1_parts[i] < v2_parts[i]: | |
return -1 | |
# If we've compared all parts and they're equal, the longer one is newer | |
if len(v1_parts) > len(v2_parts): | |
return 1 | |
elif len(v1_parts) < len(v2_parts): | |
return -1 | |
return 0 | |
except (ValueError, AttributeError): | |
# If versions cannot be compared, treat them as equal | |
return 0 | |
BLE_SPEC: List[Tuple[Pattern, str, str]] = [ | |
(re.compile(r"^(\d+yr)_(extent_raster|flow_file)$"), r"\1", r"\2"), | |
] | |
NWS_USGS_MAGS = ["action", "minor", "moderate", "major"] | |
NWS_USGS_SPEC = [ | |
(re.compile(rf"^{mag}_(extent_raster|flow_file)$"), mag, r"\1") | |
for mag in NWS_USGS_MAGS | |
] | |
RIPPLE_SPEC: List[Tuple[Pattern, str, str]] = [ | |
(re.compile(r"^(\d+yr)_extent$"), r"\1", "extents") | |
] | |
CollectionConfig = Dict[ | |
str, | |
Tuple[ | |
Callable[[Any], str], # grouping fn → group_id | |
Dict[str, Callable[[str, Any], bool]], # asset_type → test(key, asset) | |
], | |
] | |
COLLECTIONS: CollectionConfig = { | |
"gfm-collection": ( | |
lambda item: str(item.properties.get("dfo_event_id", item.id)), | |
{ | |
"extents": lambda k, a: k.endswith("_Observed_Water_Extent"), | |
"flowfiles": lambda k, a: k.endswith("_flowfile"), | |
}, | |
), | |
"gfm-expanded-collection": ( | |
lambda item: group_gfm_expanded_initial(item), | |
{ | |
"extents": lambda k, a: k.endswith("_Observed_Water_Extent"), | |
"flowfiles": lambda k, a: k.endswith("_flowfile") | |
or k == "NWM_ANA_flowfile", | |
}, | |
), | |
"hwm-collection": ( | |
lambda item: item.id, | |
{ | |
"points": lambda k, a: k == "data" | |
and a.media_type | |
and "geopackage" in a.media_type, | |
"flowfiles": lambda k, a: k.endswith("-flowfile"), | |
}, | |
), | |
} | |
def group_gfm_expanded_initial(item: Any) -> str: | |
start = item.datetime | |
if not start: | |
sp = item.properties.get("gfm_data_take_start_datetime") | |
if sp: | |
start = parse_iso(sp) | |
else: | |
logger.warning(f"No start for {item.id}; using item.id") | |
return item.id | |
if start.tzinfo is None: | |
start = start.replace(tzinfo=timezone.utc) | |
end_val = item.properties.get("end_datetime") | |
if isinstance(end_val, str): | |
end = parse_iso(end_val) | |
else: | |
end = end_val or start | |
if end < start: | |
ep = item.properties.get("gfm_data_take_end_datetime") | |
try: | |
end = parse_iso(ep) if ep else start | |
except Exception: | |
end = start | |
if end.tzinfo is None: | |
end = end.replace(tzinfo=timezone.utc) | |
fmt = "%Y-%m-%dT%H:%M:%SZ" | |
return f"{start.strftime(fmt)}/{end.strftime(fmt)}" | |
@dataclass(order=True) | |
class Interval: | |
start: datetime | |
end: datetime | |
assets: Dict[str, List[str]] = field(default_factory=lambda: defaultdict(list)) | |
def format_results(item_iter: Any) -> Dict[str, Dict[str, Dict[str, List[str]]]]: | |
results = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) | |
seen = set() | |
ripple_cache: Dict[str, List[str]] = {} | |
ripple_best_items = {} | |
for idx, item in enumerate(item_iter, start=1): | |
if item.collection_id == "ripple-fim-collection": | |
try: | |
source = item.properties.get("source", "") | |
hucs = tuple(item.properties.get("hucs", [])) | |
flows2fim_version = item.properties.get("flows2fim_version", "") | |
if source and hucs and flows2fim_version: | |
key = (source, hucs) | |
# If we haven't seen this source+hucs combo, or this version is better | |
if ( | |
key not in ripple_best_items | |
or compare_versions( | |
flows2fim_version, ripple_best_items[key][1] | |
) | |
> 0 | |
): | |
ripple_best_items[key] = (item.id, flows2fim_version) | |
else: | |
# Skip this item - we already have a better version | |
logger.info( | |
f"Skipping {item.id} - newer version exists for {source}+{hucs[0] if hucs else ''}" | |
) | |
continue | |
except Exception as e: | |
logger.warning(f"Error processing Ripple item {item.id}: {e}") | |
if item.id in seen: | |
continue | |
seen.add(item.id) | |
if idx % 100 == 0: | |
logger.info(f"Processed {idx} items (last: {item.id})") | |
coll = item.collection_id or "<none>" | |
short = coll.replace("-collection", "").replace("-fim", "") | |
if short == "gfm-expanded": | |
short = "gfm_expanded" | |
# 1) item‐level grouping | |
if coll in COLLECTIONS: | |
group_fn, tests = COLLECTIONS[coll] | |
gid = group_fn(item) | |
bucket = results[short][gid] | |
for k, a in item.assets.items(): | |
if not a.href: | |
continue | |
for atype, test in tests.items(): | |
if test(k, a) and a.href not in bucket[atype]: | |
bucket[atype].append(a.href) | |
# 2) BLE/NWS/USGS/Ripple asset‐level grouping | |
elif coll == "ble-collection" or coll.endswith("-fim-collection"): | |
# preload ripple assets once | |
if coll == "ripple-fim-collection" and not ripple_cache: | |
try: | |
col = item.get_collection() | |
for ak, aa in col.assets.items(): | |
m = re.search(r"flows_(\d+)_yr_", ak) | |
ri = f"{m.group(1)}yr" if m else None | |
if ri and aa.media_type == "text/csv": | |
logger.info(f"Caching Ripple flowfile for {ri}: {aa.href}") | |
ripple_cache.setdefault(ri, []).append(aa.href) | |
except Exception as e: | |
logger.warning(f"Ripple cache failed: {e}") | |
specs = ( | |
BLE_SPEC | |
if coll == "ble-collection" | |
else RIPPLE_SPEC if coll == "ripple-fim-collection" else NWS_USGS_SPEC | |
) | |
found = set() | |
for k, a in item.assets.items(): | |
if not a.href: | |
continue | |
for pat, gid_t, at_t in specs: | |
m = pat.match(k) | |
if not m: | |
continue | |
gid = m.expand(gid_t) if "\\" in gid_t else gid_t | |
at = m.expand(at_t) | |
bkt = results[short][gid] | |
if a.href not in bkt[at]: | |
bkt[at].append(a.href) | |
found.add(gid) | |
# append ripple flowfiles | |
if coll == "ripple-fim-collection": | |
for gid in found: | |
if gid in ripple_cache: | |
logger.info(f"Adding cached flowfiles for {gid}") | |
for href in ripple_cache[gid]: | |
if href not in results[short][gid]["flowfiles"]: | |
results[short][gid]["flowfiles"].append(href) | |
# 3) fallback | |
else: | |
logger.warning(f"Unknown coll '{coll}'; grouping by item.id") | |
gid = item.id | |
bkt = results[short][gid] | |
for k, a in item.assets.items(): | |
if not a.href: | |
continue | |
if "extent" in k and a.media_type and "tiff" in a.media_type: | |
bkt["extents"].append(a.href) | |
elif "flow" in k and a.media_type and "csv" in a.media_type: | |
bkt["flowfiles"].append(a.href) | |
logger.info(f"Finished formatting {len(seen)} items.") | |
return results | |
# Set tolerancce days for observation merging to 3 days because that is the outer limit for how long it should take the two sentinal 1 satellite's swaths to cover the continental US when operating under normal conditions and creating flood maps from both the ascending and descending orbits (which is what GFM does). Even if an event continues past that 3 days would still probably want to subset as another FIM grouping since the flow conditions will likely have changed dramatically by then. | |
def merge_gfm_expanded( | |
groups: Dict[str, Dict[str, List[str]]], tolerance_days: int = 3 | |
) -> Dict[str, Dict[str, List[str]]]: | |
if not groups: | |
return {} | |
tol = timedelta(days=tolerance_days) | |
ivs: List[Interval] = [] | |
for key, data in groups.items(): | |
try: | |
s, e = key.split("/") | |
st = parse_iso(s) | |
en = parse_iso(e) | |
iv = Interval(start=st, end=en) | |
for at, hs in data.items(): | |
for h in hs: | |
if h not in iv.assets[at]: | |
iv.assets[at].append(h) | |
ivs.append(iv) | |
except Exception as ex: | |
logger.warning(f"Bad GFM‐Expanded key '{key}': {ex}") | |
ivs.sort() | |
merged: List[Interval] = [] | |
cur = ivs[0] | |
for nxt in ivs[1:]: | |
if nxt.start <= cur.end + tol: | |
cur.end = max(cur.end, nxt.end) | |
for at, hs in nxt.assets.items(): | |
for h in hs: | |
if h not in cur.assets[at]: | |
cur.assets[at].append(h) | |
else: | |
merged.append(cur) | |
cur = nxt | |
merged.append(cur) | |
out: Dict[str, Dict[str, List[str]]] = {} | |
fmt = "%Y-%m-%dT%H:%M:%SZ" | |
for iv in merged: | |
key = f"{iv.start.strftime(fmt)}/{iv.end.strftime(fmt)}" | |
out[key] = iv.assets | |
return out | |
def main() -> None: | |
p = argparse.ArgumentParser("Group STAC flood items into extents/flowfiles") | |
p.add_argument("-u", "--api-url", required=True, help="STAC API root URL") | |
p.add_argument( | |
"-c", | |
"--collections", | |
nargs="+", | |
required=True, | |
help="One or more STAC collection IDs", | |
) | |
p.add_argument( | |
"-r", | |
"--roi", | |
type=Path, | |
help="GeoJSON file with exactly one Polygon/MultiPolygon", | |
) | |
p.add_argument("-d", "--datetime", help="STAC datetime or interval") | |
p.add_argument("-o", "--output-file", help="Write JSON output to this path") | |
args = p.parse_args() | |
# load & validate single‐feature ROI | |
intersects = None | |
if args.roi: | |
gj = json.loads(args.roi.read_text(encoding="utf-8")) | |
t = gj.get("type") | |
if t == "FeatureCollection": | |
feats = gj.get("features", []) | |
if len(feats) != 1: | |
logger.error("ROI FC must contain exactly one feature.") | |
return | |
intersects = feats[0]["geometry"] | |
elif t == "Feature": | |
intersects = gj["geometry"] | |
elif t in ("Polygon", "MultiPolygon"): | |
intersects = gj | |
else: | |
logger.error("ROI must be FC(1 feature), Feature, or Polygon/MultiPolygon.") | |
return | |
logger.info("Loaded single‐feature ROI for intersects filter") | |
# open STAC client | |
try: | |
client = Client.open(args.api_url) | |
logger.info(f"Connected to {args.api_url}") | |
except Exception as ex: | |
logger.error(f"Could not open STAC API: {ex}") | |
return | |
# build search kwargs (no max_items → all items) | |
search_kw = { | |
"collections": args.collections, | |
"datetime": args.datetime, | |
**({"intersects": intersects} if intersects else {}), | |
} | |
search_kw = {k: v for k, v in search_kw.items() if v is not None} | |
try: | |
logger.info(f"Searching with {search_kw}") | |
search = client.search(**search_kw) | |
items = search.items() | |
grouped = format_results(items) | |
if "gfm_expanded" in grouped: | |
grouped["gfm_expanded"] = merge_gfm_expanded(grouped["gfm_expanded"]) | |
out = dictify(grouped) | |
text = json.dumps(out, indent=2) | |
if args.output_file: | |
Path(args.output_file).write_text(text, encoding="utf-8") | |
logger.info(f"Wrote to {args.output_file}") | |
else: | |
print(text) | |
except requests.RequestException as rex: | |
logger.error(f"STAC request failed: {rex}") | |
except Exception as ex: | |
logger.exception(f"Unexpected error: {ex}") | |
if __name__ == "__main__": | |
""" | |
Group STAC flood‐model items into extents/flowfiles by: | |
• gfm → group by dfo_event_id | |
• gfm‐expanded → initial time‐range, then merge close intervals | |
• hwm → each item is its own survey | |
• ble/nws/ripple/usgs → asset‐level grouping by regex | |
Spatial filter via a single‐feature GeoJSON ROI (--roi). | |
""" | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment