Created
May 1, 2026 10:12
-
-
Save saagarjha/f5b387e843a74fa27a666746453a939d to your computer and use it in GitHub Desktop.
Fetches CUDA headers for clangd to be happy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """Fetch a minimal CUDA headers tree for clangd, from NVIDIA's apt repo. | |
| Stdlib-only. Works anywhere with Python 3.9+. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import gzip | |
| import io | |
| import lzma | |
| import shutil | |
| import sys | |
| import tarfile | |
| import urllib.request | |
| from pathlib import Path | |
| REPO = "https://developer.download.nvidia.com/compute/cuda/repos" | |
| # Determined by parsing clang/lib/Headers/__clang_cuda_runtime_wrapper.h: every | |
| # #include "..." it pulls in either lives in one of these packages, or is | |
| # guarded by #if CUDA_VERSION < 9000. | |
| WANT_PACKAGES = [ | |
| "cuda-cudart-dev", # cuda.h, cuda_runtime.h, sm_*_*, texture/surface_indirect, etc. | |
| "cuda-crt", # crt/host_runtime.h, crt/math_functions.hpp, crt/device_* | |
| "cuda-cccl", # libcu++ (cuda/std/*); transitively required by cooperative_groups | |
| "libcurand-dev", # curand_mtgp32_kernel.h (force-included by wrapper) | |
| ] | |
| # clang 21 fully supports through 12.8, partially through 12.9. CUDA 13+ removed | |
| # deprecated headers (texture_fetch_functions.h etc.) that clang's wrapper still | |
| # unconditionally includes. | |
| DEFAULT_VERSION = "12-9" | |
| def fetch(url: str) -> bytes: | |
| with urllib.request.urlopen(url) as r: | |
| return r.read() | |
| def load_packages_index(distro: str, arch: str) -> list[dict[str, str]]: | |
| """Download and parse the apt Packages index.""" | |
| base = f"{REPO}/{distro}/{arch}" | |
| for name, decoder in [ | |
| ("Packages.gz", gzip.decompress), | |
| ("Packages.xz", lzma.decompress), | |
| ]: | |
| try: | |
| raw = decoder(fetch(f"{base}/{name}")) | |
| break | |
| except Exception: | |
| continue | |
| else: | |
| raise RuntimeError(f"no Packages index at {base}") | |
| entries = [] | |
| for block in raw.decode("utf-8", "replace").split("\n\n"): | |
| if not block.strip(): | |
| continue | |
| d: dict[str, str] = {} | |
| for line in block.splitlines(): | |
| if ":" in line and not line.startswith((" ", "\t")): | |
| key, _, val = line.partition(":") | |
| d[key.strip()] = val.strip() | |
| entries.append(d) | |
| return entries | |
| def pick_version(entries: list[dict[str, str]], pkg: str) -> dict[str, str]: | |
| """Latest deb for an exact package name (e.g. 'cuda-cudart-dev-12-9').""" | |
| matches = [e for e in entries if e.get("Package") == pkg] | |
| if not matches: | |
| sys.exit(f"no package {pkg!r} in index — try a different --version") | |
| return max(matches, key=lambda e: e.get("Version", "")) | |
| def iter_ar(data: bytes): | |
| """Yield (name, payload) for each member of a Unix `ar` archive.""" | |
| if data[:8] != b"!<arch>\n": | |
| raise ValueError("not an ar archive") | |
| pos = 8 | |
| while pos + 60 <= len(data): | |
| name = data[pos : pos + 16].rstrip().rstrip(b"/").decode() | |
| size = int(data[pos + 48 : pos + 58].rstrip()) | |
| pos += 60 | |
| yield name, data[pos : pos + size] | |
| pos += size + (size & 1) # 2-byte align | |
| def extract_deb_data(deb: bytes, dest: Path) -> None: | |
| """Pull data.tar.* out of a .deb and unpack it under dest.""" | |
| for name, payload in iter_ar(deb): | |
| if not name.startswith("data.tar"): | |
| continue | |
| if name.endswith(".xz"): | |
| payload = lzma.decompress(payload) | |
| elif name.endswith(".gz"): | |
| payload = gzip.decompress(payload) | |
| elif name.endswith(".zst"): | |
| sys.exit("data.tar.zst not supported by stdlib") | |
| with tarfile.open(fileobj=io.BytesIO(payload)) as tf: | |
| tf.extractall(dest, filter="data") | |
| return | |
| raise RuntimeError("no data.tar.* in .deb") | |
| def main() -> int: | |
| ap = argparse.ArgumentParser(description=__doc__.splitlines()[0]) | |
| ap.add_argument("--out", type=Path, default=Path("./cuda-headers")) | |
| ap.add_argument("--distro", default="ubuntu2404") | |
| ap.add_argument( | |
| "--arch", default="x86_64", help="x86_64, sbsa, cross-linux-sbsa, ..." | |
| ) | |
| ap.add_argument( | |
| "--version", | |
| default=DEFAULT_VERSION, | |
| help=f"CUDA version in deb naming, e.g. '12-9' (default: {DEFAULT_VERSION})", | |
| ) | |
| args = ap.parse_args() | |
| print(f"[1/3] fetching index from {args.distro}/{args.arch}") | |
| index = load_packages_index(args.distro, args.arch) | |
| staging = args.out / ".staging" | |
| if staging.exists(): | |
| shutil.rmtree(staging) | |
| staging.mkdir(parents=True) | |
| for i, pkg in enumerate(WANT_PACKAGES, 1): | |
| entry = pick_version(index, f"{pkg}-{args.version}") | |
| url = f"{REPO}/{args.distro}/{args.arch}/{entry['Filename']}" | |
| size_kb = int(entry.get("Size", "0")) // 1024 | |
| print(f"[2/3] {i}/{len(WANT_PACKAGES)} {entry['Package']} ({size_kb} KB)") | |
| extract_deb_data(fetch(url), staging) | |
| # Headers land at staging/usr/local/cuda-X.Y/targets/<arch>-linux/include | |
| src = next(staging.rglob("targets/*/include"), None) | |
| if src is None: | |
| sys.exit("no targets/<arch>/include in extracted debs — apt layout changed?") | |
| print(f"[3/3] laying out {args.out}") | |
| inc = args.out / "include" | |
| if inc.exists(): | |
| shutil.rmtree(inc) | |
| shutil.copytree(src, inc) | |
| (args.out / "bin").mkdir(exist_ok=True) # clang's detector checks for bin/ | |
| shutil.rmtree(staging) | |
| mb = sum(p.stat().st_size for p in inc.rglob("*") if p.is_file()) // 1024 // 1024 | |
| print(f"done: {args.out} ({mb} MB)") | |
| print(f"point clangd at it with: --cuda-path={args.out.resolve()} -nocudalib") | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment