Skip to content

Instantly share code, notes, and snippets.

@zeryx
Last active June 3, 2026 14:07
Show Gist options
  • Select an option

  • Save zeryx/eb3f5daf23bb50d9194a6388bae65abd to your computer and use it in GitHub Desktop.

Select an option

Save zeryx/eb3f5daf23bb50d9194a6388bae65abd to your computer and use it in GitHub Desktop.
Recipe: symmetric (NCCL-window-registered) buffers with XLA + JAX (ncclCommWindowRegister + operands/results_memory_spaces, openxla/xla#39742)

Recipe: NCCL symmetric buffers with XLA + JAX (manual registration)

Set up a buffer that NCCL has registered as symmetric memory (ncclCommWindowRegister(..., NCCL_WIN_COLL_SYMMETRIC)), reach it from a custom GPU kernel through the NCCL device API (ncclGetLsaPointer), and drive it from XLA/JAX via a custom call — including the JAX↔XLA bridge from openxla/xla#39742 (operands_memory_spaces / results_memory_spaces frontend attributes).

This recipe is verified end-to-end on 2× NVIDIA RTX PRO 6000 Blackwell (sm_120) against jax-toolbox jax-2026-06-02 = jax/jaxlib 0.10.2.dev20260602, XLA 9190ab7d75, NCCL 2.29.7. See VERIFIED_RESULTS.txt. The base ghcr.io/nvidia/jax:jax-2026-06-02 image is all you need — maxtext is not required (it just adds MaxText on top of the same /opt/xla + bazel toolchain).

Just want to speed up XLA's built-in collectives with symmetric buffers, no C++/rebuild? That's the automatic path — separate gist: https://gist.github.com/zeryx/d91336d808b1d6b16b72176765af439b

Reference doc: ncclCommWindowRegister


The stack (manual path)

JAX:  jax.ffi.ffi_call("my_call", ...) wrapped in
        set_xla_metadata(operands_memory_spaces="{0:1}",
                         results_memory_spaces="{0:1}")        # jax_symmetric_allreduce.py
   │  lowers to an HLO custom-call carrying those frontend_attributes
   ▼
XLA compile:  gpu_memory_space_assignment colorer (PR #39742) reads the
              attributes → colors those buffers kCollective (space 1)
   ▼
FFI prepare:  RequestSymmetricAddress(clique, buf)
              → NcclSymmetricMemory::Create → ncclCommWindowRegister(NCCL_WIN_COLL_SYMMETRIC)
                                                  # upstream_nccl_symmetric_memory.cc
FFI execute:  FindSymmetricMemory(clique, buf) → (NcclSymmetricMemory*, offset)
                                                  # upstream_collective_ops_ffi_test.cc
   ▼
Kernel:  ncclGetLsaPointer(win, offset, peer)     # upstream_collective_ops_ffi_kernels.cu.cc

Files in this gist

Run it / what runs:

  • run_symmetric_test.sh — one script: pulls the image, compiles the XLA manual-registration test, auto-selects the per-arch binary, runs it with full NCCL/XLA debug logging, tee'd to a logfile.
  • VERIFIED_RESULTS.txt — the actual pass output + registration log lines.

The exact source the script compiles (verbatim from XLA 9190ab7d75):

  • upstream_collective_ops_ffi_test.ccPrepareDeviceAllReduce (calls RequestSymmetricAddress), DeviceAllReduce (calls FindSymmetricMemory), and DeviceAllReduceWithFrontendAttributes (the JAX-bridge HLO).
  • upstream_collective_ops_ffi_kernels.cu.cc — the ncclGetLsaPointer kernel.
  • upstream_nccl_symmetric_memory.cc / .h — the ncclCommWindowRegister(..., NCCL_WIN_COLL_SYMMETRIC) call site.

JAX-side illustration:

  • jax_symmetric_allreduce.py — how the JAX↔XLA bridge looks in Python (set_xla_metadata(operands_memory_spaces=...)). Template; needs a registered custom-call target (built as above) to actually run.

The automatic-path demos (symmetric_buffer_demo.py, verify_symmetric.py) now live in their own gist: https://gist.github.com/zeryx/d91336d808b1d6b16b72176765af439b


Compilation process (how run_symmetric_test.sh works)

The manual handler can't run on stock jaxlib, so we build the upstream test target inside the jax-toolbox image (it ships /opt/xla source + bazel + a hermetic CC/CUDA/NCCL toolchain — no host clang/CUDA needed).

  1. Pull ghcr.io/nvidia/jax:jax-2026-06-02 (jax/jaxlib 0.10.2.dev20260602). Any recent dated jax-toolbox tag works (maxtext too); the old rolling jax:jax / jaxlib 0.9.1 is too old — its /opt/xla lacks the test.
  2. Build in a GPU container:
    cd /opt/xla
    bazel build --config=cuda \
      --repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES=sm_90,sm_100,sm_120 \
      //xla/backends/gpu/tests:collective_ops_ffi_test
    • Hermetic toolchain self-downloads (bazel 7.7.0 clang, CUDA 12.9.1, NCCL 2.29.7, cuDNN 9.8.0, NVSHMEM 3.2.5).
    • Cold cache ≈ 25–60 min; warm re-runs are seconds.
    • The gpu_test macro emits per-arch binaries: collective_ops_ffi_test_rtx6000pro, _h100, _b200, _gb200, _nvgpu_any, … The script auto-picks the one matching your GPU (falling back to _nvgpu_any).
  3. Run the binary directly with the GPUs and full debug env:
    NCCL_DEBUG=TRACE NCCL_DEBUG_SUBSYS=ALL TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_MAX_VLOG_LEVEL=3 \
    ./bazel-bin/.../collective_ops_ffi_test_rtx6000pro \
        --gtest_filter='*DeviceAllReduce*'

Quick start

chmod +x run_symmetric_test.sh
./run_symmetric_test.sh                 # build + run, logs to ./symtest.log
tail -f symtest.log                     # (other terminal) full live logs

# only the JAX-bridge (frontend-attribute) case:
FILTER='*DeviceAllReduceWithFrontendAttributes*' ./run_symmetric_test.sh

Gotchas

  • Hopper+ only. Symmetric/LSA needs Hopper or newer; multimem (NVLS) needs NVLink+NVSwitch — on PCIe cards the multimem cases skip, but the symmetric + LSA-peer path runs (verified on PCIe Blackwell).
  • Equal sizes / symmetric addresses. ncclCommWindowRegister is collective and (by default) requires equal buffer sizes per rank — allocate from the collective space (memory-space int 1 = kCollective; 0=default, 2=temp).
  • Internal API. The handler links XLA-internal FFI collective contexts; it must be built in-tree (this recipe), not against the stable jaxlib FFI ABI.
  • Deregister on teardown (ncclCommWindowDeregister) — NcclSymmetricMemory's destructor does this; never deregister with a collective in flight.

References

"""JAX driver for the symmetric-memory all-reduce custom call.
This is the JAX <-> XLA bridge. The key idea: `jax.ffi.ffi_call` lowers to an
HLO custom-call, and `set_xla_metadata` attaches `frontend_attributes` to it.
We use the `operands_memory_spaces` / `results_memory_spaces` attributes added
in openxla/xla#39742 to ask XLA's GPU memory colorer to place operand 0 and
result 0 in memory space 1 (kCollective / symmetric). An FFI prepare handler (see the upstream
PrepareDeviceAllReduce in upstream_collective_ops_ffi_test.cc) then
window-registers those buffers via ncclCommWindowRegister(...,
NCCL_WIN_COLL_SYMMETRIC). This .py is a TEMPLATE: it needs a custom-call target
of that name built into XLA to actually run (see run_symmetric_test.sh).
Run on >= 2 Hopper+ GPUs, e.g.:
XLA_FLAGS="--xla_gpu_experimental_enable_nvshmem=false" \
python jax_symmetric_allreduce.py
(The FFI target `my_sym_allreduce` must be registered with XLA — i.e. built into
the jaxlib/XLA you are running against.)
"""
import numpy as np
import jax
import jax.numpy as jnp
from jax.experimental.xla_metadata import set_xla_metadata
from jax.sharding import PartitionSpec as P
from jax.experimental.shard_map import shard_map
# Memory space integers understood by XLA's colorer:
# 0 = default, 1 = collective/symmetric, 2 = temp.
SYMMETRIC = 1
def sym_all_reduce(x):
"""All-reduce `x` via a symmetric-memory custom call.
`x` is the per-device shard; the custom call sums it across all
participating devices and returns the same shape/dtype.
"""
out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)
call = jax.ffi.ffi_call("my_sym_allreduce", out_type)
# --- the bridge ---------------------------------------------------------
# Tag operand 0 and result 0 to live in symmetric memory. XLA reads these
# frontend attributes during buffer assignment (PR #39742) and colors the
# corresponding buffers kCollective, so they are eligible for NCCL window
# registration in the FFI prepare phase.
with set_xla_metadata(
operands_memory_spaces=f"{{0:{SYMMETRIC}}}",
results_memory_spaces=f"{{0:{SYMMETRIC}}}",
):
return call(x)
def main():
devices = jax.devices()
n = len(devices)
if n < 2:
raise SystemExit(f"Need >= 2 devices, found {n}")
mesh = jax.make_mesh((n,), ("x",))
# One row per device; after all-reduce every device should hold the column
# sum replicated.
data = jnp.arange(n * 8, dtype=jnp.uint32).reshape(n, 8)
@jax.jit
def run(d):
# shard_map gives each device its own shard and forms a clique across
# all of them, which the FFI prepare handler reuses.
f = shard_map(
lambda s: sym_all_reduce(s[0]), # s has shape (1, 8) per shard
mesh=mesh,
in_specs=P("x", None),
out_specs=P(None),
)
return f(d)
out = run(data)
expected = np.asarray(data).sum(axis=0)
print("result :", np.asarray(out))
print("expected:", expected)
np.testing.assert_array_equal(np.asarray(out), expected)
print("OK")
if __name__ == "__main__":
main()
#!/usr/bin/env bash
#
# Build + run XLA's MANUAL NCCL-symmetric-buffer registration test on local GPUs,
# with full debug logging, tee'd to a logfile you can tail live.
#
# Proves: a custom call whose FFI handler manually registers its buffers as NCCL
# symmetric memory (RequestSymmetricAddress -> ncclCommWindowRegister) and looks
# them up at run time (FindSymmetricMemory) + reaches peers from a device kernel
# (ncclGetLsaPointer). `DeviceAllReduceWithFrontendAttributes` is the JAX<->XLA
# bridge: operands_memory_spaces / results_memory_spaces frontend attributes.
#
# Requirements: docker + nvidia runtime, >=2 Hopper+/Blackwell GPUs, ~60GB disk,
# network (cold bazel build fetches hermetic toolchain + deps, ~25-60 min).
#
# Image: any jax-toolbox image works (maxtext NOT required) as long as the tag
# is recent enough that /opt/xla contains the manual-registration test — i.e. a
# DATED tag >= 2026-06-02. The old rolling `jax:jax` (jaxlib 0.9.1) is too old.
#
# Usage:
# ./run_symmetric_test.sh
# FILTER='*DeviceAllReduceWithFrontendAttributes*' ./run_symmetric_test.sh
# IMAGE=ghcr.io/nvidia/jax:maxtext-2026-06-02 ./run_symmetric_test.sh
# KEEP=0 ./run_symmetric_test.sh # remove the container when done
set -euo pipefail
# ---- config (override via env) --------------------------------------------
# Base jax image (smaller than maxtext); same /opt/xla + bazel + hermetic toolchain.
IMAGE="${IMAGE:-ghcr.io/nvidia/jax:jax-2026-06-02}" # jax/jaxlib 0.10.2.dev20260602
FILTER="${FILTER:-*DeviceAllReduce*}"
CAPS="${CAPS:-sm_90,sm_100,sm_120}" # Hopper + Blackwell
TARGET="//xla/backends/gpu/tests:collective_ops_ffi_test"
CONTAINER="${CONTAINER:-xla-symtest}"
KEEP="${KEEP:-1}" # keep container for fast re-runs
HERE="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
LOG="${LOG:-$HERE/symtest.log}"
# ---- mirror ALL output to the logfile and the screen ----------------------
: > "$LOG"
exec > >(tee -a "$LOG") 2>&1
ts() { date -u +%H:%M:%S; }
echo "############################################################"
echo "# symmetric-buffer test | log: $LOG"
echo "# image=$IMAGE filter=$FILTER archs=$CAPS keep=$KEEP"
echo "############################################################"
set -x # trace every command in this script
docker pull "$IMAGE"
docker rm -f "$CONTAINER" >/dev/null 2>&1 || true
# Long-lived container so the build cache survives for fast re-runs.
docker run -d --name "$CONTAINER" \
--runtime=nvidia --gpus all --ipc=host \
--ulimit memlock=-1 --ulimit stack=67108864 --shm-size 16g \
"$IMAGE" sleep infinity
set +x
echo "[$(ts)] ===== BUILD $TARGET (cold cache ~25-60 min) ====="
set -x
# --subcommands / --verbose_failures print full compile commands for debugging.
docker exec "$CONTAINER" bash -lc "
set -ex
cd /opt/xla
bazel build --config=cuda \
--repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES=$CAPS \
--verbose_failures \
--color=no --curses=no \
$TARGET
"
set +x
echo "[$(ts)] ===== RUN (full NCCL TRACE + XLA VLOG=3) ====="
set -x
# Full debug env:
# NCCL_DEBUG=TRACE + SUBSYS=ALL -> every NCCL init / registration / transport line
# TF_CPP_MIN_LOG_LEVEL=0 -> all glog INFO
# TF_CPP_MAX_VLOG_LEVEL=3 -> every VLOG(<=3) in XLA (symmetric mem, comms, thunks)
docker exec \
-e NCCL_DEBUG=TRACE \
-e NCCL_DEBUG_SUBSYS=ALL \
-e TF_CPP_MIN_LOG_LEVEL=0 \
-e TF_CPP_MAX_VLOG_LEVEL=3 \
-e TF_CPP_VMODULE=nccl_symmetric_memory=3,nccl_communicator=3,nccl_collective_thunk=3,collective_thunk=3,gpu_executable=3,thunk=3 \
-e FILTER="$FILTER" \
"$CONTAINER" bash -lc '
set -ex
cd /opt/xla
BINDIR=$(dirname "$(readlink -f bazel-bin)")/bin/xla/backends/gpu/tests
nvidia-smi --query-gpu=index,name,compute_cap,memory.total --format=csv
# gpu_test macro emits arch-suffixed binaries (…_rtx6000pro, …_h100, …_nvgpu_any).
gpu=$(nvidia-smi --query-gpu=name --format=csv,noheader | head -1 \
| tr "A-Z " "a-z" | tr -cd "a-z0-9")
pick=""
for f in "$BINDIR"/collective_ops_ffi_test_*; do
[ -x "$f" ] && [ ! -d "$f" ] || continue
case "$f" in *.params|*.cppmap|*.repo_mapping|*_manifest) continue;; esac
suf=${f##*collective_ops_ffi_test_}
if [ -z "$pick" ] || [ "$suf" = "nvgpu_any" ]; then pick="$f"; fi
case "$gpu" in *"$suf"*) pick="$f"; break;; esac
done
echo ">>> gpu=$gpu binary=$pick filter=$FILTER"
"$pick" --gtest_filter="$FILTER" --gtest_color=no --gtest_print_time=1
'
set +x
echo "[$(ts)] ===== DONE (exit from test above) ====="
if [ "$KEEP" = "1" ]; then
echo "container '$CONTAINER' kept (build cache warm). Re-run faster, or remove with:"
echo " docker rm -f $CONTAINER"
else
docker rm -f "$CONTAINER" >/dev/null 2>&1 || true
fi
echo "full log saved to: $LOG"
// SOURCE: openxla/xla @ 9190ab7d75ec35933e7cf1ed375ca6d08279e805 (jax-toolbox 2026-06-02, jax/jaxlib 0.10.2.dev20260602)
// Path in tree: xla/backends/gpu/tests/collective_ops_ffi_kernels.cu.cc
// Provided here verbatim so the gist matches exactly what run_symmetric_test.sh compiles.
#include "xla/backends/gpu/tests/collective_ops_ffi_kernels.h"
#include <cstddef>
#include <cstdint>
#include "absl/base/casts.h"
#include "third_party/nccl/nccl.h"
#include "xla/stream_executor/cuda/cuda_platform_id.h"
#include "xla/stream_executor/device_address.h"
#include "xla/stream_executor/gpu/gpu_kernel_registry.h"
#include "xla/stream_executor/kernel_spec.h"
#include "third_party/nccl/nccl_device.h"
namespace xla::gpu {
template <typename T>
static __global__ void NcclDevAllReduce(ncclDevComm dev_comm,
ncclWindow_t src_win,
ncclWindow_t dst_win, size_t src_offset,
size_t dst_offset, size_t count) {
ncclLsaBarrierSession<ncclCoopCta> bar(ncclCoopCta(), dev_comm,
ncclTeamTagLsa(), blockIdx.x);
bar.sync(ncclCoopCta(), cuda::memory_order_relaxed);
const int rank = dev_comm.lsaRank, nRanks = dev_comm.lsaSize;
const int globalTid = threadIdx.x + blockDim.x * (rank + blockIdx.x * nRanks);
const int globalNthreads = blockDim.x * gridDim.x * nRanks;
for (size_t o = globalTid; o < count; o += globalNthreads) {
T v = 0;
for (int peer = 0; peer < nRanks; peer++) {
T* inputPtr =
static_cast<T*>(ncclGetLsaPointer(src_win, src_offset, peer));
v += inputPtr[o];
}
for (int peer = 0; peer < nRanks; peer++) {
T* outputPtr =
static_cast<T*>(ncclGetLsaPointer(dst_win, dst_offset, peer));
outputPtr[o] = v;
}
}
bar.sync(ncclCoopCta(), cuda::memory_order_release);
}
// A trivial all-reduce for S32 data type that uses multimem instructions.
//
// WARNING: This kernel doesn't have any barriers and it is a caller
// responsibility to make sure that data is ready on all ranks.
static __global__ void MulticastAllReduce(uint32_t* src_mmem, uint32_t* dst,
size_t src_offset, size_t count) {
#if __CUDA_ARCH__ >= 900
int64_t offset = blockIdx.x * blockDim.x + threadIdx.x;
int64_t stride = blockDim.x * gridDim.x;
for (int64_t i = offset; i < count; i += stride) {
uint32_t data = 0;
asm volatile("multimem.ld_reduce.acquire.sys.global.add.u32 %0, [%1];"
: "=r"(data)
: "l"(src_mmem + src_offset + i)
: "memory");
dst[i] = data;
}
#endif // __CUDA_ARCH__ >= 900
}
// A trivial all-reduce for S32 data type that uses peer access.
//
// WARNING: This kernel doesn't have any barriers and it is a caller
// responsibility to make sure that data is ready on all ranks.
static __global__ void PeerAllReduce(uint32_t* src0, uint32_t* src1,
uint32_t* dst, size_t count) {
int64_t offset = blockIdx.x * blockDim.x + threadIdx.x;
int64_t stride = blockDim.x * gridDim.x;
for (int64_t i = offset; i < count; i += stride) {
uint32_t data = src0[i] + src1[i];
dst[i] = data;
}
}
static se::KernelLoaderSpec SymmetricAllReduceKernelSpec(int32_t arity) {
return se::KernelLoaderSpec::CreateInProcessSymbolSpec(
absl::bit_cast<void*>(&NcclDevAllReduce<int32_t>),
"SymmetricAllReduce_S32", arity);
}
static se::KernelLoaderSpec MulticastAllReduceKernelSpec(int32_t arity) {
return se::KernelLoaderSpec::CreateInProcessSymbolSpec(
absl::bit_cast<void*>(&MulticastAllReduce), "MulticastAllReduce_S32",
arity);
}
static se::KernelLoaderSpec Peer2AllReduceKernelSpec(int32_t arity) {
return se::KernelLoaderSpec::CreateInProcessSymbolSpec(
absl::bit_cast<void*>(&PeerAllReduce), "Peer2AllReduce_S32", arity);
}
} // namespace xla::gpu
GPU_KERNEL_REGISTRY_REGISTER_KERNEL_STATICALLY(
CollectiveSymmetricAllReduce, xla::gpu::SymmetricAllReduce,
stream_executor::cuda::kCudaPlatformId,
xla::gpu::SymmetricAllReduceKernelSpec);
GPU_KERNEL_REGISTRY_REGISTER_KERNEL_STATICALLY(
CollectiveMulticastAllReduce, xla::gpu::MultimemAllReduce,
stream_executor::cuda::kCudaPlatformId,
xla::gpu::MulticastAllReduceKernelSpec);
GPU_KERNEL_REGISTRY_REGISTER_KERNEL_STATICALLY(
CollectivePeer2AllReduce, xla::gpu::Peer2AllReduce,
stream_executor::cuda::kCudaPlatformId, xla::gpu::Peer2AllReduceKernelSpec);
// SOURCE: openxla/xla @ 9190ab7d75ec35933e7cf1ed375ca6d08279e805 (jax-toolbox 2026-06-02, jax/jaxlib 0.10.2.dev20260602)
// Path in tree: xla/backends/gpu/tests/collective_ops_ffi_test.cc
// Provided here verbatim so the gist matches exactly what run_symmetric_test.sh compiles.
==========
== CUDA ==
==========
NVIDIA Release (build )
CUDA Version 13.2.1.009
Container image Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Various files include modifications (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved.
GOVERNING TERMS: The software and materials are governed by the NVIDIA Software License Agreement
(found at https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-software-license-agreement/)
and the Product-Specific Terms for NVIDIA AI Products
(found at https://www.nvidia.com/en-us/agreements/enterprise-software/product-specific-terms-for-ai-products/).
ERROR: The NVIDIA Driver is present, but CUDA failed to initialize. GPU functionality will not be available.
[[ Unable to initialize CUDA driver (error ???) ]]
Failed to detect NVIDIA driver version.
NOTE: The SHMEM allocation limit is set to the default of 64MB. This may be
insufficient for CUDA. NVIDIA recommends the use of the following flags:
docker run --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 ...
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "absl/base/no_destructor.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "xla/tsl/platform/status_macros.h"
#include "xla/backends/gpu/collectives/gpu_clique_key.h"
#include "xla/backends/gpu/collectives/gpu_collectives.h"
#include "xla/backends/gpu/collectives/gpu_communicator.h"
#include "xla/backends/gpu/ffi.h"
#include "xla/backends/gpu/runtime/collective_clique_requests.h"
#include "xla/backends/gpu/runtime/collective_cliques.h"
#include "xla/backends/gpu/runtime/collective_execution.h"
#include "xla/backends/gpu/runtime/collective_memory.h"
#include "xla/backends/gpu/runtime/collective_memory_requests.h"
#include "xla/backends/gpu/runtime/collective_params.h"
#include "xla/backends/gpu/tests/collective_ops_e2e_test_base.h"
#include "xla/backends/gpu/tests/collective_ops_ffi_kernels.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/core/collectives/reduction_kind.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/ffi.h"
#include "xla/future.h"
#include "xla/literal.h"
#include "xla/runtime/device_id.h"
#include "xla/service/rendezvous.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/device_address.h"
#include "xla/stream_executor/gpu/gpu_kernel_registry.h"
#include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/stream.h"
#include "xla/tests/literal_test_util.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/test.h"
#include "xla/xla_data.pb.h"
namespace xla::gpu {
using ::testing::Values;
struct SynchronizationSignals {
absl::Mutex mutex;
absl::BlockingCounter finished_kernels_counter;
explicit SynchronizationSignals(int num_expected_kernels)
: finished_kernels_counter(num_expected_kernels) {}
void IncrementFinishedKernels() {
absl::MutexLock lock(mutex);
finished_kernels_counter.DecrementCount();
}
};
absl::NoDestructor<std::unique_ptr<SynchronizationSignals>> global_signals;
class CollectiveOpsTestFFI : public CollectiveOpsE2ETestBase {
public:
CollectiveOpsTestFFI()
: CollectiveOpsE2ETestBase(/*memory_size=*/32 * kMB,
/*collectives_memory_size=*/32 * kMB) {}
void SetUp() override {
CollectiveOpsE2ETestBase::SetUp();
*global_signals =
std::make_unique<SynchronizationSignals>(/*num_expected_kernels=*/2);
}
void TearDown() override {
CollectiveOpsE2ETestBase::TearDown();
global_signals->reset();
}
};
static constexpr int64_t kNumReplicas = 2;
// In this test we execute all collective operations across all devices.
static ReplicaGroup AllDevices() {
ReplicaGroup group;
for (int64_t i = 0; i < kNumReplicas; ++i) {
group.add_replica_ids(i);
}
return group;
}
// This is a prepare handler that tells XLA:GPU runtime what collective cliques
// should be acquired before the execution starts. All collective operations
// must let XLA:GPU runtime know what cliques they need ahead of time.
static absl::Status PrepareAllReduce(
const CollectiveParams* collective_params,
CollectiveCliqueRequests* clique_requests) {
TF_RET_CHECK(collective_params && clique_requests);
// Request a clique that covers all devices (this test runs on 2 gpus).
ASSIGN_OR_RETURN(
GpuCliqueKey clique_key,
GetGpuCliqueKey(
*collective_params, {AllDevices()},
CollectiveOpGroupMode::COLLECTIVE_OP_GROUP_MODE_FLATTENED_ID, false));
std::vector<GlobalDeviceId> all_device_groups;
for (int i = 0; i < kNumReplicas; ++i) {
all_device_groups.push_back(GlobalDeviceId(i));
}
// Ask XLA:GPU runtime to acquire a clique for this key. Later we will be
// able to get access to it from the execute handler.
RETURN_IF_ERROR(clique_requests->RequestClique(
clique_key, /*device_groups=*/{all_device_groups}));
return absl::OkStatus();
}
// This is a prepare handler for device-initiated collective operation which
// in addition to the clique asks for device comms and symmetric memory.
static absl::Status PrepareDeviceAllReduce(
ffi::BufferR0<U32> src, ffi::Result<ffi::BufferR0<U32>> dst,
const CollectiveParams* collective_params,
CollectiveCliqueRequests* clique_requests,
CollectiveMemoryRequests* memory_requests) {
TF_RET_CHECK(collective_params && clique_requests);
// Request a clique that covers all devices (this test runs on 2 gpus).
ASSIGN_OR_RETURN(
GpuCliqueKey clique_key,
GetGpuCliqueKey(
*collective_params, {AllDevices()},
CollectiveOpGroupMode::COLLECTIVE_OP_GROUP_MODE_FLATTENED_ID));
// Ask for a device communicator with 8 lsa barriers.
CollectiveCliqueRequests::CliqueRequirements requirements;
requirements.dev_comm = GpuDeviceCommunicator::Requirements{8};
std::vector<GlobalDeviceId> all_device_groups;
for (int i = 0; i < kNumReplicas; ++i) {
all_device_groups.push_back(GlobalDeviceId(i));
}
// Request XLA:GPU runtime to acquire a clique for this key. Later we will be
// able to get access to it from the execute handler.
RETURN_IF_ERROR(clique_requests->RequestClique(
clique_key, /*device_groups=*/{all_device_groups}, requirements));
// Request src and dst buffers to be symmetric on the given clique.
RETURN_IF_ERROR(memory_requests->RequestSymmetricAddress(
clique_key, src.device_memory()));
RETURN_IF_ERROR(memory_requests->RequestSymmetricAddress(
clique_key, dst->device_memory()));
return absl::OkStatus();
}
// This is a prepare handler for device-initiated collective operation which
// uses collective multimem to access peer devices.
static absl::Status PrepareMulticastAllReduce(
ffi::BufferR0<U32> src, ffi::Result<ffi::BufferR0<U32>> dst,
const CollectiveParams* collective_params,
CollectiveCliqueRequests* clique_requests,
CollectiveMemoryRequests* memory_requests) {
TF_RET_CHECK(collective_params && memory_requests);
// Request a clique that covers all devices (this test runs on 2 gpus).
ASSIGN_OR_RETURN(
GpuCliqueKey clique_key,
GetGpuCliqueKey(
*collective_params, {AllDevices()},
CollectiveOpGroupMode::COLLECTIVE_OP_GROUP_MODE_FLATTENED_ID));
std::vector<GlobalDeviceId> all_device_groups;
for (int i = 0; i < kNumReplicas; ++i) {
all_device_groups.push_back(GlobalDeviceId(i));
}
RETURN_IF_ERROR(clique_requests->RequestClique(
clique_key, /*device_groups=*/{all_device_groups}));
// Request src buffer to be mapped to multimem on the given clique.
//
// IMPORTANT: We don't request the clique itself, because multimem addresses
// accessible directly to kernels without a need for support from the
// underlying collective library.
RETURN_IF_ERROR(memory_requests->RequestMulticastAddress(
clique_key, src.device_memory()));
return absl::OkStatus();
}
// This is a prepare handler for device-initiated collective operation which
// uses collective multimem to access peer devices, but does it via symmetric
// memory handle.
static absl::Status PrepareSymMulticastAllReduce(
ffi::BufferR0<U32> src, ffi::Result<ffi::BufferR0<U32>> dst,
const CollectiveParams* collective_params,
CollectiveCliqueRequests* clique_requests,
CollectiveMemoryRequests* memory_requests) {
TF_RET_CHECK(collective_params && memory_requests);
// Request a clique that covers all devices (this test runs on 2 gpus).
ASSIGN_OR_RETURN(
GpuCliqueKey clique_key,
GetGpuCliqueKey(
*collective_params, {AllDevices()},
CollectiveOpGroupMode::COLLECTIVE_OP_GROUP_MODE_FLATTENED_ID));
std::vector<GlobalDeviceId> all_device_groups;
for (int i = 0; i < kNumReplicas; ++i) {
all_device_groups.push_back(GlobalDeviceId(i));
}
RETURN_IF_ERROR(clique_requests->RequestClique(
clique_key, /*device_groups=*/{all_device_groups}));
// Request src buffer to be symmetric on the given clique.
RETURN_IF_ERROR(memory_requests->RequestSymmetricAddress(
clique_key, src.device_memory()));
return absl::OkStatus();
}
// This is a prepare handler for device-initiated collective operation which
// uses collective peer memory to access peer devices, but does it via symmetric
// memory handle.
static absl::Status PrepareSymPeerAllReduce(
ffi::BufferR0<U32> src, ffi::Result<ffi::BufferR0<U32>> dst,
const CollectiveParams* collective_params,
CollectiveCliqueRequests* clique_requests,
CollectiveMemoryRequests* memory_requests) {
TF_RET_CHECK(collective_params && memory_requests);
// Request a clique that covers all devices (this test runs on 2 gpus).
ASSIGN_OR_RETURN(
GpuCliqueKey clique_key,
GetGpuCliqueKey(
*collective_params, {AllDevices()},
CollectiveOpGroupMode::COLLECTIVE_OP_GROUP_MODE_FLATTENED_ID));
std::vector<GlobalDeviceId> all_device_groups;
for (int i = 0; i < kNumReplicas; ++i) {
all_device_groups.push_back(GlobalDeviceId(i));
}
RETURN_IF_ERROR(clique_requests->RequestClique(
clique_key, /*device_groups=*/{all_device_groups}));
// Request src buffer to be symmetric on the given clique.
RETURN_IF_ERROR(memory_requests->RequestSymmetricAddress(
clique_key, src.device_memory()));
return absl::OkStatus();
}
// This is a prepare handler for device-initiated collective operation which
// uses collective peer memory to access peer devices.
static absl::Status PreparePeerAllReduce(
ffi::BufferR0<U32> src, ffi::Result<ffi::BufferR0<U32>> dst,
const CollectiveParams* collective_params,
CollectiveMemoryRequests* memory_requests) {
TF_RET_CHECK(collective_params && memory_requests);
// Request a clique that covers all devices (this test runs on 2 gpus).
ASSIGN_OR_RETURN(
GpuCliqueKey clique_key,
GetGpuCliqueKey(
*collective_params, {AllDevices()},
CollectiveOpGroupMode::COLLECTIVE_OP_GROUP_MODE_FLATTENED_ID));
// Request src buffer from all peers in the given clique.
RETURN_IF_ERROR(
memory_requests->RequestPeerAddress(clique_key, src.device_memory()));
return absl::OkStatus();
}
// FFI handler that uses XLA:GPU collectives API to perform an all reduce. This
// is just a test that demonstrates how to use XLA:GPU collectives API in an FFI
// handler, builtin all-reduce is a much better option. This version
// demonstrates requesting a communication stream and synchronizing it with the
// main stream.
static absl::Status AllReduce(se::Stream* stream, se::Stream* comm_stream,
ffi::BufferR0<U32> src,
ffi::Result<ffi::BufferR0<U32>> dst,
const CollectiveParams* collective_params,
const CollectiveCliques* collective_cliques) {
TF_RET_CHECK(collective_params && collective_cliques);
ASSIGN_OR_RETURN(
GpuCliqueKey clique_key,
GetGpuCliqueKey(
*collective_params, {AllDevices()},
CollectiveOpGroupMode::COLLECTIVE_OP_GROUP_MODE_FLATTENED_ID));
// Get the communicator for the requested clique.
ASSIGN_OR_RETURN(Communicator * comm,
collective_cliques->GetComm(
clique_key, collective_params->global_device_id));
// Synchronize communication stream with the main stream: make the
// communication stream wait for all prior work on the main stream.
RETURN_IF_ERROR(comm_stream->WaitFor(stream));
// Launch all-reduce on the communication stream.
Future<> future =
comm->AllReduce(src.device_memory(), dst->device_memory(),
src.element_type(), src.element_count(),
ReductionKind::SUM, GpuCollectives::On(*comm_stream));
RETURN_IF_ERROR(future.Await());
// Synchronize main stream with the communication stream: make the main
// stream wait for the all-reduce to complete.
RETURN_IF_ERROR(stream->WaitFor(comm_stream));
return absl::OkStatus();
}
// FFI handler that launches device kernel that does all-reduce using NCCL
// device-side APIs.
static absl::Status DeviceAllReduce(se::Stream* stream, ffi::BufferR0<U32> src,
ffi::Result<ffi::BufferR0<U32>> dst,
const CollectiveParams* collective_params,
const CollectiveCliques* collective_cliques,
const CollectiveMemory* collective_memory) {
TF_RET_CHECK(collective_params && collective_cliques && collective_memory);
ASSIGN_OR_RETURN(
GpuCliqueKey clique_key,
GetGpuCliqueKey(
*collective_params, {AllDevices()},
CollectiveOpGroupMode::COLLECTIVE_OP_GROUP_MODE_FLATTENED_ID));
// Find collective memory for src and dst buffers.
auto [sym_src, src_offset] =
collective_memory->FindSymmetricMemory(clique_key, src.device_memory());
auto [sym_dst, dst_offset] =
collective_memory->FindSymmetricMemory(clique_key, dst->device_memory());
TF_RET_CHECK(sym_src && sym_dst);
// Get requested device communicator for a given clique.
auto rank = clique_key.rank(collective_params->global_device_id);
ASSIGN_OR_RETURN(
GpuDeviceCommunicator * dev_comm,
collective_cliques->GetDeviceComm(
clique_key, *rank, GpuDeviceCommunicator::Requirements{8}));
// Load custom kernel that does device-initiated collectives.
ASSIGN_OR_RETURN(auto kernel, se::gpu::GpuKernelRegistry::GetGlobalRegistry()
.LoadKernel<SymmetricAllReduce>(
collective_params->executor));
se::BlockDim block_dims(1);
se::ThreadDim thread_dims(8);
RETURN_IF_ERROR(kernel.Launch(thread_dims, block_dims, stream, dev_comm,
sym_src, sym_dst, src_offset, dst_offset,
src.element_count()));
RETURN_IF_ERROR(stream->BlockHostUntilDone());
SynchronizationSignals* signals = global_signals->get();
signals->IncrementFinishedKernels();
return absl::OkStatus();
}
static absl::Status BlockedDeviceAllReduce(
se::Stream* stream, ffi::BufferR0<U32> src,
ffi::Result<ffi::BufferR0<U32>> dst,
const CollectiveParams* collective_params,
const CollectiveCliques* collective_cliques,
const CollectiveMemory* collective_memory) {
RETURN_IF_ERROR(DeviceAllReduce(stream, src, dst, collective_params,
collective_cliques, collective_memory));
return stream->BlockHostUntilDone();
}
// FFI handler that launches device kernel that does all-reduce using NCCL
// device-side APIs.
static absl::Status DelayedDeviceAllReduce(
se::Stream* stream, ffi::BufferR0<U32> src,
ffi::Result<ffi::BufferR0<U32>> dst,
const CollectiveParams* collective_params,
const CollectiveCliques* collective_cliques,
const CollectiveMemory* collective_memory) {
RETURN_IF_ERROR(
stream->DoHostCallback([]() { absl::SleepFor(absl::Seconds(1)); }));
RETURN_IF_ERROR(DeviceAllReduce(stream, src, dst, collective_params,
collective_cliques, collective_memory));
return absl::OkStatus();
}
static absl::Status MulticastAllReduce(
se::Stream* stream, ffi::BufferR0<U32> src,
ffi::Result<ffi::BufferR0<U32>> dst,
const CollectiveParams* collective_params,
const CollectiveMemory* collective_memory) {
TF_RET_CHECK(collective_params && collective_memory);
ASSIGN_OR_RETURN(
GpuCliqueKey clique_key,
GetGpuCliqueKey(
*collective_params, {AllDevices()},
CollectiveOpGroupMode::COLLECTIVE_OP_GROUP_MODE_FLATTENED_ID));
auto [src_mmem, src_offset] =
collective_memory->FindMultimemAddress(clique_key, src.device_memory());
TF_RET_CHECK(src_mmem != nullptr);
// Load custom kernel that does device-initiated collectives.
ASSIGN_OR_RETURN(auto kernel, se::gpu::GpuKernelRegistry::GetGlobalRegistry()
.LoadKernel<MultimemAllReduce>(
collective_params->executor));
// Create device addresses from multimem pointer.
auto src_addr =
se::DeviceAddress<uint32_t>::MakeFromByteSize(src_mmem, src.size_bytes());
// Block the host CPU thread until the asynchronous GPU copies / memory maps
// are complete.
RETURN_IF_ERROR(stream->BlockHostUntilDone());
// Because we launch a trivial kernel we use a device-side rendezvous to make
// sure that both devices will execute the kernel together after inputs become
// ready on both devices. Any real kernel must use device-side barriers.
static constexpr int32_t kKey = 0;
const int32_t* key = &kKey;
RETURN_IF_ERROR(Rendezvous<const int32_t*>(
"MulticastAllReduce", key, 2, absl::Seconds(1), absl::Seconds(5)));
se::BlockDim block_dims(1);
se::ThreadDim thread_dims(8);
RETURN_IF_ERROR(kernel.Launch(thread_dims, block_dims, stream, src_addr,
dst->device_memory(), src_offset,
src.element_count()));
RETURN_IF_ERROR(stream->BlockHostUntilDone());
SynchronizationSignals* signals = global_signals->get();
signals->IncrementFinishedKernels();
return absl::OkStatus();
}
// FFI handler that launches device kernel that does all-reduce using multicast
// memory access.
static absl::Status DelayedMulticastAllReduce(
se::Stream* stream, ffi::BufferR0<U32> src,
ffi::Result<ffi::BufferR0<U32>> dst,
const CollectiveParams* collective_params,
const CollectiveMemory* collective_memory) {
RETURN_IF_ERROR(
stream->DoHostCallback([]() { absl::SleepFor(absl::Seconds(1)); }));
RETURN_IF_ERROR(MulticastAllReduce(stream, src, dst, collective_params,
collective_memory));
return absl::OkStatus();
}
// FFI handler that launches device kernel that does all-reduce using multicast
// memory access.
static absl::Status BlockedMulticastAllReduce(
se::Stream* stream, ffi::BufferR0<U32> src,
ffi::Result<ffi::BufferR0<U32>> dst,
const CollectiveParams* collective_params,
const CollectiveMemory* collective_memory) {
RETURN_IF_ERROR(MulticastAllReduce(stream, src, dst, collective_params,
collective_memory));
return stream->BlockHostUntilDone();
}
static absl::Status SymMulticastAllReduce(
se::Stream* stream, ffi::BufferR0<U32> src,
ffi::Result<ffi::BufferR0<U32>> dst,
const CollectiveParams* collective_params,
const CollectiveMemory* collective_memory) {
TF_RET_CHECK(collective_params && collective_memory);
ASSIGN_OR_RETURN(
GpuCliqueKey clique_key,
GetGpuCliqueKey(
*collective_params, {AllDevices()},
CollectiveOpGroupMode::COLLECTIVE_OP_GROUP_MODE_FLATTENED_ID));
// Find collective memory for src buffer.
auto [sym_src, src_offset] =
collective_memory->FindSymmetricMemory(clique_key, src.device_memory());
// Load custom kernel that does device-initiated collectives.
ASSIGN_OR_RETURN(auto kernel, se::gpu::GpuKernelRegistry::GetGlobalRegistry()
.LoadKernel<MultimemAllReduce>(
collective_params->executor));
// Get multimem address for the src buffer.
ASSIGN_OR_RETURN(auto src_multimem, sym_src->multimem_addr());
if (!src_multimem) {
return absl::InternalError("Multimem address can't be resolved");
}
// Block the host CPU thread until the asynchronous GPU copies / memory maps
// are complete.
RETURN_IF_ERROR(stream->BlockHostUntilDone());
// Because we launch a trivial kernel we use a device-side rendezvous to make
// sure that both devices will execute the kernel together after inputs become
// ready on both devices. Any real kernel must use device-side barriers.
static constexpr int32_t kKey = 0;
const int32_t* key = &kKey;
RETURN_IF_ERROR(Rendezvous<const int32_t*>(
"MulticastAllReduce", key, 2, absl::Seconds(1), absl::Seconds(5)));
se::BlockDim block_dims(1);
se::ThreadDim thread_dims(8);
RETURN_IF_ERROR(kernel.Launch(thread_dims, block_dims, stream,
se::DeviceAddress<uint32_t>(src_multimem),
dst->device_memory(), src_offset,
src.element_count()));
RETURN_IF_ERROR(stream->BlockHostUntilDone());
SynchronizationSignals* signals = global_signals->get();
signals->IncrementFinishedKernels();
return absl::OkStatus();
}
// FFI handler that launches device kernel that does all-reduce using multicast
// memory access through the symmetric memory handle.
static absl::Status DelayedSymMulticastAllReduce(
se::Stream* stream, ffi::BufferR0<U32> src,
ffi::Result<ffi::BufferR0<U32>> dst,
const CollectiveParams* collective_params,
const CollectiveMemory* collective_memory) {
RETURN_IF_ERROR(
stream->DoHostCallback([]() { absl::SleepFor(absl::Seconds(1)); }));
RETURN_IF_ERROR(SymMulticastAllReduce(stream, src, dst, collective_params,
collective_memory));
return absl::OkStatus();
}
// FFI handler that launches device kernel that does all-reduce using multicast
// memory access through the symmetric memory handle.
static absl::Status BlockedSymMulticastAllReduce(
se::Stream* stream, ffi::BufferR0<U32> src,
ffi::Result<ffi::BufferR0<U32>> dst,
const CollectiveParams* collective_params,
const CollectiveMemory* collective_memory) {
RETURN_IF_ERROR(SymMulticastAllReduce(stream, src, dst, collective_params,
collective_memory));
return stream->BlockHostUntilDone();
}
static absl::Status SymPeerAllReduce(
se::Stream* stream, ffi::BufferR0<U32> src,
ffi::Result<ffi::BufferR0<U32>> dst,
const CollectiveParams* collective_params,
const CollectiveMemory* collective_memory) {
TF_RET_CHECK(collective_params && collective_memory);
ASSIGN_OR_RETURN(
GpuCliqueKey clique_key,
GetGpuCliqueKey(
*collective_params, {AllDevices()},
CollectiveOpGroupMode::COLLECTIVE_OP_GROUP_MODE_FLATTENED_ID));
// Find collective memory for src buffer.
auto [sym_src, src_offset] =
collective_memory->FindSymmetricMemory(clique_key, src.device_memory());
// Load custom kernel that does device-initiated collectives.
ASSIGN_OR_RETURN(auto kernel, se::gpu::GpuKernelRegistry::GetGlobalRegistry()
.LoadKernel<Peer2AllReduce>(
collective_params->executor));
// Get peer addresses for src buffer.
ASSIGN_OR_RETURN(auto src0, sym_src->peer_addr(RankId(0)));
ASSIGN_OR_RETURN(auto src1, sym_src->peer_addr(RankId(1)));
if (!src0 || !src1) {
return absl::InternalError("Peer address can't be resolved");
}
// Block the host CPU thread until the asynchronous GPU copies / memory maps
// are complete.
RETURN_IF_ERROR(stream->BlockHostUntilDone());
// Because we launch a trivial kernel we use a device-side rendezvous to make
// sure that both devices will execute the kernel together after inputs become
// ready on both devices. Any real kernel must use device-side barriers.
static constexpr int32_t kKey = 0;
const int32_t* key = &kKey;
RETURN_IF_ERROR(Rendezvous<const int32_t*>(
"SymPeerAllReduce", key, 2, absl::Seconds(1), absl::Seconds(5)));
se::BlockDim block_dims(1);
se::ThreadDim thread_dims(8);
RETURN_IF_ERROR(kernel.Launch(thread_dims, block_dims, stream,
se::DeviceAddress<uint32_t>(src0),
se::DeviceAddress<uint32_t>(src1),
dst->device_memory(), src.element_count()));
RETURN_IF_ERROR(stream->BlockHostUntilDone());
SynchronizationSignals* signals = global_signals->get();
signals->IncrementFinishedKernels();
return absl::OkStatus();
}
// FFI handler that launches device kernel that does all-reduce using peer
// memory access through the symmetric memory handle.
static absl::Status DelayedSymPeerAllReduce(
se::Stream* stream, ffi::BufferR0<U32> src,
ffi::Result<ffi::BufferR0<U32>> dst,
const CollectiveParams* collective_params,
const CollectiveMemory* collective_memory) {
RETURN_IF_ERROR(
stream->DoHostCallback([]() { absl::SleepFor(absl::Seconds(1)); }));
RETURN_IF_ERROR(
SymPeerAllReduce(stream, src, dst, collective_params, collective_memory));
return absl::OkStatus();
}
// FFI handler that launches device kernel that does all-reduce using peer
// memory access through the symmetric memory handle.
static absl::Status BlockedSymPeerAllReduce(
se::Stream* stream, ffi::BufferR0<U32> src,
ffi::Result<ffi::BufferR0<U32>> dst,
const CollectiveParams* collective_params,
const CollectiveMemory* collective_memory) {
RETURN_IF_ERROR(
SymPeerAllReduce(stream, src, dst, collective_params, collective_memory));
return stream->BlockHostUntilDone();
}
// FFI handler that launches device kernel that does all-reduce using peer
// memory access.
static absl::Status PeerAllReduce(se::Stream* stream, ffi::BufferR0<U32> src,
ffi::Result<ffi::BufferR0<U32>> dst,
const CollectiveParams* collective_params,
const CollectiveMemory* collective_memory) {
TF_RET_CHECK(collective_params && collective_memory);
ASSIGN_OR_RETURN(
GpuCliqueKey clique_key,
GetGpuCliqueKey(
*collective_params, {AllDevices()},
CollectiveOpGroupMode::COLLECTIVE_OP_GROUP_MODE_FLATTENED_ID));
auto src0 = collective_memory->FindPeerAddress(clique_key, RankId(0),
src.device_memory());
auto src1 = collective_memory->FindPeerAddress(clique_key, RankId(1),
src.device_memory());
TF_RET_CHECK(src0 && src1);
// Load custom kernel that does device-initiated collectives.
ASSIGN_OR_RETURN(auto kernel, se::gpu::GpuKernelRegistry::GetGlobalRegistry()
.LoadKernel<Peer2AllReduce>(
collective_params->executor));
// Block the host CPU thread until the asynchronous GPU copies / memory maps
// are complete.
RETURN_IF_ERROR(stream->BlockHostUntilDone());
// Because we launch a trivial kernel we use a device-side rendezvous to make
// sure that both devices will execute the kernel together after inputs become
// ready on both devices. Any real kernel must use device-side barriers.
static constexpr int32_t kKey = 0;
const int32_t* key = &kKey;
RETURN_IF_ERROR(Rendezvous<const int32_t*>(
"PeerAllReduce", key, 2, absl::Seconds(1), absl::Seconds(5)));
se::BlockDim block_dims(1);
se::ThreadDim thread_dims(8);
RETURN_IF_ERROR(kernel.Launch(thread_dims, block_dims, stream, *src0, *src1,
dst->device_memory(), src.element_count()));
RETURN_IF_ERROR(stream->BlockHostUntilDone());
SynchronizationSignals* signals = global_signals->get();
signals->IncrementFinishedKernels();
return absl::OkStatus();
}
static absl::Status BlockedPeerAllReduce(
se::Stream* stream, ffi::BufferR0<U32> src,
ffi::Result<ffi::BufferR0<U32>> dst,
const CollectiveParams* collective_params,
const CollectiveMemory* collective_memory) {
RETURN_IF_ERROR(
PeerAllReduce(stream, src, dst, collective_params, collective_memory));
return stream->BlockHostUntilDone();
}
static absl::Status DelayedPeerAllReduce(
se::Stream* stream, ffi::BufferR0<U32> src,
ffi::Result<ffi::BufferR0<U32>> dst,
const CollectiveParams* collective_params,
const CollectiveMemory* collective_memory) {
RETURN_IF_ERROR(
PeerAllReduce(stream, src, dst, collective_params, collective_memory));
RETURN_IF_ERROR(
stream->DoHostCallback([]() { absl::SleepFor(absl::Seconds(2)); }));
return absl::OkStatus();
}
XLA_FFI_DEFINE_HANDLER(kPrepareAllReduce, PrepareAllReduce,
ffi::Ffi::BindPrepare()
.Ctx<ffi::CollectiveParams>()
.Ctx<ffi::CollectiveCliqueRequests>());
XLA_FFI_DEFINE_HANDLER(kAllReduce, AllReduce,
ffi::Ffi::Bind()
.Ctx<ffi::Stream>()
.Ctx<ffi::CommunicationStream<0>>()
.Arg<ffi::BufferR0<U32>>() // src
.Ret<ffi::BufferR0<U32>>() // dst
.Ctx<ffi::CollectiveParams>()
.Ctx<ffi::CollectiveCliques>());
XLA_FFI_DEFINE_HANDLER(kPrepareDeviceAllReduce, PrepareDeviceAllReduce,
ffi::Ffi::BindPrepare()
.Arg<ffi::BufferR0<U32>>() // src
.Ret<ffi::BufferR0<U32>>() // dst
.Ctx<ffi::CollectiveParams>()
.Ctx<ffi::CollectiveCliqueRequests>()
.Ctx<ffi::CollectiveMemoryRequests>());
XLA_FFI_DEFINE_HANDLER(kDeviceAllReduce, BlockedDeviceAllReduce,
ffi::Ffi::Bind()
.Ctx<ffi::Stream>()
.Arg<ffi::BufferR0<U32>>() // src
.Ret<ffi::BufferR0<U32>>() // dst
.Ctx<ffi::CollectiveParams>()
.Ctx<ffi::CollectiveCliques>()
.Ctx<ffi::CollectiveMemory>());
XLA_FFI_DEFINE_HANDLER(kDelayedDeviceAllReduce, DelayedDeviceAllReduce,
ffi::Ffi::Bind()
.Ctx<ffi::Stream>()
.Arg<ffi::BufferR0<U32>>() // src
.Ret<ffi::BufferR0<U32>>() // dst
.Ctx<ffi::CollectiveParams>()
.Ctx<ffi::CollectiveCliques>()
.Ctx<ffi::CollectiveMemory>());
XLA_FFI_DEFINE_HANDLER(kPrepareMulticastAllReduce, PrepareMulticastAllReduce,
ffi::Ffi::BindPrepare()
.Arg<ffi::BufferR0<U32>>() // src
.Ret<ffi::BufferR0<U32>>() // dst
.Ctx<ffi::CollectiveParams>()
.Ctx<ffi::CollectiveCliqueRequests>()
.Ctx<ffi::CollectiveMemoryRequests>());
XLA_FFI_DEFINE_HANDLER(kMulticastAllReduce, BlockedMulticastAllReduce,
ffi::Ffi::Bind()
.Ctx<ffi::Stream>()
.Arg<ffi::BufferR0<U32>>() // src
.Ret<ffi::BufferR0<U32>>() // dst
.Ctx<ffi::CollectiveParams>()
.Ctx<ffi::CollectiveMemory>());
XLA_FFI_DEFINE_HANDLER(kDelayedMulticastAllReduce, DelayedMulticastAllReduce,
ffi::Ffi::Bind()
.Ctx<ffi::Stream>()
.Arg<ffi::BufferR0<U32>>() // src
.Ret<ffi::BufferR0<U32>>() // dst
.Ctx<ffi::CollectiveParams>()
.Ctx<ffi::CollectiveMemory>());
XLA_FFI_DEFINE_HANDLER(kPrepareSymMulticastAllReduce,
PrepareSymMulticastAllReduce,
ffi::Ffi::BindPrepare()
.Arg<ffi::BufferR0<U32>>() // src
.Ret<ffi::BufferR0<U32>>() // dst
.Ctx<ffi::CollectiveParams>()
.Ctx<ffi::CollectiveCliqueRequests>()
.Ctx<ffi::CollectiveMemoryRequests>());
XLA_FFI_DEFINE_HANDLER(kSymMulticastAllReduce, BlockedSymMulticastAllReduce,
ffi::Ffi::Bind()
.Ctx<ffi::Stream>()
.Arg<ffi::BufferR0<U32>>() // src
.Ret<ffi::BufferR0<U32>>() // dst
.Ctx<ffi::CollectiveParams>()
.Ctx<ffi::CollectiveMemory>());
XLA_FFI_DEFINE_HANDLER(kDelayedSymMulticastAllReduce,
DelayedSymMulticastAllReduce,
ffi::Ffi::Bind()
.Ctx<ffi::Stream>()
.Arg<ffi::BufferR0<U32>>() // src
.Ret<ffi::BufferR0<U32>>() // dst
.Ctx<ffi::CollectiveParams>()
.Ctx<ffi::CollectiveMemory>());
XLA_FFI_DEFINE_HANDLER(kPrepareSymPeerAllReduce, PrepareSymPeerAllReduce,
ffi::Ffi::BindPrepare()
.Arg<ffi::BufferR0<U32>>() // src
.Ret<ffi::BufferR0<U32>>() // dst
.Ctx<ffi::CollectiveParams>()
.Ctx<ffi::CollectiveCliqueRequests>()
.Ctx<ffi::CollectiveMemoryRequests>());
XLA_FFI_DEFINE_HANDLER(kSymPeerAllReduce, BlockedSymPeerAllReduce,
ffi::Ffi::Bind()
.Ctx<ffi::Stream>()
.Arg<ffi::BufferR0<U32>>() // src
.Ret<ffi::BufferR0<U32>>() // dst
.Ctx<ffi::CollectiveParams>()
.Ctx<ffi::CollectiveMemory>());
XLA_FFI_DEFINE_HANDLER(kDelayedSymPeerAllReduce, DelayedSymPeerAllReduce,
ffi::Ffi::Bind()
.Ctx<ffi::Stream>()
.Arg<ffi::BufferR0<U32>>() // src
.Ret<ffi::BufferR0<U32>>() // dst
.Ctx<ffi::CollectiveParams>()
.Ctx<ffi::CollectiveMemory>());
XLA_FFI_DEFINE_HANDLER(kPreparePeerAllReduce, PreparePeerAllReduce,
ffi::Ffi::BindPrepare()
.Arg<ffi::BufferR0<U32>>() // src
.Ret<ffi::BufferR0<U32>>() // dst
.Ctx<ffi::CollectiveParams>()
.Ctx<ffi::CollectiveMemoryRequests>());
XLA_FFI_DEFINE_HANDLER(kPeerAllReduce, BlockedPeerAllReduce,
ffi::Ffi::Bind()
.Ctx<ffi::Stream>()
.Arg<ffi::BufferR0<U32>>() // src
.Ret<ffi::BufferR0<U32>>() // dst
.Ctx<ffi::CollectiveParams>()
.Ctx<ffi::CollectiveMemory>());
XLA_FFI_DEFINE_HANDLER(kDelayedPeerAllReduce, DelayedPeerAllReduce,
ffi::Ffi::Bind()
.Ctx<ffi::Stream>()
.Arg<ffi::BufferR0<U32>>() // src
.Ret<ffi::BufferR0<U32>>() // dst
.Ctx<ffi::CollectiveParams>()
.Ctx<ffi::CollectiveMemory>());
// Register handler bundle for the custom all-reduce operation.
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$all_reduce", "gpu",
XLA_FFI_Handler_Bundle{
/*instantiate=*/nullptr,
/*prepare=*/kPrepareAllReduce,
/*initialize=*/nullptr,
/*execute=*/kAllReduce,
});
// Register handler bundle for the custom all-reduce operation with
// device-initiated collective kernels that use multimem addresses.
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(),
"__xla_test_blocked_multimem_all_reduce", "gpu",
XLA_FFI_Handler_Bundle{
/*instantiate=*/nullptr,
/*prepare=*/kPrepareMulticastAllReduce,
/*initialize=*/nullptr,
/*execute=*/kMulticastAllReduce,
});
// Register handler bundle for the custom all-reduce operation with
// device-initiated collective kernels that use multimem addresses.
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(),
"__xla_test_delayed_multimem_all_reduce", "gpu",
XLA_FFI_Handler_Bundle{
/*instantiate=*/nullptr,
/*prepare=*/kPrepareMulticastAllReduce,
/*initialize=*/nullptr,
/*execute=*/kDelayedMulticastAllReduce,
});
// Register handler bundle for the custom all-reduce operation with
// device-initiated collective kernels that use multimem addresses.
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(),
"__xla_test_blocked_sym_multimem_all_reduce", "gpu",
XLA_FFI_Handler_Bundle{
/*instantiate=*/nullptr,
/*prepare=*/kPrepareSymMulticastAllReduce,
/*initialize=*/nullptr,
/*execute=*/kSymMulticastAllReduce,
});
// Register handler bundle for the custom all-reduce operation with
// device-initiated collective kernels that use multimem addresses.
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(),
"__xla_test_delayed_sym_multimem_all_reduce", "gpu",
XLA_FFI_Handler_Bundle{
/*instantiate=*/nullptr,
/*prepare=*/kPrepareSymMulticastAllReduce,
/*initialize=*/nullptr,
/*execute=*/kDelayedSymMulticastAllReduce,
});
// Register handler bundle for the custom all-reduce operation with
// device-initiated collective kernels that use peer addresses.
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(),
"__xla_test_blocked_sym_peer_all_reduce", "gpu",
XLA_FFI_Handler_Bundle{
/*instantiate=*/nullptr,
/*prepare=*/kPrepareSymPeerAllReduce,
/*initialize=*/nullptr,
/*execute=*/kSymPeerAllReduce,
});
// Register handler bundle for the custom all-reduce operation with
// device-initiated collective kernels that use peer addresses.
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(),
"__xla_test_delayed_sym_peer_all_reduce", "gpu",
XLA_FFI_Handler_Bundle{
/*instantiate=*/nullptr,
/*prepare=*/kPrepareSymPeerAllReduce,
/*initialize=*/nullptr,
/*execute=*/kDelayedSymPeerAllReduce,
});
// Register handler bundle for the custom all-reduce operation with
// device-initiated collective kernels that use peer addresses.
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(),
"__xla_test_blocked_peer_all_reduce", "gpu",
XLA_FFI_Handler_Bundle{
/*instantiate=*/nullptr,
/*prepare=*/kPreparePeerAllReduce,
/*initialize=*/nullptr,
/*execute=*/kPeerAllReduce,
});
// Register handler bundle for the custom all-reduce operation with
// device-initiated collective kernels that use peer addresses.
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(),
"__xla_test_delayed_peer_all_reduce", "gpu",
XLA_FFI_Handler_Bundle{
/*instantiate=*/nullptr,
/*prepare=*/kPreparePeerAllReduce,
/*initialize=*/nullptr,
/*execute=*/kDelayedPeerAllReduce,
});
// Register handler bundle for the custom all-reduce operation with
// device-initiated collective kernels that use blocked execution.
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(),
"__xla_test_blocked_device_all_reduce", "gpu",
XLA_FFI_Handler_Bundle{
/*instantiate=*/nullptr,
/*prepare=*/kPrepareDeviceAllReduce,
/*initialize=*/nullptr,
/*execute=*/kDeviceAllReduce,
});
// Register handler bundle for the custom all-reduce operation with
// device-initiated collective kernels that use delayed execution.
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(),
"__xla_test_delayed_device_all_reduce", "gpu",
XLA_FFI_Handler_Bundle{
/*instantiate=*/nullptr,
/*prepare=*/kPrepareDeviceAllReduce,
/*initialize=*/nullptr,
/*execute=*/kDelayedDeviceAllReduce,
});
TEST_F(CollectiveOpsTestFFI, AllReduce) {
if (device_count() < kNumReplicas) {
GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices ("
<< device_count() << " available)";
}
if (!IsHopperAndHigher()) {
GTEST_SKIP() << "NCCL symmetric memory requires Hopper+";
}
constexpr absl::string_view hlo_string = R"(
HloModule m, replica_count=2
ENTRY test_computation {
id = u32[] replica-id()
ROOT all-reduce = u32[] custom-call(id),
custom_call_target="__xla_test$$all_reduce",
api_version=API_VERSION_TYPED_FFI
}
)";
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(hlo_string, kNumReplicas));
TF_ASSERT_OK_AND_ASSIGN(
ExecutionResult execution_result,
ExecuteReplicated(std::move(module),
/*arguments=*/std::vector<Literal*>(),
/*run_hlo_passes=*/false));
absl::Span<const Literal> results = execution_result.results;
ASSERT_EQ(results.size(), kNumReplicas);
// sum [0, num_devices)
const uint32_t expected = kNumReplicas * (kNumReplicas - 1) / 2;
for (int i = 0; i < kNumReplicas; ++i) {
LiteralTestUtil::ExpectR0Equal<uint32_t>(expected, results[i]);
}
}
class AllReduceTest : public CollectiveOpsTestFFI,
public ::testing::WithParamInterface<absl::string_view> {
};
TEST_P(AllReduceTest, DeviceAllReduce) {
if (device_count() < kNumReplicas) {
GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices ("
<< device_count() << " available)";
}
if (!IsHopperAndHigher()) {
GTEST_SKIP() << "NCCL symmetric memory requires Hopper+";
}
std::string hlo_string = absl::Substitute(R"(
HloModule m, replica_count=2
ENTRY test_computation {
id = u32[] replica-id()
in = u32[]{:S(1)} copy(id)
all-reduce = u32[]{:S(1)} custom-call(in),
custom_call_target="__xla_test_$0_device_all_reduce",
api_version=API_VERSION_TYPED_FFI
ROOT out = u32[] copy(all-reduce)
}
)",
GetParam());
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(hlo_string, kNumReplicas));
TF_ASSERT_OK_AND_ASSIGN(
ExecutionResult execution_result,
ExecuteReplicated(std::move(module),
/*arguments=*/std::vector<Literal*>(),
/*run_hlo_passes=*/false));
SynchronizationSignals* signals = global_signals->get();
signals->finished_kernels_counter.Wait();
absl::Span<const Literal> results = execution_result.results;
ASSERT_EQ(results.size(), kNumReplicas);
// sum [0, num_devices)
const uint32_t expected = kNumReplicas * (kNumReplicas - 1) / 2;
for (int i = 0; i < kNumReplicas; ++i) {
LiteralTestUtil::ExpectR0Equal<uint32_t>(expected, results[i]);
}
}
TEST_P(AllReduceTest, PeerAllReduce) {
if (device_count() < kNumReplicas) {
GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices ("
<< device_count() << " available)";
}
if (!IsHopperAndHigher()) {
GTEST_SKIP() << "Test requires Hopper+ since on a previous platforms there "
"are no guarantees that GPUs have direct peer access";
}
std::string hlo_string = absl::Substitute(R"(
HloModule m, replica_count=2
ENTRY test_computation {
id = u32[] replica-id()
ROOT all-reduce = u32[] custom-call(id),
custom_call_target="__xla_test_$0_peer_all_reduce",
api_version=API_VERSION_TYPED_FFI
}
)",
GetParam());
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(hlo_string, kNumReplicas));
TF_ASSERT_OK_AND_ASSIGN(
ExecutionResult execution_result,
ExecuteReplicated(std::move(module),
/*arguments=*/std::vector<Literal*>(),
/*run_hlo_passes=*/false));
SynchronizationSignals* signals = global_signals->get();
signals->finished_kernels_counter.Wait();
absl::Span<const Literal> results = execution_result.results;
ASSERT_EQ(results.size(), kNumReplicas);
// sum [0, num_devices)
const uint32_t expected = kNumReplicas * (kNumReplicas - 1) / 2;
for (int i = 0; i < kNumReplicas; ++i) {
LiteralTestUtil::ExpectR0Equal<uint32_t>(expected, results[i]);
}
}
TEST_P(AllReduceTest, MulticastAllReduce) {
if (device_count() < kNumReplicas) {
GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices ("
<< device_count() << " available)";
}
if (!IsHopperAndHigher()) {
GTEST_SKIP() << "Test requires Hopper+";
}
std::string hlo_string = absl::Substitute(R"(
HloModule m, replica_count=2
ENTRY test_computation {
c0 = u32[] constant(1)
in = u32[]{:S(1)} copy(c0)
all-reduce = u32[] custom-call(in),
custom_call_target="__xla_test_$0_multimem_all_reduce",
api_version=API_VERSION_TYPED_FFI
ROOT out = u32[] copy(all-reduce)
}
)",
GetParam());
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(hlo_string, kNumReplicas));
TF_ASSERT_OK_AND_ASSIGN(
ExecutionResult execution_result,
ExecuteReplicated(std::move(module),
/*arguments=*/std::vector<Literal*>(),
/*run_hlo_passes=*/false));
SynchronizationSignals* signals = global_signals->get();
signals->finished_kernels_counter.Wait();
absl::Span<const Literal> results = execution_result.results;
ASSERT_EQ(results.size(), kNumReplicas);
const uint32_t expected = 2;
for (int i = 0; i < kNumReplicas; ++i) {
LiteralTestUtil::ExpectR0Equal<uint32_t>(expected, results[i]);
}
}
TEST_P(AllReduceTest, SymMulticastAllReduce) {
if (device_count() < kNumReplicas) {
GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices ("
<< device_count() << " available)";
}
if (!IsHopperAndHigher()) {
GTEST_SKIP() << "Test requires Hopper+";
}
std::string hlo_string = absl::Substitute(R"(
HloModule m, replica_count=2
ENTRY test_computation {
c0 = u32[] constant(1)
in = u32[]{:S(1)} copy(c0)
all-reduce = u32[] custom-call(in),
custom_call_target="__xla_test_$0_sym_multimem_all_reduce",
api_version=API_VERSION_TYPED_FFI
ROOT out = u32[] copy(all-reduce)
}
)",
GetParam());
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(hlo_string, kNumReplicas));
TF_ASSERT_OK_AND_ASSIGN(
ExecutionResult execution_result,
ExecuteReplicated(std::move(module),
/*arguments=*/std::vector<Literal*>(),
/*run_hlo_passes=*/false));
SynchronizationSignals* signals = global_signals->get();
signals->finished_kernels_counter.Wait();
absl::Span<const Literal> results = execution_result.results;
ASSERT_EQ(results.size(), kNumReplicas);
const uint32_t expected = 2;
for (int i = 0; i < kNumReplicas; ++i) {
LiteralTestUtil::ExpectR0Equal<uint32_t>(expected, results[i]);
}
}
TEST_P(AllReduceTest, SymPeerAllReduce) {
if (device_count() < kNumReplicas) {
GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices ("
<< device_count() << " available)";
}
if (!IsHopperAndHigher()) {
GTEST_SKIP() << "Test requires Hopper+ since on a previous platforms there "
"are no guarantees that GPUs have direct peer access";
}
std::string hlo_string = absl::Substitute(R"(
HloModule m, replica_count=2
ENTRY test_computation {
id = u32[] replica-id()
in = u32[]{:S(1)} copy(id)
all-reduce = u32[] custom-call(in),
custom_call_target="__xla_test_$0_sym_peer_all_reduce",
api_version=API_VERSION_TYPED_FFI
ROOT out = u32[] copy(all-reduce)
}
)",
GetParam());
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(hlo_string, kNumReplicas));
TF_ASSERT_OK_AND_ASSIGN(
ExecutionResult execution_result,
ExecuteReplicated(std::move(module),
/*arguments=*/std::vector<Literal*>(),
/*run_hlo_passes=*/false));
SynchronizationSignals* signals = global_signals->get();
signals->finished_kernels_counter.Wait();
absl::Span<const Literal> results = execution_result.results;
ASSERT_EQ(results.size(), kNumReplicas);
// sum [0, num_devices)
const uint32_t expected = kNumReplicas * (kNumReplicas - 1) / 2;
for (int i = 0; i < kNumReplicas; ++i) {
LiteralTestUtil::ExpectR0Equal<uint32_t>(expected, results[i]);
}
}
INSTANTIATE_TEST_SUITE_P(
AllReduceTests, AllReduceTest, Values("blocked", "delayed"),
[](const ::testing::TestParamInfo<absl::string_view>& info) {
return std::string(info.param);
});
// Same as DeviceAllReduce, but uses frontend_attributes to specify memory
// spaces instead of hardcoded S(1).
TEST_F(CollectiveOpsTestFFI, DeviceAllReduceWithFrontendAttributes) {
if (device_count() < kNumReplicas) {
GTEST_SKIP() << "Test requires at least " << kNumReplicas << " devices ("
<< device_count() << " available)";
}
if (!IsHopperAndHigher()) {
GTEST_SKIP() << "NCCL symmetric memory requires Hopper+";
}
constexpr absl::string_view hlo_string = R"(
HloModule m, replica_count=2
ENTRY test_computation {
id = u32[] replica-id()
all-reduce = u32[] custom-call(id),
custom_call_target="__xla_test_blocked_device_all_reduce",
api_version=API_VERSION_TYPED_FFI,
frontend_attributes={
operands_memory_spaces="{0:1}",
results_memory_spaces="{0:1}"
}
ROOT out = u32[] copy(all-reduce)
}
)";
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(hlo_string, kNumReplicas));
TF_ASSERT_OK_AND_ASSIGN(
ExecutionResult execution_result,
ExecuteReplicated(std::move(module),
/*arguments=*/std::vector<Literal*>(),
/*run_hlo_passes=*/true));
SynchronizationSignals* signals = global_signals->get();
signals->finished_kernels_counter.Wait();
absl::Span<const Literal> results = execution_result.results;
ASSERT_EQ(results.size(), kNumReplicas);
// sum [0, num_devices)
const uint32_t expected = kNumReplicas * (kNumReplicas - 1) / 2;
for (int i = 0; i < kNumReplicas; ++i) {
LiteralTestUtil::ExpectR0Equal<uint32_t>(expected, results[i]);
}
}
} // namespace xla::gpu
// SOURCE: openxla/xla @ 9190ab7d75ec35933e7cf1ed375ca6d08279e805 (jax-toolbox 2026-06-02, jax/jaxlib 0.10.2.dev20260602)
// Path in tree: xla/backends/gpu/collectives/nccl_symmetric_memory.cc
// Provided here verbatim so the gist matches exactly what run_symmetric_test.sh compiles.
#include "xla/backends/gpu/collectives/nccl_symmetric_memory.h"
#include <memory>
#include <string>
#include "absl/log/log.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "xla/backends/gpu/collectives/nccl_errors.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/stream_executor/device_address.h"
// Include NCCL after XLA headers.
#include "third_party/nccl/nccl.h"
#include "third_party/nccl/nccl_device.h"
namespace xla::gpu {
NcclSymmetricMemory::NcclSymmetricMemory(
ncclComm_t comm, ncclWindow_t win, stream_executor::DeviceAddressBase addr)
: comm_(comm), win_(win), addr_(addr) {}
absl::StatusOr<std::unique_ptr<NcclSymmetricMemory>>
NcclSymmetricMemory::Create(ncclComm_t comm,
stream_executor::DeviceAddressBase addr) {
VLOG(3) << absl::StrFormat(
"Create NCCL symmetric memory on comm=%p from: ptr=%p; size=%ld", comm,
addr.opaque(), addr.size());
ncclWindow_t win;
XLA_NCCL_RETURN_IF_ERROR(ncclCommWindowRegister(
comm, addr.opaque(), addr.size(), &win, NCCL_WIN_COLL_SYMMETRIC));
return absl::WrapUnique(new NcclSymmetricMemory(comm, win, addr));
}
NcclSymmetricMemory::~NcclSymmetricMemory() {
VLOG(3) << absl::StrFormat("Destroy %v", *this);
XLA_NCCL_LOG_IF_ERROR(ncclCommWindowDeregister(comm_, win_));
}
stream_executor::DeviceAddressBase NcclSymmetricMemory::addr() const {
return addr_;
}
absl::StatusOr<stream_executor::DeviceAddressBase>
NcclSymmetricMemory::multimem_addr() const {
#if (NCCL_VERSION_CODE >= 22900) || defined(USE_NCCL_HOST_API)
void* multimem = nullptr;
XLA_NCCL_RETURN_IF_ERROR(ncclGetLsaMultimemDevicePointer(win_, 0, &multimem));
if (multimem) {
return stream_executor::DeviceAddressBase(multimem, addr_.size());
}
#endif
return absl::UnimplementedError(
"Multimem not supported on this NCCL version or device");
}
absl::StatusOr<stream_executor::DeviceAddressBase>
NcclSymmetricMemory::peer_addr(RankId peer) const {
#if (NCCL_VERSION_CODE >= 22900) || defined(USE_NCCL_HOST_API)
void* peer_addr = nullptr;
XLA_NCCL_RETURN_IF_ERROR(
ncclGetLsaDevicePointer(win_, 0, peer.value(), &peer_addr));
if (peer_addr) {
return stream_executor::DeviceAddressBase(peer_addr, addr_.size());
}
#endif
return absl::UnimplementedError(
"Peer address not supported on this NCCL version or device");
}
std::string NcclSymmetricMemory::ToString() const {
return absl::StrFormat(
"NcclSymmetricMemory(comm=%p, win=%p, ptr=%p, size=%ld)", comm_, win_,
addr_.opaque(), addr_.size());
}
NcclSymmetricMemory::PackedKernelArg NcclSymmetricMemory::PackKernelArg()
const {
return win_;
}
} // namespace xla::gpu
// SOURCE: openxla/xla @ 9190ab7d75ec35933e7cf1ed375ca6d08279e805 (jax-toolbox 2026-06-02, jax/jaxlib 0.10.2.dev20260602)
// Path in tree: xla/backends/gpu/collectives/nccl_symmetric_memory.h
// Provided here verbatim so the gist matches exactly what run_symmetric_test.sh compiles.
#ifndef XLA_BACKENDS_GPU_COLLECTIVES_NCCL_SYMMETRIC_MEMORY_H_
#define XLA_BACKENDS_GPU_COLLECTIVES_NCCL_SYMMETRIC_MEMORY_H_
#include <memory>
#include <string>
#include "absl/status/statusor.h"
#include "third_party/nccl/nccl.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/core/collectives/symmetric_memory.h"
#include "xla/stream_executor/device_address.h"
namespace xla::gpu {
// A NCCL window registration handle that makes local buffers accessible from
// remote peers via symmetric memory registration process.
class NcclSymmetricMemory final : public SymmetricMemory {
public:
~NcclSymmetricMemory() final;
static absl::StatusOr<std::unique_ptr<NcclSymmetricMemory>> Create(
ncclComm_t comm, stream_executor::DeviceAddressBase addr);
stream_executor::DeviceAddressBase addr() const final;
absl::StatusOr<stream_executor::DeviceAddressBase> multimem_addr()
const final;
absl::StatusOr<stream_executor::DeviceAddressBase> peer_addr(
RankId peer) const final;
ncclWindow_t win() const { return win_; }
std::string ToString() const final;
PackedKernelArg PackKernelArg() const final;
private:
NcclSymmetricMemory(ncclComm_t comm, ncclWindow_t win,
stream_executor::DeviceAddressBase addr);
ncclComm_t comm_;
ncclWindow_t win_;
stream_executor::DeviceAddressBase addr_;
};
} // namespace xla::gpu
#endif // XLA_BACKENDS_GPU_COLLECTIVES_NCCL_SYMMETRIC_MEMORY_H_
# Verified run — manual NCCL symmetric registration
# image: ghcr.io/nvidia/jax:maxtext-2026-06-02 (jax/jaxlib 0.10.2.dev20260602, XLA 9190ab7d75)
# GPUs: 2x NVIDIA RTX PRO 6000 Blackwell (sm_120), binary: collective_ops_ffi_test_rtx6000pro
# build: BUILD_EXIT=0 (cold cache, ~25 min, hermetic NCCL 2.29.7)
## gtest summary
[ RUN ] CollectiveOpsTestFFI.DeviceAllReduceWithFrontendAttributes
[ OK ] CollectiveOpsTestFFI.DeviceAllReduceWithFrontendAttributes (1468 ms)
[ RUN ] AllReduceTests/AllReduceTest.DeviceAllReduce/blocked
[ OK ] AllReduceTests/AllReduceTest.DeviceAllReduce/blocked (31 ms)
[ RUN ] AllReduceTests/AllReduceTest.DeviceAllReduce/delayed
[ OK ] AllReduceTests/AllReduceTest.DeviceAllReduce/delayed (1026 ms)
[==========] 3 tests from 2 test suites ran. (2526 ms total)
[ PASSED ] 3 tests.
## manual registration evidence (NO auto flag set)
I0000 00:00:1780432258.017648 200 nccl_symmetric_memory.cc:43] Create NCCL symmetric memory on comm=0x7303a400a280 from: ptr=0x404000000; size=2097152
I0000 00:00:1780432258.017662 196 nccl_symmetric_memory.cc:43] Create NCCL symmetric memory on comm=0x7303ac009c90 from: ptr=0x402000000; size=2097152
a2722007bad8:1:273 [1] NCCL INFO Symmetric VA size=96GB
a2722007bad8:1:272 [0] NCCL INFO Symmetric VA size=96GB
a2722007bad8:1:273 [1] NCCL INFO Symmetric VA size=96GB
a2722007bad8:1:272 [0] NCCL INFO Symmetric VA size=96GB
a2722007bad8:1:200 [1] NCCL INFO register comm 0x7303a400a280 buffer 0x404000000 size 2097152
a2722007bad8:1:196 [0] NCCL INFO register comm 0x7303ac009c90 buffer 0x402000000 size 2097152
I0000 00:00:1780432258.203235 290 nccl_symmetric_memory.cc:43] Create NCCL symmetric memory on comm=0x7303a400a280 from: ptr=0x730748200000; size=2097152
I0000 00:00:1780432258.203239 286 nccl_symmetric_memory.cc:43] Create NCCL symmetric memory on comm=0x7303ac009c90 from: ptr=0x730746200000; size=2097152
a2722007bad8:1:290 [1] NCCL INFO register comm 0x7303a400a280 buffer 0x730748200000 size 2097152
a2722007bad8:1:286 [0] NCCL INFO register comm 0x7303ac009c90 buffer 0x730746200000 size 2097152
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment