Turn on XLA's automatic NCCL symmetric-buffer registration for its built-in
collectives (psum, all-reduce, all-gather, …) — no custom C++, no rebuild,
runs on stock jaxlib. XLA window-registers the collective buffers for you via
ncclCommWindowRegister(..., NCCL_WIN_COLL_SYMMETRIC).
Verified on 2× NVIDIA RTX PRO 6000 Blackwell (sm_120) with the jax-toolbox
image ghcr.io/nvidia/jax:jax-2026-06-02 (jax/jaxlib 0.10.2.dev20260602,
NCCL 2.28.8).
Want your own kernel to receive symmetric buffers and call
FindSymmetricMemory/ncclGetLsaPointer? That's the manual path and it requires building C++ inside XLA — see the companion gist: https://gist.github.com/zeryx/eb3f5daf23bb50d9194a6388bae65abd
XLA_FLAGS="--xla_gpu_experimental_enable_nccl_symmetric_buffers=true \
--xla_gpu_enable_nccl_user_buffers=true"With those set, any built-in collective that runs on >=2 Hopper+/Blackwell GPUs gets its collective buffers registered as NCCL symmetric memory automatically. That's the whole API surface — you don't touch the symmetric pointers yourself (XLA's runtime does, internally).
symmetric_buffer_demo.py— self-contained: sets the flags itself, runs apmapall-reduce, then captures XLA's/NCCL's C++ log stream and prints the actualncclCommWindowRegister/Register symmetric bufferevents, asserting both that registration happened and that the math is correct.verify_symmetric.py— minimal correctness check across a few sizes (a "minimal yet large map"); pair it with the flags + debug env below.
docker run --rm --runtime=nvidia --gpus all --ipc=host \
--ulimit memlock=-1 --ulimit stack=67108864 --shm-size 16g \
-v "$PWD:/work" -w /work \
ghcr.io/nvidia/jax:jax-2026-06-02 \
python symmetric_buffer_demo.pysymmetric_buffer_demo.py sets the XLA flags + debug logging internally, so the
above is all you need. For verify_symmetric.py, pass the flags yourself:
docker run --rm --runtime=nvidia --gpus all --ipc=host \
--ulimit memlock=-1 --ulimit stack=67108864 --shm-size 16g \
-v "$PWD:/work" -w /work \
-e XLA_PYTHON_CLIENT_PREALLOCATE=false -e XLA_PYTHON_CLIENT_MEM_FRACTION=0.10 \
-e XLA_FLAGS="--xla_gpu_experimental_enable_nccl_symmetric_buffers=true --xla_gpu_enable_nccl_user_buffers=true" \
-e NCCL_DEBUG=INFO -e NCCL_DEBUG_SUBSYS=INIT,REG \
-e TF_CPP_MIN_LOG_LEVEL=0 -e TF_CPP_VMODULE=nccl_symmetric_memory=3,nccl_communicator=3 \
ghcr.io/nvidia/jax:jax-2026-06-02 \
python verify_symmetric.pynccl_communicator.cc:444] [0] Register symmetric buffer for NCCL communicator; buffer=0x402000000; size=33554432; comm=...
NCCL INFO Symmetric VA size=96GB
NCCL INFO register comm ... buffer 0x402000000 size 33554432
[result] ALL ALL-REDUCES CORRECT
- Hopper+ GPUs. The auto path no-ops on older GPUs.
- Leave room for the symmetric allocator. It allocates via
ncclMemAllocoutside the BFC pool, so don't let XLA preallocate all memory — setXLA_PYTHON_CLIENT_PREALLOCATE=false(and/or a smallMEM_FRACTION). Without this you'll seecould not allocate collective ... out of memory. - PCIe is fine. Multimem (NVLS) needs NVLink+NVSwitch and will be skipped on PCIe cards, but symmetric registration + P2P/LSA still works.
- This only accelerates XLA's own collectives. It does not let your custom kernel reach peer memory — use the manual recipe for that.
ncclCommWindowRegister— https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#c.ncclCommWindowRegister- JAX FFI — https://docs.jax.dev/en/latest/ffi.html
- jax-toolbox — https://github.com/NVIDIA/JAX-Toolbox