A minimal recipe for write your own
CUDA communication kernel as an out-of-tree XLA FFI custom call, give it
symmetric NCCL buffers, and reach peers with the NCCL device API
(ncclGetLsaPointer) — all built against a released jaxlib, with no XLA
source checkout and no recompile.
This is the simplified sibling of the in-tree manual recipe. The in-tree one had
to be built inside /opt/xla because it used XLA's internal FFI collective
contexts (RequestSymmetricAddress / FindSymmetricMemory), which are not in
the stable FFI ABI. This recipe avoids them entirely.
The gating feature is per-custom-call memory-space coloring —
operands_memory_spaces / results_memory_spaces being honored on a plain
custom-call (XLA PR #39742, GetCustomCallOperandMemorySpace /
kOperandsMemorySpacesAttr). Without it, set_xla_metadata("{0:1}") on your
custom call is silently ignored, the buffer isn't placed symmetrically, and
ncclCommWindowRegister(NCCL_WIN_COLL_SYMMETRIC) fails.
- PR #39742 landed in XLA on 2026-03-26 (commit
c73d5b8af35c). - jax 0.9.x (XLA pin 2026-01-15) does not have it.
- jax 0.10.0 (XLA pin ~2026-04-15) is the first release that includes it.
- Everything else is older:
jax.ffi.{pycapsule,register_ffi_target,ffi_call,include_dir}(≥0.6),set_xla_metadata(≥0.4.30). - Also needs the bundled NCCL ≥ 2.28 for the device API (
ncclGetLsaPointer,ncclDevCommCreate, thenccl_device/headers) +ncclCommWindowRegister. jax 0.10.x bundles NCCL 2.29.x, so 0.10.0 covers this too.
So the floor is 0.10.0. Verified on 0.10.2.dev20260602 (NCCL 2.29.7).
- Bring your own NCCL comm. Rank 0 mints an
ncclUniqueIdand publishes it through the JAX distributed coordinator's key-value store (client.key_value_set_bytes/blocking_key_value_get_bytes); every process reads it and callsncclCommInitRank. (EnsureCommin the.cu,exchange_uidin the.py.) This is multi-node correct — the coordinator spans all nodes and the KV store needs no shared filesystem. Do not bootstrap with an XLA collective (broadcast_one_to_all): it races against our own comm init and can deadlock; the KV store is a plain RPC, not a collective, so it's safe. - Color the buffers collective. In Python,
set_xla_metadata(operands_memory_spaces="{0:1}", results_memory_spaces="{0:1}")tags the custom call's input/output as memory-space1 = kCollective, so XLA allocates them symmetrically (this is exactly what XLA PR #39742 enables for plain custom calls). No recompile — the colorer ships in jaxlib ≥ 0.10. - Register + cache the windows. In the handler, call
ncclCommWindowRegister(comm, ptr, bytes, &win, NCCL_WIN_COLL_SYMMETRIC)on the device pointers XLA passed, and cache thencclWindow_tkeyed by pointer (GetWindow). Registration is lazy and idempotent. - Reduce with the device API. The kernel pulls every peer's symmetric
srcviancclGetLsaPointer(win, 0, peer)and sums (AllReduceKernel).
Everything it touches is public: stable FFI ABI (PlatformStream, Buffer,
Result, Attr) + public NCCL host/device API. Verified surfaces are listed at
the bottom.
| File | What it is |
|---|---|
byo_allreduce.cu |
the custom call: comm bootstrap, window register/cache, device-API kernel, exported handler ByoAllReduce + byo_get_unique_id |
jax_byo_allreduce.py |
JAX driver: distributed init, uid broadcast, register_ffi_target from the .so, set_xla_metadata coloring, ffi_call |
build.sh |
nvcc build to libbyo_allreduce.so against jax.ffi.include_dir() — no XLA tree |
run.sh |
docker + nvidia runtime orchestration: 1 process per GPU, tee to byo.log |
./run.sh # 2 GPUs, default image jax-2026-06-02
NPROC=2 N=1048576 ./run.sh
tail -f byo.logExpected (2 GPUs): each rank fills src with its rank id, so the all-reduce sum
is sum(0..nproc-1):
[proc 0/2] N=1048576 expected=1.0 got[0]=1.0 got[-1]=1.0 ALL_OK=True
[proc 1/2] N=1048576 expected=1.0 got[0]=1.0 got[-1]=1.0 ALL_OK=True
| Need | In-tree recipe | This recipe |
|---|---|---|
| symmetric buffer | XLA RequestSymmetricAddress (internal ctx) |
set_xla_metadata color kCollective (Python) |
| get the window | FindSymmetricMemory (internal ctx) |
you call ncclCommWindowRegister + cache |
| peer access | ncclGetLsaPointer (device API) |
same |
| comm | XLA's clique (internal ctx) | your own ncclCommInitRank |
| build | inside /opt/xla, recompile XLA |
nvcc one .so vs released jaxlib headers |
The internal contexts live in xla/ffi/ffi_structs.h / invoke.h and are absent
from the stable xla/ffi/api/ffi.h — that was the only reason the in-tree build
was unavoidable. By owning the comm and the registration, this recipe never
needs them.
ncclCommWindowRegister(NCCL_WIN_COLL_SYMMETRIC) requires the buffer to be at a
cross-rank symmetric address with ≥4096-byte alignment
(NCCL_WIN_REQUIRED_ALIGNMENT). Coloring kCollective is what is supposed to
guarantee that placement — and jax.distributed.initialize is what makes XLA's
collective allocator active so the coloring is meaningful. If the registration
returns an error, that's the signal the colored buffer wasn't actually placed
symmetrically; check NCCL_DEBUG=INFO REG lines in byo.log. This is the single
assumption the recipe rests on; everything else is mechanical.
- Cross-rank sync. The kernel is barrier-free; cross-rank ordering is
done with host-side whole-comm barriers (a 1-element
ncclAllReduceon the stream) bracketing the launch. An in-kernelncclLsaBarrierSessionintermittently deadlocks on a non-cooperative grid, so avoid it. - Datatype/shape. Wired for
f32, 1-D. Generalize by templating the kernel and bindingAnyBuffer. - Teardown is collective — don't skip it.
byo_finalize()deregisters the windows and callsncclCommDestroyon every rank, gated by a coordinator barrier. Leaving the comm to implicit at-exit cleanup hangs the surviving rank (ncclCommDestroyis collective). - Topology.
ncclGetLsaPointerpeer loads need the peers to be load/store accessible (NVLink, or PCIe P2P). Verified working over PCIe P2P (CUMEM) on 2× RTX PRO 6000 Blackwell (no NVLink).
- jaxlib
0.10.2.dev20260602, NCCL2.29.7, CUDA12.9. ncclCommWindowRegister/Deregister/ncclMemAllocexported bylibnccl.so.2.- device API used:
ncclGetLsaPointer(win, 0, peer)(fromnccl_device/); cross-rank sync via host-sidencclAllReducebarriers (no device comm needed). jax.ffi:pycapsule,register_ffi_target,ffi_call,include_dir(=/opt/jaxlibs/jaxlib/jaxlib/include, hasxla/ffi/api/ffi.h).jax.experimental.xla_metadata.set_xla_metadata,jax.experimental.multihost_utils.broadcast_one_to_all.