Last active
June 9, 2026 20:54
-
-
Save zeryx/87a6b5a5cc56c7a6e0fd61d6f9a378e2 to your computer and use it in GitHub Desktop.
Reproducer: jax.experimental.compute_on gpu_stream annotation silent no-op in JAX 0.9.1
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Reproducer: does compute_on("gpu_stream:N") annotate a GEMM in JAX 0.9.1? | |
| Hypothesis under test: a bare elementwise add (x + y) gets fused/elided by XLA, | |
| leaving no standalone op to carry _xla_stream_annotation -> annotation disappears. | |
| A GEMM is a real cuBLAS/cutlass kernel that won't be fused away, so it should | |
| retain the annotation if the mechanism works at all. | |
| We check BOTH: | |
| - lowered (pre-compilation StableHLO) .lower().as_text() | |
| - compiled (post-optimization HLO) .lower().compile().as_text() | |
| Environment: | |
| - JAX 0.9.1 (pip install "jax[cuda12]==0.9.1") | |
| - XLA_FLAGS must be set BEFORE importing jax. | |
| Run: python compute_on_repro.py | |
| """ | |
| import os | |
| os.environ["XLA_FLAGS"] = "--xla_gpu_experimental_stream_annotation=true" | |
| import jax | |
| import jax.numpy as jnp | |
| from jax.experimental.compute_on import compute_on | |
| from jax._src.compute_on import compute_on2 # not publicly exported in 0.9.1 | |
| print(f"JAX {jax.__version__}") | |
| x = jnp.ones((1024, 1024), jnp.float32) | |
| w = jnp.ones((1024, 1024), jnp.float32) | |
| b = jnp.ones((1024,), jnp.float32) | |
| def report(label, lowered): | |
| lo = lowered.as_text() | |
| co = lowered.compile().as_text() | |
| print(f"{label}") | |
| print(f" annotation in LOWERED HLO = {'_xla_stream_annotation' in lo}") | |
| print(f" annotation in COMPILED HLO = {'_xla_stream_annotation' in co}") | |
| return lo, co | |
| # ── A: compute_on2 decorator wrapping a GEMM + bias add ───────────────────── | |
| @compute_on2(compute_type="gpu_stream:1", out_memory_spaces=jax.memory.Space.Device) | |
| def gemm_add(x, w, b): | |
| return x @ w + b | |
| @jax.jit | |
| def f_decorator(x, w, b): | |
| return gemm_add(x, w, b) | |
| lo_a, co_a = report("A: compute_on2(GEMM+add)", jax.jit(f_decorator).lower(x, w, b)) | |
| # ── B: context manager INSIDE jit around a GEMM + bias add ────────────────── | |
| @jax.jit | |
| def f_inside(x, w, b): | |
| with compute_on("gpu_stream:1"): | |
| return x @ w + b | |
| lo_b, co_b = report("B: ctx-inside(GEMM+add)", jax.jit(f_inside).lower(x, w, b)) | |
| # ── Diagnostics: show any annotation / custom-call / frontend lines ───────── | |
| for tag, txt in [("A-LOWERED", lo_a), ("A-COMPILED", co_a), | |
| ("B-LOWERED", lo_b), ("B-COMPILED", co_b)]: | |
| hits = [l.strip()[:180] for l in txt.splitlines() | |
| if any(k in l for k in ("_xla_stream_annotation", "frontend_attributes", | |
| "compute_on", "custom_call", "stream"))] | |
| print(f"\n--- {tag}: {len(hits)} relevant line(s) ---") | |
| for h in hits[:12]: | |
| print(" ", h) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment