Last active
June 10, 2026 19:32
-
-
Save zeryx/248efdfcf23a37a1e26ce8e9480e0552 to your computer and use it in GitHub Desktop.
JAX compute_on2: place a Transformer Engine matmul+add on a chosen CUDA stream, verified with nsys cuda_gpu_trace
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Parse nsys cuda_gpu_trace CSV; report which CUDA stream each kernel class ran on.""" | |
| import csv, sys | |
| from collections import defaultdict | |
| path = sys.argv[1] | |
| def classify(name): | |
| n = name.lower() | |
| if any(k in n for k in ("nccl", "allgather", "reducescatter", "sendrecv")): | |
| return "NCCL-comm" | |
| if any(k in n for k in ("gemm", "cutlass", "sm90", "sm100", "sm120", "cublas", | |
| "ampere", "volta", "turing", "s16816", "s1688", "implicit", "dot")): | |
| return "GEMM-compute" | |
| if "memcpy" in n or "memset" in n: | |
| return "memcpy/set" | |
| return "other" | |
| with open(path, newline="") as f: | |
| rows = list(csv.reader(f)) | |
| hdr_idx = None | |
| for i, r in enumerate(rows): | |
| j = ",".join(c.lower() for c in r) | |
| if ("strm" in j or "stream" in j) and "name" in j: | |
| hdr_idx = i; break | |
| if hdr_idx is None: | |
| print(" !! no stream column; first rows:", rows[:3]); sys.exit(0) | |
| hdr = [c.strip().lower() for c in rows[hdr_idx]] | |
| def col(*cands): | |
| for c in cands: | |
| if c in hdr: return hdr.index(c) | |
| for idx, h in enumerate(hdr): | |
| if any(c in h for c in cands): return idx | |
| return None | |
| strm_i, name_i = col("strm", "stream"), col("name") | |
| dur_i, dev_i = col("duration (ns)", "duration"), col("device") | |
| agg = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: [0, 0.0]))) | |
| for r in rows[hdr_idx+1:]: | |
| if len(r) <= max(strm_i, name_i): continue | |
| s = r[strm_i].strip() | |
| if not s or not s.lstrip("-").isdigit(): continue | |
| dev = r[dev_i].strip() if (dev_i is not None and dev_i < len(r)) else "?" | |
| cls = classify(r[name_i]) | |
| dur = 0.0 | |
| if dur_i is not None and dur_i < len(r): | |
| try: dur = float(r[dur_i].replace(",", "")) | |
| except ValueError: dur = 0.0 | |
| a = agg[dev][int(s)][cls]; a[0]+=1; a[1]+=dur | |
| total = sum(c[0] for d in agg.values() for s in d.values() for c in s.values()) | |
| print(f" parsed {total} kernel records across {len(agg)} device(s)") | |
| for dev in sorted(agg): | |
| print(f"\n ===== device: {dev} =====") | |
| print(f" {'stream':>8} | {'class':<14} | {'count':>7} | {'gpu_ms':>10}") | |
| print(f" {'-'*8}-+-{'-'*14}-+-{'-'*7}-+-{'-'*10}") | |
| dagg = agg[dev] | |
| for s in sorted(dagg): | |
| for cls in sorted(dagg[s], key=lambda c: -dagg[s][c][1]): | |
| cnt, dur = dagg[s][cls] | |
| print(f" {s:>8} | {cls:<14} | {cnt:>7} | {dur/1e6:>10.3f}") | |
| gemm_streams = sorted({s for s in dagg if "GEMM-compute" in dagg[s]}) | |
| print(f" GEMM-compute kernels ran on stream(s): {gemm_streams} (count={len(gemm_streams)})") | |
| if len(gemm_streams) >= 2: | |
| print(" RESULT: GEMMs span MULTIPLE CUDA streams -> stream annotation honored at RUNTIME") | |
| else: | |
| print(" RESULT: all GEMMs on ONE stream -> annotation NOT honored at runtime") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/bin/bash | |
| # Profile the compute_on2 + TE matmul+add script and report per-kernel CUDA streams. | |
| set -e | |
| cd /work | |
| nsys profile -o /work/te --force-overwrite true --trace=cuda,nvtx --sample=none \ | |
| python /work/gist_te_compute_on2.py | |
| nsys stats --report cuda_gpu_trace --format csv --force-export true /work/te.nsys-rep > /work/te.csv 2>/dev/null | |
| python /work/analyze.py /work/te.csv |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Place a Transformer Engine matmul+add (dense w/ bias) on a chosen CUDA stream | |
| with jax._src.compute_on.compute_on2, and prove it at runtime with nsys. | |
| Why compute_on2 and not the public `compute_on` context manager? | |
| The context manager only stamps `_xla_stream_annotation` onto ops in the HLO; | |
| XLA is free to leave them on the default stream (verified: no runtime shift). | |
| compute_on2 lowers the wrapped region to a separate, non-inlineable called | |
| function with the annotation on the call + explicit output memory space — the | |
| structural form XLA actually turns into a side CUDA stream. | |
| Note on XLA_FLAGS: --xla_gpu_experimental_stream_annotation=true is NOT required | |
| for compute_on2 (verified: the stream split happens with the flag unset). That | |
| flag only gates the pass that acts on the _xla_stream_annotation attribute used | |
| by the public `compute_on` context manager; compute_on2 relocates via its | |
| async-region + output-memory-space lowering, independent of the flag. | |
| Run (single GPU): | |
| $ nsys profile -o te --trace=cuda,nvtx --sample=none python this_file.py | |
| $ nsys stats --report cuda_gpu_trace --format csv te.nsys-rep > te.csv | |
| $ python analyze.py te.csv # prints the CUDA stream per kernel class | |
| Expected: the te_gemm_ffi GEMM kernels run on TWO distinct CUDA streams | |
| (default + the compute_on2 stream). If you instead used the `compute_on` | |
| context manager, they would all share ONE stream. | |
| """ | |
| import os | |
| # Note: --xla_gpu_experimental_stream_annotation=true is NOT needed for compute_on2 | |
| # (see header). It is only consulted by the public `compute_on` context-manager | |
| # path, and would have to be exported before importing jax (XLA reads XLA_FLAGS at | |
| # backend init). We deliberately do not set it here. | |
| import jax | |
| import jax.numpy as jnp | |
| from jax._src.compute_on import compute_on2 # internal API (not re-exported) | |
| from transformer_engine.jax.dense import dense as te_dense | |
| print("jax", jax.__version__, "| device", jax.devices()[0], flush=True) | |
| M = K = N = 2048 | |
| x = jnp.ones((M, K), jnp.bfloat16) | |
| k0 = jnp.ones((K, N), jnp.bfloat16) | |
| k1 = (jnp.ones((K, N), jnp.bfloat16) * 2).astype(jnp.bfloat16) # distinct -> no CSE merge | |
| b = jnp.ones((N,), jnp.bfloat16) | |
| # matmul + add (TE dense with bias), pinned to gpu_stream:1. | |
| @compute_on2(compute_type="gpu_stream:1", out_memory_spaces=jax.memory.Space.Device) | |
| def te_matmul_add_on_stream1(x, k, b): | |
| return te_dense(x, k, b) | |
| @jax.jit | |
| def f(x, k0, k1, b): | |
| default = te_dense(x, k0, b) # default stream | |
| moved = te_matmul_add_on_stream1(x, k1, b) # compute_on2 -> CUDA stream 1 | |
| return default + moved | |
| # Confirm the annotation reached the te_gemm_ffi custom-call in compiled HLO. | |
| co = jax.jit(f).lower(x, k0, k1, b).compile().as_text() | |
| print("annotation in compiled HLO =", '_xla_stream_annotation="1"' in co, | |
| "| te_gemm_ffi custom-calls =", co.count("te_gemm_ffi"), flush=True) | |
| out = f(x, k0, k1, b) | |
| jax.block_until_ready(out) | |
| for i in range(20): # iterations so the trace is unambiguous | |
| with jax.profiler.TraceAnnotation(f"te_matmul_add_{i}"): | |
| out = f(x, k0, k1, b) | |
| jax.block_until_ready(out) | |
| print("done | out[0,0] =", float(out[0, 0]), flush=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment