Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save robbiemu/c224d6daf2f05d918ac8cf759db33288 to your computer and use it in GitHub Desktop.
Save robbiemu/c224d6daf2f05d918ac8cf759db33288 to your computer and use it in GitHub Desktop.
print a grid showing availability of Pytorch wheel images for python versions
import argparse
import os
import re
from collections import defaultdict
from typing import List, Set, Tuple
import requests
from packaging.version import Version, InvalidVersion
# ---------------------------------------------------------------------------
# Constants & endpoints
# ---------------------------------------------------------------------------
ENDOFLIFE_PYTHON_JSON = "https://endoflife.date/api/python.json"
TORCH_PYPI_JSON = "https://pypi.org/pypi/torch/json"
TORCH_DOCS_VERSIONS = "https://pytorch.org/docs/versions.html"
WHEEL_ROOT = "https://download.pytorch.org/whl/"
# ---------------------------------------------------------------------------
# HTTP helper
# ---------------------------------------------------------------------------
def _get(url: str) -> str:
resp = requests.get(url, timeout=30)
resp.raise_for_status()
return resp.text
# ---------------------------------------------------------------------------
# 1. Parse single wheel listing per channel (stable & nightly)
# ---------------------------------------------------------------------------
WHEEL_DIRS = {
"stable": f"{WHEEL_ROOT}torch/",
"nightly": f"{WHEEL_ROOT}nightly/torch/"
}
# capture 2‑digit tags (cp39) and 3‑digit (cp310, cp312 …)
_TORCH_WHEEL_RE = re.compile(
r"torch-(?P<tv>[0-9a-zA-Z\.]+)\+(?P<variant>[a-z0-9_]+)[^\s]*-cp(?P<py>\d{2,3})-cp\d{2,3}.*?\.whl"
)
def _py_tag_to_version(tag: str) -> Version:
"""Convert a cpXY or cpXYZ tag (e.g. 39, 310) to Version 3.9 / 3.10."""
if len(tag) == 2: # 38, 39
return Version(f"3.{int(tag[1])}")
# len 3: first digit major (3), rest minor
return Version(f"3.{int(tag[1:])}")
def parse_wheel_listing(channel: str):
"""Return (compat_map, variant_set) for 'stable' or 'nightly'.
compat_map maps (torch_major, torch_minor, py_version) -> set of variants
"""
idx_url = WHEEL_DIRS[channel]
html = _get(idx_url)
compat: dict[Tuple[int, int, Version], set[str]] = defaultdict(set)
variants: Set[str] = set()
for tv_str, variant, py_tag in _TORCH_WHEEL_RE.findall(html):
try:
tv = Version(tv_str)
py_ver = _py_tag_to_version(py_tag)
key = (tv.major, tv.minor, py_ver)
compat[key].add(variant)
variants.add(variant)
except InvalidVersion:
continue
return compat, variants
# ---------------------------------------------------------------------------
# 2. Canonical Python & Torch versions
# ---------------------------------------------------------------------------
def fetch_python_versions(min_py: Version) -> List[Version]:
data = requests.get(ENDOFLIFE_PYTHON_JSON, timeout=30).json()
return sorted({Version(e["cycle"]) for e in data if Version(e["cycle"]) >= min_py})
def _torch_versions_from_docs() -> Set[Version]:
html = _get(TORCH_DOCS_VERSIONS)
return {Version(v) for v in re.findall(r"/docs/(\d+\.\d+)/", html)}
def _torch_versions_from_pypi() -> Set[Version]:
data = requests.get(TORCH_PYPI_JSON, timeout=30).json()
return {Version(v) for v in data["releases"] if not Version(v).is_prerelease}
def fetch_torch_versions(min_torch: Version) -> List[Version]:
versions = _torch_versions_from_docs() | _torch_versions_from_pypi()
minors = {
Version(f"{v.major}.{v.minor}") for v in versions
if v >= min_torch and not (v.is_prerelease or v.is_devrelease or v.is_postrelease)
}
return sorted(minors)
# ---------------------------------------------------------------------------
# 3. Build compatibility grid
# ---------------------------------------------------------------------------
def target_satisfied(vset: Set[str], target: str) -> bool:
if target == "cuda":
return any(v.startswith("cu") for v in vset)
elif target == "cpu":
return "cpu" in vset
else:
return target in vset
def build_grid(min_py: Version, min_torch: Version, targets: List[str]):
python_versions = fetch_python_versions(min_py)
torch_versions_canonical = fetch_torch_versions(min_torch)
stable_compat, stable_variants = parse_wheel_listing("stable")
nightly_compat, nightly_variants = parse_wheel_listing("nightly")
print("Stable variants discovered:", ", ".join(sorted(stable_variants)))
print("Nightly variants discovered:", ", ".join(sorted(nightly_variants)))
# merge canonical with parsed keys, then filter by min_torch
all_keys = {(v.major, v.minor) for v in torch_versions_canonical} | set(stable_compat) | set(nightly_compat)
keys = sorted({(k[0], k[1]) for k in all_keys if Version(f"{k[0]}.{k[1]}") >= min_torch})
torch_versions = [Version(f"{m}.{n}") for m, n in keys]
grid = []
for py in python_versions:
row = []
for tk in keys:
stable_variants = stable_compat.get((tk[0], tk[1], py), set())
nightly_variants = nightly_compat.get((tk[0], tk[1], py), set())
has_all = lambda vset: all(target_satisfied(vset, t) for t in targets)
if has_all(stable_variants):
row.append("✓")
elif has_all(nightly_variants):
row.append("!")
else:
row.append("X")
grid.append((py, row))
return torch_versions, grid
# ---------------------------------------------------------------------------
# 4. CLI helpers (unchanged)
# ---------------------------------------------------------------------------
def _inf_min_py_from_toml(path: str) -> str | None:
try:
with open(path) as fh:
for ln in fh:
if ln.strip().startswith("requires-python"):
m = re.search(r'>=\s*([0-9]+\.[0-9]+)', ln)
if m:
return m.group(1)
except Exception:
pass
return None
def _inf_min_torch_from_toml(path: str) -> str | None:
try:
with open(path) as fh:
for ln in fh:
m = re.search(r'torch[^0-9]*>=\s*([0-9]+\.[0-9]+)', ln)
if m:
return m.group(1)
except Exception:
pass
return None
def _inf_min_torch_from_req(path: str) -> str | None:
try:
with open(path) as fh:
for ln in fh:
if ln.strip().startswith("torch"):
m = re.search(r'torch[^0-9]*([0-9]+\.[0-9]+)', ln)
if m:
return m.group(1)
except Exception:
pass
return None
def get_args():
p = argparse.ArgumentParser(description="Dynamic Python⇄Torch compatibility grid")
p.add_argument("--min-python-version")
p.add_argument("--min-torch-version")
p.add_argument("--toml", nargs="?", const="pyproject.toml")
p.add_argument("--requirements", nargs="?", const="requirements.txt")
p.add_argument(
"--targets",
help="Comma-separated list of required build variants (e.g., cpu, cu118, rocm6)",
)
args = p.parse_args()
min_py = None
min_torch = None
if args.toml and os.path.exists(args.toml):
py = _inf_min_py_from_toml(args.toml)
tor = _inf_min_torch_from_toml(args.toml)
if py:
min_py = Version(py)
if tor:
min_torch = Version(tor)
if args.requirements and os.path.exists(args.requirements):
tor = _inf_min_torch_from_req(args.requirements)
if tor:
min_torch = Version(tor)
if args.min_python_version:
min_py = Version(args.min_python_version)
if args.min_torch_version:
min_torch = Version(args.min_torch_version)
if min_py is None:
min_py = Version("3.9")
if min_torch is None:
min_torch = Version("2.2")
targets = args.targets.split(",") if args.targets else []
return min_py, min_torch, targets
# ---------------------------------------------------------------------------
# 5. grid printing
def print_grid(torch_versions: List[Version], grid):
header = ["Python \\ Torch"] + [str(tv) for tv in torch_versions]
print(" | ".join(header))
print("-+-".join(["-" * len(h) for h in header]))
for py, row in grid:
print(" | ".join([str(py)] + row))
# ---------------------------------------------------------------------------
# __main__
# ---------------------------------------------------------------------------
if __name__ == "__main__":
min_py, min_torch, targets = get_args()
print("Building dynamic compatibility grid — this may take a few seconds…")
torch_cols, grid_rows = build_grid(min_py, min_torch, targets)
print_grid(torch_cols, grid_rows)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment