Skip to content

Instantly share code, notes, and snippets.

@saagarjha
Created May 1, 2026 10:12
Show Gist options
  • Select an option

  • Save saagarjha/f5b387e843a74fa27a666746453a939d to your computer and use it in GitHub Desktop.

Select an option

Save saagarjha/f5b387e843a74fa27a666746453a939d to your computer and use it in GitHub Desktop.
Fetches CUDA headers for clangd to be happy
#!/usr/bin/env python3
"""Fetch a minimal CUDA headers tree for clangd, from NVIDIA's apt repo.
Stdlib-only. Works anywhere with Python 3.9+.
"""
from __future__ import annotations
import argparse
import gzip
import io
import lzma
import shutil
import sys
import tarfile
import urllib.request
from pathlib import Path
REPO = "https://developer.download.nvidia.com/compute/cuda/repos"
# Determined by parsing clang/lib/Headers/__clang_cuda_runtime_wrapper.h: every
# #include "..." it pulls in either lives in one of these packages, or is
# guarded by #if CUDA_VERSION < 9000.
WANT_PACKAGES = [
"cuda-cudart-dev", # cuda.h, cuda_runtime.h, sm_*_*, texture/surface_indirect, etc.
"cuda-crt", # crt/host_runtime.h, crt/math_functions.hpp, crt/device_*
"cuda-cccl", # libcu++ (cuda/std/*); transitively required by cooperative_groups
"libcurand-dev", # curand_mtgp32_kernel.h (force-included by wrapper)
]
# clang 21 fully supports through 12.8, partially through 12.9. CUDA 13+ removed
# deprecated headers (texture_fetch_functions.h etc.) that clang's wrapper still
# unconditionally includes.
DEFAULT_VERSION = "12-9"
def fetch(url: str) -> bytes:
with urllib.request.urlopen(url) as r:
return r.read()
def load_packages_index(distro: str, arch: str) -> list[dict[str, str]]:
"""Download and parse the apt Packages index."""
base = f"{REPO}/{distro}/{arch}"
for name, decoder in [
("Packages.gz", gzip.decompress),
("Packages.xz", lzma.decompress),
]:
try:
raw = decoder(fetch(f"{base}/{name}"))
break
except Exception:
continue
else:
raise RuntimeError(f"no Packages index at {base}")
entries = []
for block in raw.decode("utf-8", "replace").split("\n\n"):
if not block.strip():
continue
d: dict[str, str] = {}
for line in block.splitlines():
if ":" in line and not line.startswith((" ", "\t")):
key, _, val = line.partition(":")
d[key.strip()] = val.strip()
entries.append(d)
return entries
def pick_version(entries: list[dict[str, str]], pkg: str) -> dict[str, str]:
"""Latest deb for an exact package name (e.g. 'cuda-cudart-dev-12-9')."""
matches = [e for e in entries if e.get("Package") == pkg]
if not matches:
sys.exit(f"no package {pkg!r} in index — try a different --version")
return max(matches, key=lambda e: e.get("Version", ""))
def iter_ar(data: bytes):
"""Yield (name, payload) for each member of a Unix `ar` archive."""
if data[:8] != b"!<arch>\n":
raise ValueError("not an ar archive")
pos = 8
while pos + 60 <= len(data):
name = data[pos : pos + 16].rstrip().rstrip(b"/").decode()
size = int(data[pos + 48 : pos + 58].rstrip())
pos += 60
yield name, data[pos : pos + size]
pos += size + (size & 1) # 2-byte align
def extract_deb_data(deb: bytes, dest: Path) -> None:
"""Pull data.tar.* out of a .deb and unpack it under dest."""
for name, payload in iter_ar(deb):
if not name.startswith("data.tar"):
continue
if name.endswith(".xz"):
payload = lzma.decompress(payload)
elif name.endswith(".gz"):
payload = gzip.decompress(payload)
elif name.endswith(".zst"):
sys.exit("data.tar.zst not supported by stdlib")
with tarfile.open(fileobj=io.BytesIO(payload)) as tf:
tf.extractall(dest, filter="data")
return
raise RuntimeError("no data.tar.* in .deb")
def main() -> int:
ap = argparse.ArgumentParser(description=__doc__.splitlines()[0])
ap.add_argument("--out", type=Path, default=Path("./cuda-headers"))
ap.add_argument("--distro", default="ubuntu2404")
ap.add_argument(
"--arch", default="x86_64", help="x86_64, sbsa, cross-linux-sbsa, ..."
)
ap.add_argument(
"--version",
default=DEFAULT_VERSION,
help=f"CUDA version in deb naming, e.g. '12-9' (default: {DEFAULT_VERSION})",
)
args = ap.parse_args()
print(f"[1/3] fetching index from {args.distro}/{args.arch}")
index = load_packages_index(args.distro, args.arch)
staging = args.out / ".staging"
if staging.exists():
shutil.rmtree(staging)
staging.mkdir(parents=True)
for i, pkg in enumerate(WANT_PACKAGES, 1):
entry = pick_version(index, f"{pkg}-{args.version}")
url = f"{REPO}/{args.distro}/{args.arch}/{entry['Filename']}"
size_kb = int(entry.get("Size", "0")) // 1024
print(f"[2/3] {i}/{len(WANT_PACKAGES)} {entry['Package']} ({size_kb} KB)")
extract_deb_data(fetch(url), staging)
# Headers land at staging/usr/local/cuda-X.Y/targets/<arch>-linux/include
src = next(staging.rglob("targets/*/include"), None)
if src is None:
sys.exit("no targets/<arch>/include in extracted debs — apt layout changed?")
print(f"[3/3] laying out {args.out}")
inc = args.out / "include"
if inc.exists():
shutil.rmtree(inc)
shutil.copytree(src, inc)
(args.out / "bin").mkdir(exist_ok=True) # clang's detector checks for bin/
shutil.rmtree(staging)
mb = sum(p.stat().st_size for p in inc.rglob("*") if p.is_file()) // 1024 // 1024
print(f"done: {args.out} ({mb} MB)")
print(f"point clangd at it with: --cuda-path={args.out.resolve()} -nocudalib")
return 0
if __name__ == "__main__":
sys.exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment