Created
May 15, 2025 16:17
-
-
Save robbiemu/c224d6daf2f05d918ac8cf759db33288 to your computer and use it in GitHub Desktop.
print a grid showing availability of Pytorch wheel images for python versions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import re | |
from collections import defaultdict | |
from typing import List, Set, Tuple | |
import requests | |
from packaging.version import Version, InvalidVersion | |
# --------------------------------------------------------------------------- | |
# Constants & endpoints | |
# --------------------------------------------------------------------------- | |
ENDOFLIFE_PYTHON_JSON = "https://endoflife.date/api/python.json" | |
TORCH_PYPI_JSON = "https://pypi.org/pypi/torch/json" | |
TORCH_DOCS_VERSIONS = "https://pytorch.org/docs/versions.html" | |
WHEEL_ROOT = "https://download.pytorch.org/whl/" | |
# --------------------------------------------------------------------------- | |
# HTTP helper | |
# --------------------------------------------------------------------------- | |
def _get(url: str) -> str: | |
resp = requests.get(url, timeout=30) | |
resp.raise_for_status() | |
return resp.text | |
# --------------------------------------------------------------------------- | |
# 1. Parse single wheel listing per channel (stable & nightly) | |
# --------------------------------------------------------------------------- | |
WHEEL_DIRS = { | |
"stable": f"{WHEEL_ROOT}torch/", | |
"nightly": f"{WHEEL_ROOT}nightly/torch/" | |
} | |
# capture 2‑digit tags (cp39) and 3‑digit (cp310, cp312 …) | |
_TORCH_WHEEL_RE = re.compile( | |
r"torch-(?P<tv>[0-9a-zA-Z\.]+)\+(?P<variant>[a-z0-9_]+)[^\s]*-cp(?P<py>\d{2,3})-cp\d{2,3}.*?\.whl" | |
) | |
def _py_tag_to_version(tag: str) -> Version: | |
"""Convert a cpXY or cpXYZ tag (e.g. 39, 310) to Version 3.9 / 3.10.""" | |
if len(tag) == 2: # 38, 39 | |
return Version(f"3.{int(tag[1])}") | |
# len 3: first digit major (3), rest minor | |
return Version(f"3.{int(tag[1:])}") | |
def parse_wheel_listing(channel: str): | |
"""Return (compat_map, variant_set) for 'stable' or 'nightly'. | |
compat_map maps (torch_major, torch_minor, py_version) -> set of variants | |
""" | |
idx_url = WHEEL_DIRS[channel] | |
html = _get(idx_url) | |
compat: dict[Tuple[int, int, Version], set[str]] = defaultdict(set) | |
variants: Set[str] = set() | |
for tv_str, variant, py_tag in _TORCH_WHEEL_RE.findall(html): | |
try: | |
tv = Version(tv_str) | |
py_ver = _py_tag_to_version(py_tag) | |
key = (tv.major, tv.minor, py_ver) | |
compat[key].add(variant) | |
variants.add(variant) | |
except InvalidVersion: | |
continue | |
return compat, variants | |
# --------------------------------------------------------------------------- | |
# 2. Canonical Python & Torch versions | |
# --------------------------------------------------------------------------- | |
def fetch_python_versions(min_py: Version) -> List[Version]: | |
data = requests.get(ENDOFLIFE_PYTHON_JSON, timeout=30).json() | |
return sorted({Version(e["cycle"]) for e in data if Version(e["cycle"]) >= min_py}) | |
def _torch_versions_from_docs() -> Set[Version]: | |
html = _get(TORCH_DOCS_VERSIONS) | |
return {Version(v) for v in re.findall(r"/docs/(\d+\.\d+)/", html)} | |
def _torch_versions_from_pypi() -> Set[Version]: | |
data = requests.get(TORCH_PYPI_JSON, timeout=30).json() | |
return {Version(v) for v in data["releases"] if not Version(v).is_prerelease} | |
def fetch_torch_versions(min_torch: Version) -> List[Version]: | |
versions = _torch_versions_from_docs() | _torch_versions_from_pypi() | |
minors = { | |
Version(f"{v.major}.{v.minor}") for v in versions | |
if v >= min_torch and not (v.is_prerelease or v.is_devrelease or v.is_postrelease) | |
} | |
return sorted(minors) | |
# --------------------------------------------------------------------------- | |
# 3. Build compatibility grid | |
# --------------------------------------------------------------------------- | |
def target_satisfied(vset: Set[str], target: str) -> bool: | |
if target == "cuda": | |
return any(v.startswith("cu") for v in vset) | |
elif target == "cpu": | |
return "cpu" in vset | |
else: | |
return target in vset | |
def build_grid(min_py: Version, min_torch: Version, targets: List[str]): | |
python_versions = fetch_python_versions(min_py) | |
torch_versions_canonical = fetch_torch_versions(min_torch) | |
stable_compat, stable_variants = parse_wheel_listing("stable") | |
nightly_compat, nightly_variants = parse_wheel_listing("nightly") | |
print("Stable variants discovered:", ", ".join(sorted(stable_variants))) | |
print("Nightly variants discovered:", ", ".join(sorted(nightly_variants))) | |
# merge canonical with parsed keys, then filter by min_torch | |
all_keys = {(v.major, v.minor) for v in torch_versions_canonical} | set(stable_compat) | set(nightly_compat) | |
keys = sorted({(k[0], k[1]) for k in all_keys if Version(f"{k[0]}.{k[1]}") >= min_torch}) | |
torch_versions = [Version(f"{m}.{n}") for m, n in keys] | |
grid = [] | |
for py in python_versions: | |
row = [] | |
for tk in keys: | |
stable_variants = stable_compat.get((tk[0], tk[1], py), set()) | |
nightly_variants = nightly_compat.get((tk[0], tk[1], py), set()) | |
has_all = lambda vset: all(target_satisfied(vset, t) for t in targets) | |
if has_all(stable_variants): | |
row.append("✓") | |
elif has_all(nightly_variants): | |
row.append("!") | |
else: | |
row.append("X") | |
grid.append((py, row)) | |
return torch_versions, grid | |
# --------------------------------------------------------------------------- | |
# 4. CLI helpers (unchanged) | |
# --------------------------------------------------------------------------- | |
def _inf_min_py_from_toml(path: str) -> str | None: | |
try: | |
with open(path) as fh: | |
for ln in fh: | |
if ln.strip().startswith("requires-python"): | |
m = re.search(r'>=\s*([0-9]+\.[0-9]+)', ln) | |
if m: | |
return m.group(1) | |
except Exception: | |
pass | |
return None | |
def _inf_min_torch_from_toml(path: str) -> str | None: | |
try: | |
with open(path) as fh: | |
for ln in fh: | |
m = re.search(r'torch[^0-9]*>=\s*([0-9]+\.[0-9]+)', ln) | |
if m: | |
return m.group(1) | |
except Exception: | |
pass | |
return None | |
def _inf_min_torch_from_req(path: str) -> str | None: | |
try: | |
with open(path) as fh: | |
for ln in fh: | |
if ln.strip().startswith("torch"): | |
m = re.search(r'torch[^0-9]*([0-9]+\.[0-9]+)', ln) | |
if m: | |
return m.group(1) | |
except Exception: | |
pass | |
return None | |
def get_args(): | |
p = argparse.ArgumentParser(description="Dynamic Python⇄Torch compatibility grid") | |
p.add_argument("--min-python-version") | |
p.add_argument("--min-torch-version") | |
p.add_argument("--toml", nargs="?", const="pyproject.toml") | |
p.add_argument("--requirements", nargs="?", const="requirements.txt") | |
p.add_argument( | |
"--targets", | |
help="Comma-separated list of required build variants (e.g., cpu, cu118, rocm6)", | |
) | |
args = p.parse_args() | |
min_py = None | |
min_torch = None | |
if args.toml and os.path.exists(args.toml): | |
py = _inf_min_py_from_toml(args.toml) | |
tor = _inf_min_torch_from_toml(args.toml) | |
if py: | |
min_py = Version(py) | |
if tor: | |
min_torch = Version(tor) | |
if args.requirements and os.path.exists(args.requirements): | |
tor = _inf_min_torch_from_req(args.requirements) | |
if tor: | |
min_torch = Version(tor) | |
if args.min_python_version: | |
min_py = Version(args.min_python_version) | |
if args.min_torch_version: | |
min_torch = Version(args.min_torch_version) | |
if min_py is None: | |
min_py = Version("3.9") | |
if min_torch is None: | |
min_torch = Version("2.2") | |
targets = args.targets.split(",") if args.targets else [] | |
return min_py, min_torch, targets | |
# --------------------------------------------------------------------------- | |
# 5. grid printing | |
def print_grid(torch_versions: List[Version], grid): | |
header = ["Python \\ Torch"] + [str(tv) for tv in torch_versions] | |
print(" | ".join(header)) | |
print("-+-".join(["-" * len(h) for h in header])) | |
for py, row in grid: | |
print(" | ".join([str(py)] + row)) | |
# --------------------------------------------------------------------------- | |
# __main__ | |
# --------------------------------------------------------------------------- | |
if __name__ == "__main__": | |
min_py, min_torch, targets = get_args() | |
print("Building dynamic compatibility grid — this may take a few seconds…") | |
torch_cols, grid_rows = build_grid(min_py, min_torch, targets) | |
print_grid(torch_cols, grid_rows) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment