Skip to content

Instantly share code, notes, and snippets.

@zeryx
Last active June 10, 2026 19:32
Show Gist options
  • Select an option

  • Save zeryx/248efdfcf23a37a1e26ce8e9480e0552 to your computer and use it in GitHub Desktop.

Select an option

Save zeryx/248efdfcf23a37a1e26ce8e9480e0552 to your computer and use it in GitHub Desktop.
JAX compute_on2: place a Transformer Engine matmul+add on a chosen CUDA stream, verified with nsys cuda_gpu_trace
"""Parse nsys cuda_gpu_trace CSV; report which CUDA stream each kernel class ran on."""
import csv, sys
from collections import defaultdict
path = sys.argv[1]
def classify(name):
n = name.lower()
if any(k in n for k in ("nccl", "allgather", "reducescatter", "sendrecv")):
return "NCCL-comm"
if any(k in n for k in ("gemm", "cutlass", "sm90", "sm100", "sm120", "cublas",
"ampere", "volta", "turing", "s16816", "s1688", "implicit", "dot")):
return "GEMM-compute"
if "memcpy" in n or "memset" in n:
return "memcpy/set"
return "other"
with open(path, newline="") as f:
rows = list(csv.reader(f))
hdr_idx = None
for i, r in enumerate(rows):
j = ",".join(c.lower() for c in r)
if ("strm" in j or "stream" in j) and "name" in j:
hdr_idx = i; break
if hdr_idx is None:
print(" !! no stream column; first rows:", rows[:3]); sys.exit(0)
hdr = [c.strip().lower() for c in rows[hdr_idx]]
def col(*cands):
for c in cands:
if c in hdr: return hdr.index(c)
for idx, h in enumerate(hdr):
if any(c in h for c in cands): return idx
return None
strm_i, name_i = col("strm", "stream"), col("name")
dur_i, dev_i = col("duration (ns)", "duration"), col("device")
agg = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: [0, 0.0])))
for r in rows[hdr_idx+1:]:
if len(r) <= max(strm_i, name_i): continue
s = r[strm_i].strip()
if not s or not s.lstrip("-").isdigit(): continue
dev = r[dev_i].strip() if (dev_i is not None and dev_i < len(r)) else "?"
cls = classify(r[name_i])
dur = 0.0
if dur_i is not None and dur_i < len(r):
try: dur = float(r[dur_i].replace(",", ""))
except ValueError: dur = 0.0
a = agg[dev][int(s)][cls]; a[0]+=1; a[1]+=dur
total = sum(c[0] for d in agg.values() for s in d.values() for c in s.values())
print(f" parsed {total} kernel records across {len(agg)} device(s)")
for dev in sorted(agg):
print(f"\n ===== device: {dev} =====")
print(f" {'stream':>8} | {'class':<14} | {'count':>7} | {'gpu_ms':>10}")
print(f" {'-'*8}-+-{'-'*14}-+-{'-'*7}-+-{'-'*10}")
dagg = agg[dev]
for s in sorted(dagg):
for cls in sorted(dagg[s], key=lambda c: -dagg[s][c][1]):
cnt, dur = dagg[s][cls]
print(f" {s:>8} | {cls:<14} | {cnt:>7} | {dur/1e6:>10.3f}")
gemm_streams = sorted({s for s in dagg if "GEMM-compute" in dagg[s]})
print(f" GEMM-compute kernels ran on stream(s): {gemm_streams} (count={len(gemm_streams)})")
if len(gemm_streams) >= 2:
print(" RESULT: GEMMs span MULTIPLE CUDA streams -> stream annotation honored at RUNTIME")
else:
print(" RESULT: all GEMMs on ONE stream -> annotation NOT honored at runtime")
#!/bin/bash
# Profile the compute_on2 + TE matmul+add script and report per-kernel CUDA streams.
set -e
cd /work
nsys profile -o /work/te --force-overwrite true --trace=cuda,nvtx --sample=none \
python /work/gist_te_compute_on2.py
nsys stats --report cuda_gpu_trace --format csv --force-export true /work/te.nsys-rep > /work/te.csv 2>/dev/null
python /work/analyze.py /work/te.csv
"""
Place a Transformer Engine matmul+add (dense w/ bias) on a chosen CUDA stream
with jax._src.compute_on.compute_on2, and prove it at runtime with nsys.
Why compute_on2 and not the public `compute_on` context manager?
The context manager only stamps `_xla_stream_annotation` onto ops in the HLO;
XLA is free to leave them on the default stream (verified: no runtime shift).
compute_on2 lowers the wrapped region to a separate, non-inlineable called
function with the annotation on the call + explicit output memory space — the
structural form XLA actually turns into a side CUDA stream.
Note on XLA_FLAGS: --xla_gpu_experimental_stream_annotation=true is NOT required
for compute_on2 (verified: the stream split happens with the flag unset). That
flag only gates the pass that acts on the _xla_stream_annotation attribute used
by the public `compute_on` context manager; compute_on2 relocates via its
async-region + output-memory-space lowering, independent of the flag.
Run (single GPU):
$ nsys profile -o te --trace=cuda,nvtx --sample=none python this_file.py
$ nsys stats --report cuda_gpu_trace --format csv te.nsys-rep > te.csv
$ python analyze.py te.csv # prints the CUDA stream per kernel class
Expected: the te_gemm_ffi GEMM kernels run on TWO distinct CUDA streams
(default + the compute_on2 stream). If you instead used the `compute_on`
context manager, they would all share ONE stream.
"""
import os
# Note: --xla_gpu_experimental_stream_annotation=true is NOT needed for compute_on2
# (see header). It is only consulted by the public `compute_on` context-manager
# path, and would have to be exported before importing jax (XLA reads XLA_FLAGS at
# backend init). We deliberately do not set it here.
import jax
import jax.numpy as jnp
from jax._src.compute_on import compute_on2 # internal API (not re-exported)
from transformer_engine.jax.dense import dense as te_dense
print("jax", jax.__version__, "| device", jax.devices()[0], flush=True)
M = K = N = 2048
x = jnp.ones((M, K), jnp.bfloat16)
k0 = jnp.ones((K, N), jnp.bfloat16)
k1 = (jnp.ones((K, N), jnp.bfloat16) * 2).astype(jnp.bfloat16) # distinct -> no CSE merge
b = jnp.ones((N,), jnp.bfloat16)
# matmul + add (TE dense with bias), pinned to gpu_stream:1.
@compute_on2(compute_type="gpu_stream:1", out_memory_spaces=jax.memory.Space.Device)
def te_matmul_add_on_stream1(x, k, b):
return te_dense(x, k, b)
@jax.jit
def f(x, k0, k1, b):
default = te_dense(x, k0, b) # default stream
moved = te_matmul_add_on_stream1(x, k1, b) # compute_on2 -> CUDA stream 1
return default + moved
# Confirm the annotation reached the te_gemm_ffi custom-call in compiled HLO.
co = jax.jit(f).lower(x, k0, k1, b).compile().as_text()
print("annotation in compiled HLO =", '_xla_stream_annotation="1"' in co,
"| te_gemm_ffi custom-calls =", co.count("te_gemm_ffi"), flush=True)
out = f(x, k0, k1, b)
jax.block_until_ready(out)
for i in range(20): # iterations so the trace is unambiguous
with jax.profiler.TraceAnnotation(f"te_matmul_add_{i}"):
out = f(x, k0, k1, b)
jax.block_until_ready(out)
print("done | out[0,0] =", float(out[0, 0]), flush=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment