Skip to content

Instantly share code, notes, and snippets.

@zeryx
Last active June 2, 2026 21:03
Show Gist options
  • Select an option

  • Save zeryx/d91336d808b1d6b16b72176765af439b to your computer and use it in GitHub Desktop.

Select an option

Save zeryx/d91336d808b1d6b16b72176765af439b to your computer and use it in GitHub Desktop.
Recipe: automatic NCCL symmetric buffers in XLA/JAX (pure Python, no rebuild) — xla_gpu_experimental_enable_nccl_symmetric_buffers

Recipe: automatic NCCL symmetric buffers in XLA/JAX (pure Python)

Turn on XLA's automatic NCCL symmetric-buffer registration for its built-in collectives (psum, all-reduce, all-gather, …) — no custom C++, no rebuild, runs on stock jaxlib. XLA window-registers the collective buffers for you via ncclCommWindowRegister(..., NCCL_WIN_COLL_SYMMETRIC).

Verified on 2× NVIDIA RTX PRO 6000 Blackwell (sm_120) with the jax-toolbox image ghcr.io/nvidia/jax:jax-2026-06-02 (jax/jaxlib 0.10.2.dev20260602, NCCL 2.28.8).

Want your own kernel to receive symmetric buffers and call FindSymmetricMemory / ncclGetLsaPointer? That's the manual path and it requires building C++ inside XLA — see the companion gist: https://gist.github.com/zeryx/eb3f5daf23bb50d9194a6388bae65abd


The one knob

XLA_FLAGS="--xla_gpu_experimental_enable_nccl_symmetric_buffers=true \
           --xla_gpu_enable_nccl_user_buffers=true"

With those set, any built-in collective that runs on >=2 Hopper+/Blackwell GPUs gets its collective buffers registered as NCCL symmetric memory automatically. That's the whole API surface — you don't touch the symmetric pointers yourself (XLA's runtime does, internally).


Files

  • symmetric_buffer_demo.pyself-contained: sets the flags itself, runs a pmap all-reduce, then captures XLA's/NCCL's C++ log stream and prints the actual ncclCommWindowRegister / Register symmetric buffer events, asserting both that registration happened and that the math is correct.
  • verify_symmetric.py — minimal correctness check across a few sizes (a "minimal yet large map"); pair it with the flags + debug env below.

Run it

docker run --rm --runtime=nvidia --gpus all --ipc=host \
  --ulimit memlock=-1 --ulimit stack=67108864 --shm-size 16g \
  -v "$PWD:/work" -w /work \
  ghcr.io/nvidia/jax:jax-2026-06-02 \
  python symmetric_buffer_demo.py

symmetric_buffer_demo.py sets the XLA flags + debug logging internally, so the above is all you need. For verify_symmetric.py, pass the flags yourself:

docker run --rm --runtime=nvidia --gpus all --ipc=host \
  --ulimit memlock=-1 --ulimit stack=67108864 --shm-size 16g \
  -v "$PWD:/work" -w /work \
  -e XLA_PYTHON_CLIENT_PREALLOCATE=false -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.10 \
  -e XLA_FLAGS="--xla_gpu_experimental_enable_nccl_symmetric_buffers=true --xla_gpu_enable_nccl_user_buffers=true" \
  -e NCCL_DEBUG=INFO -e NCCL_DEBUG_SUBSYS=INIT,REG \
  -e TF_CPP_MIN_LOG_LEVEL=0 -e TF_CPP_VMODULE=nccl_symmetric_memory=3,nccl_communicator=3 \
  ghcr.io/nvidia/jax:jax-2026-06-02 \
  python verify_symmetric.py

Expected evidence

nccl_communicator.cc:444] [0] Register symmetric buffer for NCCL communicator; buffer=0x402000000; size=33554432; comm=...
NCCL INFO Symmetric VA size=96GB
NCCL INFO register comm ... buffer 0x402000000 size 33554432
[result] ALL ALL-REDUCES CORRECT

Notes / gotchas

  • Hopper+ GPUs. The auto path no-ops on older GPUs.
  • Leave room for the symmetric allocator. It allocates via ncclMemAlloc outside the BFC pool, so don't let XLA preallocate all memory — set XLA_PYTHON_CLIENT_PREALLOCATE=false (and/or a small MEM_FRACTION). Without this you'll see could not allocate collective ... out of memory.
  • PCIe is fine. Multimem (NVLS) needs NVLink+NVSwitch and will be skipped on PCIe cards, but symmetric registration + P2P/LSA still works.
  • This only accelerates XLA's own collectives. It does not let your custom kernel reach peer memory — use the manual recipe for that.

References

"""Self-contained demo: create NCCL symmetric buffers from JAX+XLA and prove it.
Run: python symmetric_buffer_demo.py (on >= 2 Hopper+/Blackwell GPUs)
Everything needed is in THIS file:
1. We configure the symmetric-buffer path ourselves (XLA flags below) instead
of relying on external env vars.
2. We run a real cross-device collective (pmap psum all-reduce).
3. We capture XLA's / NCCL's own C++ log stream at the file-descriptor level
while the collective runs, then print the actual symmetric-buffer creation
events: XLA's `NcclSymmetricMemory -> Register symmetric buffer ...` and
NCCL's `register comm ... buffer ... size ...`
(these come from ncclCommWindowRegister(..., NCCL_WIN_COLL_SYMMETRIC)).
4. We assert both that symmetric buffers were created AND that the math is
correct.
The symmetric-buffer *creation* itself lives in XLA's runtime (there is no
pure-Python NCCL window API); what this script controls and demonstrates is the
configuration that triggers it and the evidence that it happened.
"""
import os
# ---- (1) configure the symmetric-buffer path BEFORE importing jax ----------
# These two flags make XLA's collective allocator use ncclMemAlloc + register a
# symmetric window (ncclCommWindowRegister) for collective buffers.
os.environ["XLA_FLAGS"] = (
"--xla_gpu_experimental_enable_nccl_symmetric_buffers=true "
"--xla_gpu_enable_nccl_user_buffers=true"
)
# Don't grab all device memory: the symmetric allocator needs room outside BFC.
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.10")
# Ask XLA + NCCL to log buffer registration so we can capture it.
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
os.environ["TF_CPP_VMODULE"] = "nccl_symmetric_memory=3,nccl_communicator=3"
os.environ["NCCL_DEBUG"] = "INFO"
os.environ["NCCL_DEBUG_SUBSYS"] = "INIT,REG"
import contextlib
import os as _os
import re
import tempfile
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
@contextlib.contextmanager
def capture_c_stderr():
"""Capture fd-level stdout+stderr (where XLA/NCCL C++ logs go)."""
tmp = tempfile.TemporaryFile(mode="w+b")
saved_out, saved_err = _os.dup(1), _os.dup(2)
_os.dup2(tmp.fileno(), 1)
_os.dup2(tmp.fileno(), 2)
try:
yield tmp
finally:
_os.dup2(saved_out, 1)
_os.dup2(saved_err, 2)
_os.close(saved_out)
_os.close(saved_err)
def main():
devs = jax.devices()
n = len(devs)
print(f"[info] jax {jax.__version__}, {n} devices: {devs[0].device_kind} x{n}")
if n < 2:
raise SystemExit("need >= 2 devices")
@partial(jax.pmap, axis_name="i")
def all_reduce(x):
return jax.lax.psum(x, axis_name="i")
elems = 8 * 1024 * 1024 # 32 MB / device
base = jnp.arange(n, dtype=jnp.float32).reshape(n, 1) # row r filled with r
shard = jnp.broadcast_to(base, (n, elems))
# ---- (2)+(3) run the collective while capturing the C++ log stream -----
with capture_c_stderr() as logf:
out = all_reduce(shard)
out.block_until_ready()
logf.seek(0)
log = logf.read().decode("utf-8", "replace")
# ---- (3) show the symmetric-buffer CREATION events ---------------------
xla_reg = re.findall(
r"Register symmetric buffer for NCCL communicator; "
r"buffer=(\w+); size=(\d+); comm=(\w+)", log)
nccl_reg = re.findall(r"register comm (\w+) buffer (\w+) size (\d+)", log)
va = re.findall(r"Symmetric VA size=(\S+)", log)
print("\n=== symmetric buffers created (XLA NcclSymmetricMemory) ===")
for buf, size, comm in xla_reg:
print(f" XLA register: buffer={buf} size={int(size)//(1<<20)}MiB comm={comm}")
print("=== NCCL window registration (ncclCommWindowRegister, NCCL_WIN_COLL_SYMMETRIC) ===")
for comm, buf, size in nccl_reg:
print(f" NCCL register: comm={comm} buffer={buf} size={int(size)//(1<<20)}MiB")
if va:
print(f" NCCL symmetric VA window size: {va[0]}")
assert xla_reg, "no symmetric buffers were registered — flag not honored?"
assert nccl_reg, "NCCL did not report a window registration"
print(f"\n[ok] {len(xla_reg)} symmetric buffer registration(s) observed")
# ---- (4) correctness of the all-reduce that used those buffers ---------
expected = n * (n - 1) / 2
host = np.asarray(out)
assert np.allclose(host, expected), f"{host.min()}..{host.max()} != {expected}"
print(f"[check] all-reduce({elems} elems/dev) == {expected} (min={host.min()}, max={host.max()})")
# a couple more sizes -> repeated FindSymmetricMemory lookups at runtime
for e in (1 << 20, 16 << 20):
s = jnp.broadcast_to(jnp.arange(n, dtype=jnp.float32).reshape(n, 1), (n, e))
r = all_reduce(s); r.block_until_ready()
assert np.allclose(np.asarray(r), expected)
print(f"[check] all-reduce({e} elems/dev) OK")
print("\n[result] SYMMETRIC BUFFERS CREATED + ALL-REDUCES CORRECT")
if __name__ == "__main__":
main()
"""Verify XLA sets up NCCL symmetric buffers and exercises FindSymmetricMemory.
Strategy (no custom C++ needed): the prebuilt XLA in jax-toolbox already
contains NcclSymmetricMemory::Create -> ncclCommWindowRegister(...,
NCCL_WIN_COLL_SYMMETRIC) and the runtime FindSymmetricMemory lookup. We turn the
path on with XLA flags, run a real cross-device all-reduce (psum) over a large
array, and (a) verify numerical correctness and (b) rely on XLA/NCCL logs to
prove the symmetric registration + lookup actually happened.
"""
import os
import numpy as np
import jax
import jax.numpy as jnp
from functools import partial
def main():
devs = jax.devices()
n = len(devs)
print(f"[info] jax {jax.__version__}, {n} devices: {[d.device_kind for d in devs]}",
flush=True)
if n < 2:
raise SystemExit("need >= 2 devices")
# ---- the collective: all-reduce (psum) across all devices via pmap -------
@partial(jax.pmap, axis_name="i")
def all_reduce(x):
return jax.lax.psum(x, axis_name="i")
# "minimal yet large map": one large shard per device, all-reduced.
# Sized to fit alongside other jobs sharing these GPUs.
elems = 8 * 1024 * 1024
base = jnp.arange(n, dtype=jnp.float32).reshape(n, 1) # [0,1,...]
shard = jnp.broadcast_to(base, (n, elems)) # row r filled with r
# Warmup (triggers compile + symmetric registration during prepare phase).
out = all_reduce(shard)
out.block_until_ready()
# Correctness: every device should hold sum_r r = n*(n-1)/2 in every element.
expected = n * (n - 1) / 2
host = np.asarray(out)
ok = np.allclose(host, expected)
print(f"[check] elements/device={elems} ({elems*4/1e6:.0f} MB), "
f"expected={expected}, got[0,0]={host[0,0]}, got[{n-1},{elems-1}]={host[n-1,elems-1]}",
flush=True)
assert ok, f"all-reduce mismatch: {host.min()}..{host.max()} != {expected}"
# A few more sizes to exercise repeated FindSymmetricMemory lookups.
for e in (1 << 20, 4 << 20, 16 << 20):
s = jnp.broadcast_to(jnp.arange(n, dtype=jnp.float32).reshape(n, 1), (n, e))
r = all_reduce(s)
r.block_until_ready()
assert np.allclose(np.asarray(r), expected)
print(f"[check] size={e} ({e*4/1e6:.0f} MB/dev) all-reduce OK", flush=True)
print("[result] ALL ALL-REDUCES CORRECT", flush=True)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment