Skip to content

Instantly share code, notes, and snippets.

@zeryx
Last active June 3, 2026 19:29
Show Gist options
  • Select an option

  • Save zeryx/086ec6099c709fa9a44db1e192fd7ead to your computer and use it in GitHub Desktop.

Select an option

Save zeryx/086ec6099c709fa9a44db1e192fd7ead to your computer and use it in GitHub Desktop.
JAX↔XLA bring-your-own-comm NCCL symmetric all-reduce — out-of-tree FFI custom call, no XLA recompile (verified jaxlib 0.10.2/NCCL 2.29.7 on 2x Blackwell)

BYO-comm symmetric all-reduce — XLA custom call, no XLA recompile

A minimal recipe for write your own CUDA communication kernel as an out-of-tree XLA FFI custom call, give it symmetric NCCL buffers, and reach peers with the NCCL device API (ncclGetLsaPointer) — all built against a released jaxlib, with no XLA source checkout and no recompile.

This is the simplified sibling of the in-tree manual recipe. The in-tree one had to be built inside /opt/xla because it used XLA's internal FFI collective contexts (RequestSymmetricAddress / FindSymmetricMemory), which are not in the stable FFI ABI. This recipe avoids them entirely.

Minimum version: jax / jaxlib 0.10.0

The gating feature is per-custom-call memory-space coloringoperands_memory_spaces / results_memory_spaces being honored on a plain custom-call (XLA PR #39742, GetCustomCallOperandMemorySpace / kOperandsMemorySpacesAttr). Without it, set_xla_metadata("{0:1}") on your custom call is silently ignored, the buffer isn't placed symmetrically, and ncclCommWindowRegister(NCCL_WIN_COLL_SYMMETRIC) fails.

  • PR #39742 landed in XLA on 2026-03-26 (commit c73d5b8af35c).
  • jax 0.9.x (XLA pin 2026-01-15) does not have it.
  • jax 0.10.0 (XLA pin ~2026-04-15) is the first release that includes it.
  • Everything else is older: jax.ffi.{pycapsule,register_ffi_target,ffi_call,include_dir} (≥0.6), set_xla_metadata (≥0.4.30).
  • Also needs the bundled NCCL ≥ 2.28 for the device API (ncclGetLsaPointer, ncclDevCommCreate, the nccl_device/ headers) + ncclCommWindowRegister. jax 0.10.x bundles NCCL 2.29.x, so 0.10.0 covers this too.

So the floor is 0.10.0. Verified on 0.10.2.dev20260602 (NCCL 2.29.7).

The pattern (4 steps)

  1. Bring your own NCCL comm. Rank 0 mints an ncclUniqueId and publishes it through the JAX distributed coordinator's key-value store (client.key_value_set_bytes / blocking_key_value_get_bytes); every process reads it and calls ncclCommInitRank. (EnsureComm in the .cu, exchange_uid in the .py.) This is multi-node correct — the coordinator spans all nodes and the KV store needs no shared filesystem. Do not bootstrap with an XLA collective (broadcast_one_to_all): it races against our own comm init and can deadlock; the KV store is a plain RPC, not a collective, so it's safe.
  2. Color the buffers collective. In Python, set_xla_metadata(operands_memory_spaces="{0:1}", results_memory_spaces="{0:1}") tags the custom call's input/output as memory-space 1 = kCollective, so XLA allocates them symmetrically (this is exactly what XLA PR #39742 enables for plain custom calls). No recompile — the colorer ships in jaxlib ≥ 0.10.
  3. Register + cache the windows. In the handler, call ncclCommWindowRegister(comm, ptr, bytes, &win, NCCL_WIN_COLL_SYMMETRIC) on the device pointers XLA passed, and cache the ncclWindow_t keyed by pointer (GetWindow). Registration is lazy and idempotent.
  4. Reduce with the device API. The kernel pulls every peer's symmetric src via ncclGetLsaPointer(win, 0, peer) and sums (AllReduceKernel).

Everything it touches is public: stable FFI ABI (PlatformStream, Buffer, Result, Attr) + public NCCL host/device API. Verified surfaces are listed at the bottom.

Files

File What it is
byo_allreduce.cu the custom call: comm bootstrap, window register/cache, device-API kernel, exported handler ByoAllReduce + byo_get_unique_id
jax_byo_allreduce.py JAX driver: distributed init, uid broadcast, register_ffi_target from the .so, set_xla_metadata coloring, ffi_call
build.sh nvcc build to libbyo_allreduce.so against jax.ffi.include_dir()no XLA tree
run.sh docker + nvidia runtime orchestration: 1 process per GPU, tee to byo.log

Run

./run.sh                      # 2 GPUs, default image jax-2026-06-02
NPROC=2 N=1048576 ./run.sh
tail -f byo.log

Expected (2 GPUs): each rank fills src with its rank id, so the all-reduce sum is sum(0..nproc-1):

[proc 0/2] N=1048576 expected=1.0 got[0]=1.0 got[-1]=1.0 ALL_OK=True
[proc 1/2] N=1048576 expected=1.0 got[0]=1.0 got[-1]=1.0 ALL_OK=True

Why no recompile (vs. the in-tree recipe)

Need In-tree recipe This recipe
symmetric buffer XLA RequestSymmetricAddress (internal ctx) set_xla_metadata color kCollective (Python)
get the window FindSymmetricMemory (internal ctx) you call ncclCommWindowRegister + cache
peer access ncclGetLsaPointer (device API) same
comm XLA's clique (internal ctx) your own ncclCommInitRank
build inside /opt/xla, recompile XLA nvcc one .so vs released jaxlib headers

The internal contexts live in xla/ffi/ffi_structs.h / invoke.h and are absent from the stable xla/ffi/api/ffi.h — that was the only reason the in-tree build was unavoidable. By owning the comm and the registration, this recipe never needs them.

The one thing to confirm (linchpin)

ncclCommWindowRegister(NCCL_WIN_COLL_SYMMETRIC) requires the buffer to be at a cross-rank symmetric address with ≥4096-byte alignment (NCCL_WIN_REQUIRED_ALIGNMENT). Coloring kCollective is what is supposed to guarantee that placement — and jax.distributed.initialize is what makes XLA's collective allocator active so the coloring is meaningful. If the registration returns an error, that's the signal the colored buffer wasn't actually placed symmetrically; check NCCL_DEBUG=INFO REG lines in byo.log. This is the single assumption the recipe rests on; everything else is mechanical.

Notes / production knobs

  • Cross-rank sync. The kernel is barrier-free; cross-rank ordering is done with host-side whole-comm barriers (a 1-element ncclAllReduce on the stream) bracketing the launch. An in-kernel ncclLsaBarrierSession intermittently deadlocks on a non-cooperative grid, so avoid it.
  • Datatype/shape. Wired for f32, 1-D. Generalize by templating the kernel and binding AnyBuffer.
  • Teardown is collective — don't skip it. byo_finalize() deregisters the windows and calls ncclCommDestroy on every rank, gated by a coordinator barrier. Leaving the comm to implicit at-exit cleanup hangs the surviving rank (ncclCommDestroy is collective).
  • Topology. ncclGetLsaPointer peer loads need the peers to be load/store accessible (NVLink, or PCIe P2P). Verified working over PCIe P2P (CUMEM) on 2× RTX PRO 6000 Blackwell (no NVLink).

Verified against jax-2026-06-02

  • jaxlib 0.10.2.dev20260602, NCCL 2.29.7, CUDA 12.9.
  • ncclCommWindowRegister/Deregister/ncclMemAlloc exported by libnccl.so.2.
  • device API used: ncclGetLsaPointer(win, 0, peer) (from nccl_device/); cross-rank sync via host-side ncclAllReduce barriers (no device comm needed).
  • jax.ffi: pycapsule, register_ffi_target, ffi_call, include_dir (= /opt/jaxlibs/jaxlib/jaxlib/include, has xla/ffi/api/ffi.h).
  • jax.experimental.xla_metadata.set_xla_metadata, jax.experimental.multihost_utils.broadcast_one_to_all.
#!/usr/bin/env bash
# Build the BYO-comm custom-call .so against the STOCK jaxlib FFI headers +
# the system NCCL. No XLA checkout, no XLA recompile.
set -euo pipefail
HERE="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$HERE"
# FFI headers ship inside the installed jaxlib wheel.
INC="$(python3 -c 'import jax; print(jax.ffi.include_dir())')"
echo "jax FFI include dir: $INC"
test -f "$INC/xla/ffi/api/ffi.h" || { echo "ffi.h not found under $INC"; exit 1; }
# Build a fat binary covering Hopper (sm_90) + Blackwell (sm_100/sm_120).
# NCCL symmetric memory requires Hopper or newer.
GENCODE=(
-gencode arch=compute_90,code=sm_90
-gencode arch=compute_100,code=sm_100
-gencode arch=compute_120,code=sm_120
)
set -x
nvcc -O3 -std=c++17 --expt-relaxed-constexpr \
--compiler-options -fPIC -shared \
-I"$INC" \
"${GENCODE[@]}" \
byo_allreduce.cu -o libbyo_allreduce.so \
-lnccl -lcudart
set +x
echo "built: $HERE/libbyo_allreduce.so"
// byo_allreduce.cu — Bring-Your-Own-Comm symmetric all-reduce custom call.
//
// This is the "devtech pattern": an out-of-tree XLA FFI custom call that
// 1. brings its OWN NCCL communicator (bootstrapped from a unique-id that
// JAX broadcasts to every process),
// 2. takes the input/output device buffers XLA hands it — which JAX has
// colored into the *collective* memory space via set_xla_metadata, so
// XLA allocates them symmetrically,
// 3. registers those buffers as NCCL symmetric windows with
// ncclCommWindowRegister(NCCL_WIN_COLL_SYMMETRIC) and caches the windows,
// 4. runs a one-shot all-reduce kernel that reaches every peer's copy of the
// input directly with the NCCL *device* API (ncclGetLsaPointer).
//
// Crucially it touches ONLY the stable public FFI ABI (xla/ffi/api/ffi.h:
// PlatformStream, Buffer, Result, Attr) plus the public NCCL host+device API.
// It never uses XLA's internal CollectiveMemory / RequestSymmetricAddress /
// FindSymmetricMemory contexts, so it builds against a *released* jaxlib with
// NO XLA recompile.
//
// Build: ./build.sh Run: ./run.sh
//
// Verified against the jax-toolbox image jax-2026-06-02:
// jaxlib 0.10.2.dev20260602, NCCL 2.29.7, CUDA 12.9.
#include <cuda_runtime.h>
#include <cuda/atomic>
#include <nccl.h>
#include <nccl_device.h>
#include <cstdint>
#include <cstring>
#include <mutex>
#include <unordered_map>
#include "xla/ffi/api/ffi.h"
namespace ffi = xla::ffi;
// ----------------------------------------------------------------------------
// error plumbing
// ----------------------------------------------------------------------------
#define NCCL_RET(expr) \
do { \
ncclResult_t _s = (expr); \
if (_s != ncclSuccess) \
return ffi::Error::Internal(std::string("NCCL: ") + #expr + " -> " + \
ncclGetErrorString(_s)); \
} while (0)
#define CUDA_RET(expr) \
do { \
cudaError_t _s = (expr); \
if (_s != cudaSuccess) \
return ffi::Error::Internal(std::string("CUDA: ") + #expr + " -> " + \
cudaGetErrorString(_s)); \
} while (0)
// ----------------------------------------------------------------------------
// process-global resources: one NCCL comm + device comm per process, plus a
// pointer-keyed cache of registered symmetric windows. (One GPU per process,
// so one comm is all we need; the cache is what "cache the window for later"
// from the devtech pattern means.)
// ----------------------------------------------------------------------------
struct Resources {
std::mutex mu;
bool inited = false;
ncclComm_t comm = nullptr;
int nranks = 0;
void* scratch = nullptr; // 4B device scratch for the host-side barrier
std::unordered_map<void*, ncclWindow_t> windows; // device ptr -> window
};
static Resources& Res() {
static Resources r;
return r;
}
// Lazily build the comm from (rank, nranks, uid). All ranks pass the SAME uid
// (JAX broadcasts it), so ncclCommInitRank rendezvouses correctly.
static ffi::Error EnsureComm(int rank, int nranks,
ffi::Span<const uint8_t> uid) {
Resources& r = Res();
if (r.inited) return ffi::Error::Success();
if (uid.size() != sizeof(ncclUniqueId))
return ffi::Error::InvalidArgument("uid attr must be 128 bytes");
ncclUniqueId id;
std::memcpy(id.internal, uid.begin(), sizeof(ncclUniqueId));
NCCL_RET(ncclCommInitRank(&r.comm, nranks, id, rank));
r.nranks = nranks;
CUDA_RET(cudaMalloc(&r.scratch, sizeof(int))); // for the cross-rank barrier
r.inited = true;
return ffi::Error::Success();
}
// A whole-comm barrier on the device stream: a 1-element all-reduce. NCCL
// guarantees every rank reaches it and orders memory across ranks. We use this
// instead of an in-kernel device barrier (ncclLsaBarrierSession) because a
// per-CTA cross-rank device barrier on a non-cooperative grid can deadlock
// intermittently when the grid's CTAs are not all co-resident.
static ffi::Error StreamBarrier(Resources& r, cudaStream_t stream) {
NCCL_RET(ncclAllReduce(r.scratch, r.scratch, 1, ncclInt, ncclSum, r.comm,
stream));
return ffi::Error::Success();
}
// Register `ptr` as a symmetric window once, then reuse. The buffer must be
// symmetric across ranks — that is guaranteed by JAX having colored it into the
// collective memory space (operands/results_memory_spaces="{...:1}").
static ffi::Error GetWindow(void* ptr, size_t bytes, ncclWindow_t* out) {
Resources& r = Res();
std::lock_guard<std::mutex> lk(r.mu);
auto it = r.windows.find(ptr);
if (it != r.windows.end()) {
*out = it->second;
return ffi::Error::Success();
}
ncclWindow_t win = nullptr;
// Collective + blocking: every rank must reach this with the same ptr-order
// and size. They do, because every rank runs the same custom call.
NCCL_RET(ncclCommWindowRegister(r.comm, ptr, bytes, &win,
NCCL_WIN_COLL_SYMMETRIC));
r.windows.emplace(ptr, win);
*out = win;
return ffi::Error::Success();
}
// ----------------------------------------------------------------------------
// the kernel — pull every peer's symmetric `src` with the NCCL device API and
// sum. Cross-rank synchronization is done by host-side ncclAllReduce barriers
// bracketing the launch (see AllReduceImpl), so the kernel itself is barrier-
// free and cannot deadlock on grid occupancy.
// ----------------------------------------------------------------------------
__global__ void AllReduceKernel(ncclWindow_t src_win, float* __restrict__ dst,
int npeers, size_t count) {
const size_t gid = blockIdx.x * size_t(blockDim.x) + threadIdx.x;
const size_t stride = size_t(gridDim.x) * blockDim.x;
for (size_t i = gid; i < count; i += stride) {
float acc = 0.f;
for (int peer = 0; peer < npeers; ++peer) {
// <-- THE device API: peer's pointer inside the symmetric window.
const float* peer_src =
static_cast<const float*>(ncclGetLsaPointer(src_win, 0, peer));
acc += peer_src[i];
}
dst[i] = acc;
}
}
// ----------------------------------------------------------------------------
// the FFI handler (host side)
// ----------------------------------------------------------------------------
static ffi::Error AllReduceImpl(cudaStream_t stream,
ffi::BufferR1<ffi::F32> src,
ffi::Result<ffi::BufferR1<ffi::F32>> dst,
int32_t rank, int32_t nranks,
ffi::Span<const uint8_t> uid) {
if (auto e = EnsureComm(rank, nranks, uid); e.failure()) return e;
const size_t count = src.element_count();
const size_t bytes = src.size_bytes();
void* src_ptr = src.untyped_data();
void* dst_ptr = dst->untyped_data();
// Register (or look up cached) symmetric windows for BOTH buffers, exactly as
// the devtech pattern prescribes.
ncclWindow_t src_win = nullptr, dst_win = nullptr;
if (auto e = GetWindow(src_ptr, bytes, &src_win); e.failure()) return e;
if (auto e = GetWindow(dst_ptr, bytes, &dst_win); e.failure()) return e;
Resources& r = Res();
// Pre-barrier: every rank's `src` is produced + visible before any peer read.
if (auto e = StreamBarrier(r, stream); e.failure()) return e;
const int block = 256;
const int grid = 64;
AllReduceKernel<<<grid, block, 0, stream>>>(
src_win, static_cast<float*>(dst_ptr), r.nranks, count);
CUDA_RET(cudaGetLastError());
// Post-barrier: keep ranks in lockstep before buffers can be reused.
if (auto e = StreamBarrier(r, stream); e.failure()) return e;
return ffi::Error::Success();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(
ByoAllReduce, AllReduceImpl,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<cudaStream_t>>()
.Arg<ffi::BufferR1<ffi::F32>>()
.Ret<ffi::BufferR1<ffi::F32>>()
.Attr<int32_t>("rank")
.Attr<int32_t>("nranks")
.Attr<ffi::Span<const uint8_t>>("uid"));
// ----------------------------------------------------------------------------
// tiny ctypes-callable helper so Python can mint the NCCL unique id without a
// cupy / mpi4py dependency. Rank 0 calls this, JAX broadcasts the 128 bytes.
// ----------------------------------------------------------------------------
extern "C" int byo_get_unique_id(char* out_128) {
ncclUniqueId id;
ncclResult_t s = ncclGetUniqueId(&id);
if (s != ncclSuccess) return static_cast<int>(s);
std::memcpy(out_128, id.internal, sizeof(ncclUniqueId));
return 0;
}
// Synchronized teardown. MUST be called by every rank while all are still
// alive (gate it behind a coordinator barrier on the Python side): ncclComm
// destroy is collective, so if one rank exits and leaves its comm to implicit
// at-exit cleanup, the surviving rank hangs forever in NCCL. Draining here
// avoids that.
extern "C" int byo_finalize() {
Resources& r = Res();
std::lock_guard<std::mutex> lk(r.mu);
if (!r.inited) return 0;
for (auto& kv : r.windows) ncclCommWindowDeregister(r.comm, kv.second);
r.windows.clear();
if (r.scratch) { cudaFree(r.scratch); r.scratch = nullptr; }
ncclCommDestroy(r.comm);
r.comm = nullptr;
r.inited = false;
return 0;
}
#!/usr/bin/env python3
"""Driver for the bring-your-own-comm symmetric all-reduce custom call.
One process per GPU. The launcher (run.sh) sets PROC_ID / NUM_PROCS /
COORD_ADDR and pins one GPU per process via CUDA_VISIBLE_DEVICES.
Flow:
1. jax.distributed.initialize -> XLA knows the global topology, which is what
makes the collective memory space actually allocate *symmetric* buffers.
2. rank 0 mints an NCCL unique id (via the .so's byo_get_unique_id) and
publishes the 128 bytes through the JAX distributed coordinator's
key-value store; every rank reads it back. This works across nodes.
(We deliberately do NOT bootstrap with an XLA collective like
broadcast_one_to_all -- issuing an XLA collective here races against our
own comm's init and can deadlock. The KV store is a coordinator RPC, not a
collective, so it is safe.)
3. register the FFI target straight from the .so symbol (no nanobind).
4. set_xla_metadata colors the custom call's operand+result into the
collective memory space ("{0:1}") -> symmetric allocation + the device
pointer our handler registers as an NCCL window.
5. call it. Each rank fills src with its rank id; the all-reduce sum is
sum(0..nproc-1), identical on every rank.
"""
import ctypes
import os
import numpy as np
import jax
import jax.numpy as jnp
from jax._src import distributed
from jax.experimental.xla_metadata import set_xla_metadata
NCCL_UID_BYTES = 128
UID_KV_KEY = "byo_allreduce/nccl_unique_id"
def exchange_uid(lib, proc_id, timeout_ms=60_000):
"""Rank 0 mints the NCCL unique id and publishes it through the JAX
distributed coordinator's key-value store; every rank reads it back.
This is the multi-node-correct bootstrap: the coordinator already spans all
nodes (jax.distributed.initialize), and the KV store is a plain RPC -- not
an XLA/NCCL collective -- so it can't deadlock against our comm init. The
coordinator's KV store is fresh per run, so the key never goes stale.
"""
client = distributed.global_state.client
assert client is not None, "call jax.distributed.initialize() first"
if proc_id == 0:
buf = ctypes.create_string_buffer(NCCL_UID_BYTES)
rc = lib.byo_get_unique_id(buf)
assert rc == 0, f"ncclGetUniqueId failed: {rc}"
client.key_value_set_bytes(UID_KV_KEY, buf.raw[:NCCL_UID_BYTES])
raw = client.blocking_key_value_get_bytes(UID_KV_KEY, timeout_ms)
return np.frombuffer(raw, dtype=np.uint8).copy()
LIB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "libbyo_allreduce.so")
N = int(os.environ.get("N", 1 << 20))
def main():
proc_id = int(os.environ["PROC_ID"])
nproc = int(os.environ["NUM_PROCS"])
coord = os.environ.get("COORD_ADDR", "127.0.0.1:12345")
# 1. distributed init (1 visible GPU per process -> local device 0)
jax.distributed.initialize(coordinator_address=coord, num_processes=nproc,
process_id=proc_id, local_device_ids=[0])
lib = ctypes.CDLL(LIB)
lib.byo_get_unique_id.argtypes = [ctypes.c_char_p]
lib.byo_get_unique_id.restype = ctypes.c_int
# 2. rank 0 mints + publishes the NCCL unique id via the coordinator KV store
uid = exchange_uid(lib, proc_id)
# 3. register the FFI target directly from the exported .so symbol
jax.ffi.register_ffi_target(
"byo_all_reduce", jax.ffi.pycapsule(lib.ByoAllReduce), platform="CUDA")
def all_reduce(x):
out = jax.ShapeDtypeStruct(x.shape, x.dtype)
return jax.ffi.ffi_call("byo_all_reduce", out)(
x, rank=np.int32(proc_id), nranks=np.int32(nproc), uid=uid)
# 4. color operand 0 + result 0 into the collective (symmetric) memory space
@jax.jit
def f(x):
with set_xla_metadata(operands_memory_spaces="{0:1}",
results_memory_spaces="{0:1}"):
return all_reduce(x)
x = jnp.full((N,), float(proc_id), dtype=jnp.float32)
y = f(x)
y.block_until_ready()
host = np.asarray(y)
expected = nproc * (nproc - 1) / 2.0
ok = bool(np.allclose(host, expected))
print(f"[proc {proc_id}/{nproc}] N={N} expected={expected} "
f"got[0]={host[0]} got[-1]={host[-1]} ALL_OK={ok}", flush=True)
assert ok, "all-reduce result mismatch"
# Clean teardown, in order:
# (1) coordinator barrier so every rank has finished its GPU work and is
# present while the coordinator is definitely still up;
# (2) destroy our NCCL comm on ALL ranks together -- ncclCommDestroy is
# collective, so if a rank exits and leaves its comm to implicit
# at-exit cleanup, the surviving rank hangs forever in NCCL. This was
# the cause of the intermittent "Shutdown barrier has failed" aborts;
# (3) jax.distributed.shutdown() -- now a tight, already-synced barrier.
client = distributed.global_state.client
client.wait_at_barrier("byo_allreduce_done", 120_000)
lib.byo_finalize.restype = ctypes.c_int
lib.byo_finalize()
client.wait_at_barrier("byo_finalized", 120_000)
jax.distributed.shutdown()
if __name__ == "__main__":
main()
#!/usr/bin/env bash
#
# Build + run the BYO-comm symmetric all-reduce against STOCK jaxlib — no XLA
# recompile. Spawns one process per GPU (1 GPU each) and does a cross-process
# all-reduce entirely through our own NCCL comm + symmetric windows.
#
# Requirements: docker + nvidia runtime, >=2 Hopper/Blackwell GPUs, NVLink (or
# P2P) between them. Logs tee'd to byo.log (tail -f byo.log).
#
# Usage:
# ./run.sh
# NPROC=2 N=1048576 IMAGE=ghcr.io/nvidia/jax:jax-2026-06-02 ./run.sh
set -euo pipefail
IMAGE="${IMAGE:-ghcr.io/nvidia/jax:jax-2026-06-02}"
CONTAINER="${CONTAINER:-byo-symtest}"
NPROC="${NPROC:-2}"
N="${N:-1048576}"
KEEP="${KEEP:-1}"
HERE="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
LOG="${LOG:-$HERE/byo.log}"
: > "$LOG"; exec > >(tee -a "$LOG") 2>&1
ts() { date -u +%H:%M:%S; }
echo "############################################################"
echo "# BYO-comm symmetric all-reduce | log: $LOG"
echo "# image=$IMAGE nproc=$NPROC N=$N"
echo "############################################################"
set -x
docker pull "$IMAGE"
docker rm -f "$CONTAINER" >/dev/null 2>&1 || true
# Mount the recipe dir; long-lived container so rebuilds are fast.
docker run -d --name "$CONTAINER" \
--runtime=nvidia --gpus all --ipc=host \
--ulimit memlock=-1 --ulimit stack=67108864 --shm-size 16g \
-v "$HERE":/work -w /work \
"$IMAGE" sleep infinity
set +x
echo "[$(ts)] ===== BUILD libbyo_allreduce.so (stock jaxlib headers) ====="
set -x
docker exec "$CONTAINER" bash -lc 'cd /work && ./build.sh'
set +x
echo "[$(ts)] ===== RUN $NPROC processes (1 GPU each) ====="
set -x
# Each process: own GPU (CUDA_VISIBLE_DEVICES), own rank (PROC_ID), shared
# coordinator. NCCL_DEBUG to watch the window registration + symmetric setup.
# NCCL_DEBUG=WARN keeps the log readable; set INFO + SUBSYS=INIT,REG to watch
# the symmetric-window registration. The NCCL unique id is exchanged through the
# JAX coordinator's KV store (see jax_byo_allreduce.py) -- multi-node ready, no
# shared filesystem needed.
docker exec \
-e NPROC="$NPROC" -e N="$N" \
-e NCCL_DEBUG="${NCCL_DEBUG:-WARN}" \
"$CONTAINER" bash -lc '
set -e
cd /work
nvidia-smi --query-gpu=index,name,compute_cap --format=csv
pids=""
for i in $(seq 0 $((NPROC-1))); do
CUDA_VISIBLE_DEVICES=$i \
PROC_ID=$i NUM_PROCS=$NPROC COORD_ADDR=127.0.0.1:12345 N=$N \
python3 jax_byo_allreduce.py &
pids="$pids $!"
done
rc=0
for p in $pids; do wait $p || rc=1; done
exit $rc
'
set +x
echo "[$(ts)] ===== DONE ====="
if [ "$KEEP" = "1" ]; then
echo "container '$CONTAINER' kept. Re-run: docker exec $CONTAINER bash -lc 'cd /work && ./build.sh && ...'"
echo "remove with: docker rm -f $CONTAINER"
else
docker rm -f "$CONTAINER" >/dev/null 2>&1 || true
fi
echo "log: $LOG"
VERIFIED — 2026-06-03, container built from ghcr.io/nvidia/jax:jax-2026-06-02
Hardware: 2x NVIDIA RTX PRO 6000 Blackwell (sm_120 / compute_cap 12.0),
P2P-reachable over PCIe (CUMEM; no NVLink on these workstation cards).
Stack: jaxlib 0.10.2.dev20260602, NCCL 2.29.7, CUDA 12.9. NO XLA recompile.
BUILD (against stock jaxlib FFI headers, jax.ffi.include_dir()):
built: /work/libbyo_allreduce.so (clean compile; only benign -Wreturn-type
warnings from XLA's own api.h)
exported symbols: ByoAllReduce (T), byo_get_unique_id (T)
REGISTRATION (Python): jax.ffi.register_ffi_target("byo_all_reduce",
jax.ffi.pycapsule(lib.ByoAllReduce), platform="CUDA") -> OK
ncclGetUniqueId via lib.byo_get_unique_id -> rc=0
END-TO-END via ./run.sh (KEEP=0), 2 procs (1 GPU each), N=1048576 f32:
index, name, compute_cap
0, NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 12.0
1, NVIDIA RTX PRO 6000 Blackwell Workstation Edition, 12.0
[proc 1/2] N=1048576 expected=1.0 got[0]=1.0 got[-1]=1.0 ALL_OK=True
[proc 0/2] N=1048576 expected=1.0 got[0]=1.0 got[-1]=1.0 ALL_OK=True
===== DONE ===== RUNSH_EXIT=0
NCCL version 2.29.7+cuda13.2
(trailing WatchTasksAsync warnings are the coordinator closing AFTER the
verified results — cosmetic.)
CONCLUSION: the devtech pattern works out-of-tree on released jaxlib.
1. set_xla_metadata(operands/results_memory_spaces="{0:1}") -> kCollective
-> XLA placed the buffers symmetrically.
2. our OWN ncclCommInitRank comm (uid via the JAX coordinator KV store, NOT
an XLA collective) + ncclCommWindowRegister(NCCL_WIN_COLL_SYMMETRIC)
accepted those buffers.
3. device kernel ncclGetLsaPointer(win,0,peer) read peers correctly;
cross-rank sync via host-side ncclAllReduce barriers.
No XLA internal FFI contexts (RequestSymmetricAddress/FindSymmetricMemory),
no /opt/xla checkout, no recompile.
RELIABILITY: 10/10 consecutive trials at N=2^20 (after the three fixes below),
plus the packaged ./run.sh end-to-end (exit 0).
COORDINATION/SYNC HAZARDS found + fixed (see SHUTDOWN_ISSUE.md, full write-up):
1. Bootstrap: broadcast_one_to_all (XLA collective) for the uid DEADLOCKS
against our comm init. Fixed -> coordinator KV store (key_value_set_bytes /
blocking_key_value_get_bytes), a plain RPC.
2. In-kernel ncclLsaBarrierSession (per-CTA cross-rank barrier) intermittently
DEADLOCKS on a non-cooperative grid. Fixed -> barrier-free kernel +
host-side ncclAllReduce barriers bracketing the launch (also dropped
ncclDevCommCreate).
3. Not destroying our comm hangs the surviving rank at exit (ncclCommDestroy
is collective). Fixed -> byo_finalize() on all ranks, barrier-gated, before
jax.distributed.shutdown().
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment