Skip to content

Instantly share code, notes, and snippets.

@zeryx
Last active June 9, 2026 20:54
Show Gist options
  • Select an option

  • Save zeryx/87a6b5a5cc56c7a6e0fd61d6f9a378e2 to your computer and use it in GitHub Desktop.

Select an option

Save zeryx/87a6b5a5cc56c7a6e0fd61d6f9a378e2 to your computer and use it in GitHub Desktop.
Reproducer: jax.experimental.compute_on gpu_stream annotation silent no-op in JAX 0.9.1
"""
Reproducer: does compute_on("gpu_stream:N") annotate a GEMM in JAX 0.9.1?
Hypothesis under test: a bare elementwise add (x + y) gets fused/elided by XLA,
leaving no standalone op to carry _xla_stream_annotation -> annotation disappears.
A GEMM is a real cuBLAS/cutlass kernel that won't be fused away, so it should
retain the annotation if the mechanism works at all.
We check BOTH:
- lowered (pre-compilation StableHLO) .lower().as_text()
- compiled (post-optimization HLO) .lower().compile().as_text()
Environment:
- JAX 0.9.1 (pip install "jax[cuda12]==0.9.1")
- XLA_FLAGS must be set BEFORE importing jax.
Run: python compute_on_repro.py
"""
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_experimental_stream_annotation=true"
import jax
import jax.numpy as jnp
from jax.experimental.compute_on import compute_on
from jax._src.compute_on import compute_on2 # not publicly exported in 0.9.1
print(f"JAX {jax.__version__}")
x = jnp.ones((1024, 1024), jnp.float32)
w = jnp.ones((1024, 1024), jnp.float32)
b = jnp.ones((1024,), jnp.float32)
def report(label, lowered):
lo = lowered.as_text()
co = lowered.compile().as_text()
print(f"{label}")
print(f" annotation in LOWERED HLO = {'_xla_stream_annotation' in lo}")
print(f" annotation in COMPILED HLO = {'_xla_stream_annotation' in co}")
return lo, co
# ── A: compute_on2 decorator wrapping a GEMM + bias add ─────────────────────
@compute_on2(compute_type="gpu_stream:1", out_memory_spaces=jax.memory.Space.Device)
def gemm_add(x, w, b):
return x @ w + b
@jax.jit
def f_decorator(x, w, b):
return gemm_add(x, w, b)
lo_a, co_a = report("A: compute_on2(GEMM+add)", jax.jit(f_decorator).lower(x, w, b))
# ── B: context manager INSIDE jit around a GEMM + bias add ──────────────────
@jax.jit
def f_inside(x, w, b):
with compute_on("gpu_stream:1"):
return x @ w + b
lo_b, co_b = report("B: ctx-inside(GEMM+add)", jax.jit(f_inside).lower(x, w, b))
# ── Diagnostics: show any annotation / custom-call / frontend lines ─────────
for tag, txt in [("A-LOWERED", lo_a), ("A-COMPILED", co_a),
("B-LOWERED", lo_b), ("B-COMPILED", co_b)]:
hits = [l.strip()[:180] for l in txt.splitlines()
if any(k in l for k in ("_xla_stream_annotation", "frontend_attributes",
"compute_on", "custom_call", "stream"))]
print(f"\n--- {tag}: {len(hits)} relevant line(s) ---")
for h in hits[:12]:
print(" ", h)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment