#nvidia #cuda #gpu #arch #metal #rocm #triton
- Tensor Cores and Tensor Memory Accelerator
- Split Barriers (Ampere and above)
- Thread Block Cluster (Hopper and above)
- Producer-Consumer Queues
Global Memory (8192 × 8192):
┌─────────────────────────────────────┐
│ [0,0] [0,1] ... [0,8191] │
│ [1,0] [1,1] ... [1,8191] │
│ ... ... ... ... │
│[8191,0][8191,1] ... [8191,8191] │
└─────────────────────────────────────┘
TMA copies 64×64 tiles:
┌───────┬───────┬─────┬───────┐
│ Tile │ Tile │ ... │ Tile │
│ (0,0) │ (0,1) │ │(0,127)│
├───────┼───────┼─────┼───────┤
│ Tile │ Tile │ ... │ Tile │
│ (1,0) │ (1,1) │ │(1,127)│
├───────┼───────┼─────┼───────┤
│ ... │ ... │ ... │ ... │
├───────┼───────┼─────┼───────┤
│ Tile │ Tile │ ... │ Tile │
│(127,0)│(127,1)│ │(127,127)│
└───────┴───────┴─────┴───────┘- Whole matrix (8192×8192) is split into tiles with dimensions 64×64 → In total, we have 128×128 = 16,384 tiles.
- One CUDA block computes one 64×64 tile of C and then writes it out.
-
K is processed in chunks of 64 (BK=64). Since K=8192, that’s 8192/64 = 128 iterations.
-
In each iteration:
- Load the current 64×64 slices from A and B into shared memory.
- Run 4 tensor-core ops of size
m64n64k16(because 4×16 = 64 along K). - Each op accumulates into the same register tile
d[ ][ ](per-thread registers).
-
Registers
d[ ][ ]start at zero for this tile and keep accumulating across all 128 iterations. -
After the 128 iterations, the register tile holds the final 64×64 results and is written to C (column‑major).
-
Note: each thread holds 32 accumulator registers; 32 × 128 = 4096 = 64 × 64.
- There are 128 tiles across M and 128 tiles across N → launch 128×128 blocks.
- Each block independently:
- accumulates its 64×64 results over 128 K-steps via 4 WGMMA calls per step,
- stores its finished 64×64 tile to C.
- Per 64×64 tile: 128 K-steps; each step issues
4× wgmma.m64n64k16; all accumulate in registers. - Whole matrix: 16,384 such tiles; each computed by one block and stored to C.
- TMA is a new piece of hardware introduced in the
Hopperarchitecture— a faster way to load tiles of multi-dimensional matrices between GMEM and SMEM. - For example—It’s essentially used to tell the GPU: Here's how to efficiently slice this big matrix into 64×64 chunks and copy them to shared memory.
- TMA loads directly support the swizzling patterns required by tensor cores.
- TMA takes in a tiling configuration for a given matrix, and can load any requested tile into SMEM.
- We’ll need to create a
TMA map descriptorfor efficient data movement between global memory and shared memory on modern CUDA GPUs.
template <int BlockMajorSize, int BlockMinorSize>
__host__ static inline CUtensorMap* allocate_and_create_tensor_map(bf16* src, int blocks_height, int blocks_width) {
CUtensorMap *tma_map_d;
cudaMalloc(&tma_map_d, sizeof(CUtensorMap));
CUtensorMap tma_map_host;
create_tensor_map<BlockMajorSize, BlockMinorSize>(&tma_map_host, src, blocks_height, blocks_width);
cudaMemcpy(tma_map_d, &tma_map_host, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
return tma_map_d;
}
template <int BlockMajorSize, int BlockMinorSize>
void create_tensor_map(CUtensorMap *tma_map, bf16* gmem_ptr, int blocks_height, int blocks_width) {
void* gmem_address = (void*)gmem_ptr;
uint64_t gmem_prob_shape[5] = {(uint64_t)BlockMinorSize*blocks_width, (uint64_t)BlockMajorSize*blocks_height, 1, 1, 1};
uint64_t gmem_prob_stride[5] = {sizeof(bf16), sizeof(bf16) * BlockMinorSize*blocks_width, 0, 0, 0};
uint32_t smem_box_shape[5] = {uint32_t(BlockMinorSize), uint32_t(BlockMajorSize), 1, 1, 1};
uint32_t smem_box_stride[5] = {1, 1, 1, 1, 1};
CUresult result = cuTensorMapEncodeTiled(
tma_map, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 2, gmem_address, gmem_prob_shape,
gmem_prob_stride + 1, smem_box_shape, smem_box_stride, CU_TENSOR_MAP_INTERLEAVE_NONE,
CU_TENSOR_MAP_SWIZZLE_128B, CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
assert(result == CUDA_SUCCESS);
}This function sets up a hardware-accelerated memory copy descriptor that tells the GPU how to efficiently move data from global memory to shared memory using TMA instructions.
BlockMajorSize = 64BlockMinorSize = 64blocks_width = 128blocks_height = 128
uint64_t gmem_prob_shape[5] = {
(uint64_t)BlockMinorSize*blocks_width, // 64 * 128 = 8192
(uint64_t)BlockMajorSize*blocks_height, // 64 * 128 = 8192
1, 1, 1
};This describes the total tensor dimensions in global memory:
- Dimension 0: Width = 8192 elements
- Dimension 1: Height = 8192 elements
- Dimensions 2-4: Unused (set to 1)
So we have an 8192 × 8192 matrix in global memory.
uint64_t gmem_prob_stride[5] = {
sizeof(bf16), // 2 bytes
sizeof(bf16) * BlockMinorSize*blocks_width, // 2 * 8192 = 16384 bytes
0, 0, 0
};This tells how to navigate through global memory:
- Stride 0: 2 bytes to next element in same row
- Stride 1: 16384 bytes to next row (skip entire row width)
- Strides 2-4: Unused
uint32_t smem_box_shape[5] = {
uint32_t(BlockMinorSize), // 64
uint32_t(BlockMajorSize), // 64
1, 1, 1
};This describes the tile size we copy to shared memory:
- 64 × 64 tile of bf16 elements
- Each TMA operation copies this much data
uint32_t smem_box_stride[5] = {1, 1, 1, 1, 1};This is always {1, 1, 1, 1, 1} for contiguous shared memory layout.
CUresult result = cuTensorMapEncodeTiled(
tma_map, // Output: the TMA descriptor
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, // Data type: bf16
2, // 2D tensor
gmem_address, // Global memory base address
gmem_prob_shape, // Global tensor dimensions
gmem_prob_stride + 1, // Skip first stride element
smem_box_shape, // Tile size for shared memory
smem_box_stride, // Shared memory layout
CU_TENSOR_MAP_INTERLEAVE_NONE, // No interleaving
CU_TENSOR_MAP_SWIZZLE_128B, // 128-byte swizzling for bank conflicts
CU_TENSOR_MAP_L2_PROMOTION_NONE, // No L2 cache promotion
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE // No out-of-bounds fill
);Global Memory (8192 × 8192):
┌─────────────────────────────────────┐
│ [0,0] [0,1] ... [0,8191] │
│ [1,0] [1,1] ... [1,8191] │
│ ... ... ... ... │
│[8191,0][8191,1] ... [8191,8191] │
└─────────────────────────────────────┘
TMA copies 64×64 tiles:
┌───────┬───────┬─────┬───────┐
│ Tile │ Tile │ ... │ Tile │
│ (0,0) │ (0,1) │ │(0,127)│
├───────┼───────┼─────┼───────┤
│ Tile │ Tile │ ... │ Tile │
│ (1,0) │ (1,1) │ │(1,127)│
├───────┼───────┼─────┼───────┤
│ ... │ ... │ ... │ ... │
├───────┼───────┼─────┼───────┤
│ Tile │ Tile │ ... │ Tile │
│(127,0)│(127,1)│ │(127,127)│
└───────┴───────┴─────┴───────┘This skips the first stride element, effectively using:
{16384, 0, 0, 0} // Instead of {2, 16384, 0, 0, 0}This is a common pattern for 2D tensors.
Applies 128-byte swizzling to avoid shared memory bank conflicts when accessing the data.
After creating this TMA map, you can use it in kernels:
// Copy a 64×64 tile from global memory to shared memory
asm volatile (
"cp.async.bulk.tensor.3d.shared::cluster.global.tile.mbarrier::complete_tx::bytes"
" [%0], [%1, {%3, %4, %5}], [%2];"
:
: "r"(dst_ptr), "l"(tma_ptr), "r"(mbar_ptr),
"n"(0), "r"(global_row_idx), "r"(global_col_idx/64)
: "memory"
);
// Note:
// • One TMA instruction = One 64×64 tile (8KB)
// • 8192×8192 matrix = 16,384 tiles total
// • Need 16,384 TMA instructions to copy entire matrix
// • Typical usage: Copy tiles one-by-one as needed for computation
// • The power of TMA isn't copying everything at once - it's efficiently copying exactly what you need, when you need it!8192×8192 Global Matrix (128×128 tiles):
┌─────┬─────┬─────┬─────┬─────┐
│(0,0)│(0,1)│(0,2)│ ... │(0,127)│
├─────┼─────┼─────┼─────┼─────┤
│(1,0)│(1,1)│(1,2)│ ... │(1,127)│ ← This instruction copies
├─────┼─────┼─────┼─────┼─────┤ ONE of these tiles
│(2,0)│(2,1)│(2,2)│ ... │(2,127)│
├─────┼─────┼─────┼─────┼─────┤
│ ... │ ... │ ... │ ... │ ... │
├─────┼─────┼─────┼─────┼─────┤
│(127,0)│(127,1)│... │(127,127)│
└─────┴─────┴─────┴─────┴─────┘
Single TMA instruction: Copies tile (global_row_idx, global_col_idx/64)To copy multiple tiles, we usually do
// Copy tiles for matrix multiplication
for (int k = 0; k < num_k_tiles; k++) {
// Copy one A tile
asm volatile("cp.async.bulk.tensor.2d..."
:: "r"(&sA[0]), "l"(tma_map_A), "r"(k), "r"(block_row));
// Copy one B tile
asm volatile("cp.async.bulk.tensor.2d..."
:: "r"(&sB[0]), "l"(tma_map_B), "r"(k), "r"(block_col));
// Wait for copies to complete
barrier.wait();
// Do computation on these tiles
wgmma(...);
}This function creates a hardware descriptor that enables:
- Fast bulk copies of 64×64 tiles from an 8192×8192 global matrix
- Automatic address calculation for different tile positions
- Optimized memory access patterns with swizzling
- Hardware-accelerated transfers using TMA instructions
- __syncthreads() / bar.sync:
- A monolithic barrier: all threads in the block must reach it before any proceed.
- If some threads arrive much earlier, they must stall until the last one arrives.
- Coarse-grained → limits overlap of independent work.
- Introduced
barrier.arriveandbarrier.waitprimitives. - Benefits:
- Decoupled arrival & waiting → a warp can “signal” it’s done with some work (arrive), then continue with independent tasks before later executing wait.
- Better overlap of computation & synchronization → allows useful work to be done instead of idling at the barrier.
- Finer-grained control → synchronization can happen at warp-level rather than full block-level.
- Improved latency hiding → especially important for async copies (cp.async) into shared memory, where producer warps can continue issuing loads while consumer warps wait only when needed.
asm volatile ("mbarrier.init.shared.b64 [%0], %1;" // Initializes a hardware barrier in shared memory
// Converts a C++ object address to a shared memory address
:: "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(&__b->__barrier))),
"r"(static_cast<_CUDA_VSTD::uint32_t>(__expected)) // Sets the expected arrival count
: "memory");
// Breaking Down the Syntax
// 1. asm volatile
// • asm = Inline assembly block
// • volatile = "Don't optimize this away, it has side effects"
// 2. The Instruction: "mbarrier.init.shared.b64 [%0], %1;"
// • mbarrier.init = Initialize a memory barrier
// • shared = Located in shared memory
// • b64 = 64-bit barrier object
// • [%0] = Memory address (first parameter)
// • %1 = Expected count (second parameter)
// 3. Input Constraints: :: "r"(...), "r"(...)
// The :: separates outputs (none) from inputs:
// • First "r"(...) → %0 (register constraint)
// • Second "r"(...) → %1 (register constraint)
// 4. Clobber List: : "memory"
// • "memory" = This instruction may modify memory, so don't cache memory values across it
// Enables hardware-accelerated synchronization for all threads in the block
// It's essentially telling the GPU: This memory location is now a hardware barrier that expects N threads to synchronize on it!The b64 tells us this is a 64-bit barrier object in shared memory.
About the scope:
- All threads in the thread block can access the same shared memory
- All threads see the same 64-bit barrier value
- Hardware recognizes this specific memory location as a barrier
[63:48] Phase + Hardware State
[47:32] Expected Arrival Count
[31:0 ] Current Arrival Count
__shared__ uint64_t barrier; // The special 64-bit value
// Thread 0 initializes
if (threadIdx.x == 0) {
asm volatile ("mbarrier.init.shared.b64 [%0], %1;"
:: "r"(&barrier), "r"(blockDim.x));
}
__syncthreads();
// ALL threads arrive at barrier
asm volatile ("mbarrier.arrive.shared.b64 [%0];" :: "r"(&barrier));
// ALL threads wait for completion
asm volatile ("mbarrier.test_wait.shared.b64 [%0];" :: "r"(&barrier));- Memory Controller Recognition: The SM's memory controller knows this address contains a barrier
- Atomic Operations: Hardware can atomically increment arrival count
- Efficient Waiting: Threads don't busy-wait, they can be suspended
- Cache Coherency: All threads see consistent state instantly
// Thread arrives - hardware atomically does:
// barrier[31:0]++; // Increment arrival count
// if (arrival_count == expected_count) {
// barrier[63:48] ^= 1; // Flip phase bit - releases all waiters
// }Initial state: [Phase=0][Expected=4][Current=0]
Thread 0 arrives: [Phase=0][Expected=4][Current=1]
Thread 1 arrives: [Phase=0][Expected=4][Current=2]
Thread 2 arrives: [Phase=0][Expected=4][Current=3]
Thread 3 arrives: [Phase=1][Expected=4][Current=4] ← Phase flips, all threads released!
- It's just a 64-bit integer in shared memory
- But the hardware treats it specially when you use
mbarrier.*instructions - All threads in the block can interact with the same barrier
- Hardware ensures atomicity and efficient synchronization
The "magic" isn't in the memory itself - it's in the specialized hardware instructions (mbarrier.arrive, mbarrier.test_wait, etc.) that know how to interpret and manipulate this 64-bit value as a synchronization primitive!
So a barrier is essentially a hardware-accelerated 64-bit counter with special semantics that all threads in the SM can safely and efficiently synchronize on.
We perform 4 WGMMA operations on the loaded tiles
- Each
wgmma64performs: 64×16 @ 16×64 → 64×64 - sA: 64×16 slice (64 rows, 16 columns)
- sB: 16×64 slice (16 rows, 64 columns)
- Result: 64×64 accumulation into
d
wgmma64<1, 1, 1, 0, 0>(d, &sA[0], &sB[0]); // K=0:15
wgmma64<1, 1, 1, 0, 0>(d, &sA[WGMMA_K], &sB[WGMMA_K]); // K=16:31
wgmma64<1, 1, 1, 0, 0>(d, &sA[2*WGMMA_K], &sB[2*WGMMA_K]); // K=32:47
wgmma64<1, 1, 1, 0, 0>(d, &sA[3*WGMMA_K], &sB[3*WGMMA_K]); // K=48:63Where WGMMA_K = 16.
sA (64×64):
┌─────┬─────┬─────┬─────┐
│64×16│64×16│64×16│64×16│
│ 0 │ 16 │ 32 │ 48 │
│ │ │ │ │
│ │ │ │ │
│ │ │ │ │
└─────┴─────┴─────┴─────┘
sB (64×64):
┌─────────────────────────┐ ← 16×64 (slice 0)
├─────────────────────────┤ ← 16×64 (slice 16)
├─────────────────────────┤ ← 16×64 (slice 32)
└─────────────────────────┘ ← 16×64 (slice 48)
wgmma64<1, 1, 1, 0, 0>(d, &sA[0], &sB[0]); // sA[64×16₀] × sB[16₀×64]
wgmma64<1, 1, 1, 0, 0>(d, &sA[WGMMA_K], &sB[WGMMA_K]); // sA[64×16₁] × sB[16₁×64]
wgmma64<1, 1, 1, 0, 0>(d, &sA[2*WGMMA_K], &sB[2*WGMMA_K]); // sA[64×16₂] × sB[16₂×64]
wgmma64<1, 1, 1, 0, 0>(d, &sA[3*WGMMA_K], &sB[3*WGMMA_K]); // sA[64×16₃] × sB[16₃×64]Note: For sA, the first wgmma instruction
wgmma64<1, 1, 1, 0, 0>(d, &sA[0], &sB[0])automatically picks a 64×16 slice from the 64×64 tile in row-major form:
Row 0: elements 0-15
Row 1: elements 64-79
Row 2: elements 128-143
Row 3: elements 192-207
...
Row 63: elements 4032-4047
This happens because:
- sA is laid out in SMEM as a 64×64 tile with K (BK=64) as the contiguous dimension
- The wgmma instruction expects A data in a specific
swizzled pattern that TMA automatically provides - When you pass
&sA[0], the tensor core hardware reads the first 16 columns (K=0 to K=15) across all 64 rows
The pattern continues for subsequent calls:
&sA[16]: columns 16-31 across all 64 rows&sA[32]: columns 32-47 across all 64 rows&sA[48]: columns 48-63 across all 64 rows
So, each wgmma call processes one vertical 64×16 "stripe" of the sA tile, and the +16 offset selects which stripe along the K dimension.
Each operation multiplies:
- sA slice: 64×16 (64 rows, 16 columns)
- sB slice: 16×64 (16 rows, 64 columns)
- Result: 64×64 accumulated into
d
C[64×64] = A[64×16₀] × B[16₀×64] + // K-slice 0:15
A[64×16₁] × B[16₁×64] + // K-slice 16:31
A[64×16₂] × B[16₂×64] + // K-slice 32:47
A[64×16₃] × B[16₃×64] // K-slice 48:63There are exactly 4 tiles in each of sA and sB, and sB should indeed be 16×64 for each slice.
// Operation 1: K-dimension 0:15
sA[0:63, 0:15] × sB[0:15, 0:63] → d[0:63, 0:63]
// Operation 2: K-dimension 16:31
sA[0:63, 16:31] × sB[16:31, 0:63] → d[0:63, 0:63] (accumulate)
// Operation 3: K-dimension 32:47
sA[0:63, 32:47] × sB[32:47, 0:63] → d[0:63, 0:63] (accumulate)
// Operation 4: K-dimension 48:63
sA[0:63, 48:63] × sB[48:63, 0:63] → d[0:63, 0:63] (accumulate)- WGMMA instruction can only handle K-dimension of 16 at a time
- Our tile has K=64, so we need 64/16 = 4 operations
C[64×64] = A[64×64] × B[64×64]
// Broken down by K-dimension:
C = A[64×0:15] × B[0:15×64] + // First 16 K-elements
A[64×16:31] × B[16:31×64] + // Next 16 K-elements
A[64×32:47] × B[32:47×64] + // Next 16 K-elements
A[64×48:63] × B[48:63×64] // Last 16 K-elements
&sA[0] // Points to sA[0:63, 0:15]
&sA[WGMMA_K] // Points to sA[0:63, 16:31] (skip 16 columns)
&sA[2*WGMMA_K] // Points to sA[0:63, 32:47] (skip 32 columns)
&sA[3*WGMMA_K] // Points to sA[0:63, 48:63] (skip 48 columns)&sB[0] // Points to sB[0:15, 0:63]
&sB[WGMMA_K] // Points to sB[16:31, 0:63] (skip 16 rows)
&sB[2*WGMMA_K] // Points to sB[32:47, 0:63] (skip 32 rows)
&sB[3*WGMMA_K] // Points to sB[48:63, 0:63] (skip 48 rows)- Load: One 64×64 tile each for A and B into shared memory
- Compute: 4 WGMMA operations, each handling 16 elements of the K-dimension
- Result: Complete 64×64 matrix multiplication accumulated in register
d - Subdivision: Along the K-dimension (inner dimension), not the output dimensions
So we’re processing the same 64×64 output tile but breaking down the inner K-dimension into 4 chunks of 16 elements each!
- Launch: 16,384 blocks (128×128 tiles), each with 128 threads = 4 warps.
- Each thread accumulates 32 results (in d[ ][ ]) and stores those 32 into its
block’s 64×64 tile of C(not the whole tile per thread).
// Store
{
int tid = threadIdx.x;
int lane = tid % 32;
int warp = tid / 32;
uint32_t row = warp*16 + lane / 4;
bf16 *block_C = C + num_block_n*BN*M + num_block_m*BM;
for (int m_it = 0; m_it < BM/WGMMA_M; ++m_it) {
for (int n_it = 0; n_it < BN/WGMMA_N; ++n_it) {
for (int w = 0; w < WGMMA_N/16; ++w) {
int col = 16*w + 2*(tid % 4);
#define IDX(i, j) ((j + n_it*WGMMA_N)*M + ((i) + m_it*WGMMA_M))
block_C[IDX(row, col)] = d[w][0];
block_C[IDX(row, col+1)] = d[w][1];
block_C[IDX(row+8, col)] = d[w][2];
block_C[IDX(row+8, col+1)] = d[w][3];
block_C[IDX(row, col+8)] = d[w][4];
block_C[IDX(row, col+9)] = d[w][5];
block_C[IDX(row+8, col+8)] = d[w][6];
block_C[IDX(row+8, col+9)] = d[w][7];
#undef IDX
}
}
}
}- The 64×64 tile is split across the 4 warps in the warp‑group:
- Each warp handles 16 rows: rows [warp16 .. warp16+15].
- Within a warp:
- Each thread writes 2 rows: row_base = warp*16 + lane/4, and row_base+8.
- Columns are covered in 4 groups of 16 (w = 0..3). In each group, a thread writes 4 columns: col, col+1, col+8, col+9. Across w=0..3, that’s 16 distinct columns per thread.
- Therefore:
- One thread covers 2 rows × 16 columns = 32 elements.
- The 4 threads with the same row pair (tid%4 = 0..3) together cover 2 rows × 64 columns.
- Sanity check: 128 threads × 32 elements = 4096 = 64×64.
uint32_t row = warp*16 + lane / 4; // 2 rows per thread: row and row+8
int col = 16*w + 2*(tid % 4); // 4 cols per w: col,col+1,col+8,col+9- Each warp handles a 16-row band. Within that band:
- lanes
0–3→ row 0;4–7→ row 1; …;28–31→ row 7 - Each thread writes two rows: row and row+8 (so 16 rows per warp total)
- lanes
- Columns are handled in 4 groups of 16 columns each (w = 0..3; WGMMA_N=64):
- Inside each 16-wide group, each thread writes 2 adjacent columns at col and col+1, and also col+8 and col+9 (so 4 columns per w)
- Over
w=0..3, that’s 16 columns per thread’s rows, and collectively the warp covers all 64 columns
- tid = 0 (warp=0, lane=0)
- Rows: row = 0, and row+8 = 8
- Columns per w:
- w=0 → col=0 → {0,1,8,9}
- w=1 → col=16 → {16,17,24,25}
- w=2 → col=32 → {32,33,40,41}
- w=3 → col=48 → {48,49,56,57}
- Writes 32 values total across rows {0,8}
- tid = 1 (warp=0, lane=1)
- Rows: {0, 8}
- Columns per w:
- w=0 → col=2 → {2,3,10,11}
- w=1 → col=18 → {18,19,26,27}
- w=2 → col=34 → {34,35,42,43}
- w=3 → col=50 → {50,51,58,59}
- tid = 2 (warp=0, lane=2)
- Rows: {0, 8}
- Columns per w:
- w=0 → col=4 → {4,5,12,13}
- w=1 → col=20 → {20,21,28,29}
- w=2 → col=36 → {36,37,44,45}
- w=3 → col=52 → {52,53,60,61}
- tid = 3 (warp=0, lane=3)
- Rows: {0, 8}
- Columns per w:
- w=0 → col=6 → {6,7,14,15}
- w=1 → col=22 → {22,23,30,31}
- w=2 → col=38 → {38,39,46,47}
- w=3 → col=54 → {54,55,62,63}
Now move down one row in the band:
- tid = 4 (warp=0, lane=4)
- Rows: {1, 9}
- Columns per w (same pattern as tid=0, just different rows):
- w=0 → {0,1,8,9}
- w=1 → {16,17,24,25}
- w=2 → {32,33,40,41}
- w=3 → {48,49,56,57}
Next warp (warp 1) jumps to the next 16-row band:
- tid = 32 (warp=1, lane=0)
- Rows: {16, 24}
- Columns per w (same as tid=0, just shifted rows):
- w=0 → {0,1,8,9}
- w=1 → {16,17,24,25}
- w=2 → {32,33,40,41}
- w=3 → {48,49,56,57}
- Warps tile rows in 16-row bands: warp 0 → rows [0..15], warp 1 → [16..31], warp 2 → [32..47], warp 3 → [48..63].
- Inside a warp, lanes come in groups of 4 that share the same base row; each thread writes 2 rows: row and row+8.
- The 64 columns are split into four 16-wide groups (w=0..3). Each thread writes 4 columns per group (two adjacent, plus the same two at +8), so 16 columns per thread across all w.
- Over 128 threads, this perfectly partitions the 64×64 tile; each thread contributes 32 elements.
- Only thread 0 of warpgroup 0 issues the global→shared loads (via TMA), while the other warpgroups are consumers that compute.
- Two arrays of barriers are used per queue slot:
full[q]:signals “tile q is filled and safe to read” (producer → consumers).empty[q]:signals “tile q is no longer in use and safe to overwrite” (consumers → producer).
Producer (wg 0, tid 0) doing the loads:
// Producer
if (wg_idx == 0) {
if (tid == 0) {
int qidx = 0;
for (int block_k_iter = 0; block_k_iter < num_blocks_k; ++block_k_iter, ++qidx) {
if (qidx == QSIZE) qidx = 0;
empty[qidx].wait(empty[qidx].arrive());
cde::cp_async_bulk_tensor_2d_global_to_shared(&sA[qidx*BK*BM], tensorMapA, ... , full[qidx]);
cde::cp_async_bulk_tensor_2d_global_to_shared(&sB[qidx*BK*BN], tensorMapB, ... , full[qidx]);
barrier::arrival_token _ = cuda::device::barrier_arrive_tx(full[qidx], 1, (BK*BN+BK*BM)*sizeof(bf16));
}
}
}Consumers initialize empties, then wait for “full,” compute, and finally “arrive” on empty:
} else {
for (int i = 0; i < QSIZE; ++i) {
barrier::arrival_token _ = empty[i].arrive(); // Pre-post EMPTY on all slots
}
int qidx = 0;
for (int block_k_iter = 0; block_k_iter < num_blocks_k; ++block_k_iter, ++qidx) {
if (qidx == QSIZE) qidx = 0;
full[qidx].wait(full[qidx].arrive()); // Wait until tile is FULL
// ... compute using sA/sB for slot qidx ...
barrier::arrival_token _ = empty[qidx].arrive();// Mark slot qidx EMPTY again
}
}Barriers are allocated per slot and initialized with the correct participant count:
__shared__ barrier full[QSIZE], empty[QSIZE];
if (threadIdx.x == 0) {
for (int i = 0; i < QSIZE; ++i) {
init(&full[i], num_consumers * 128 + 1);
init(&empty[i], num_consumers * 128 + 1);
}
}Notes:
- Each warpgroup is 128 threads; wg_idx == 0 is producer; num_consumers = (#warpgroups - 1).
- The expected participant count for each barrier phase is: all consumer threads (num_consumers*128) + 1 producer thread.
Think of a ring buffer with QSIZE slots (e.g., 5). Each slot has two independent “phases” to coordinate different directions of ownership:
-
empty[q]: producer-side guard against overwrite
- Meaning: “Every consumer is done with slot q; producer may overwrite it.”
- Usage:
- On startup, all consumers pre-post empty[i].arrive() for every slot i, so all slots begin EMPTY.
- Before filling slot q, the producer does empty[q].wait(empty[q].arrive()), i.e., arrive-and-wait until all consumer threads have arrived for that phase. Only then it refills the slot.
-
full[q]: consumer-side guard against premature read
- Meaning: “Slot q contains a new tile; consumers may safely read it.”
- Usage:
- Producer starts TMA copies into slot q and then calls barrier_arrive_tx(full[q], 1, bytes) to bind transaction completion to the full[q] barrier.
- Consumers do full[q].wait(full[q].arrive()) for the slot they need next. They “arrive” early, then “wait” until the TMA engine completes and arrives transactionally on the barrier, releasing them to compute safely.
Why two arrays instead of one?
- You need two independent, alternating conditions per slot:
- Producer must not overwrite until consumers finish (empty barrier).
- Consumers must not read until producer (TMA) finishes (full barrier).
- Keeping them separate decouples the direction of waiting and makes the ring buffer pipelining clean: full controls availability to consumers; empty controls reusability by the producer. A single barrier would require more fragile phase accounting and increase the chance of read/overwrite races.
- init(..., num_consumers*128 + 1) sets the number of arrivals needed to complete a barrier phase for both arrays.
- For full[q] per iteration:
- Each consumer thread calls arrive(), then wait() on full[q].
- The producer contributes via barrier_arrive_tx when the TMA completes (counts as the “+1” arrival).
- When all consumers have arrived and the TMA arrives, the phase completes and all consumers’ wait() returns.
- For empty[q] per iteration:
- Each consumer thread calls empty[q].arrive() after finishing with the slot.
- The producer does empty[q].wait(empty[q].arrive()) before reuse; its arrive() plus all consumer arrivals complete the phase and its wait() returns.
- Startup:
- Consumers pre-post empty[i].arrive() for i in 0..4 → all 5 slots start EMPTY.
- Iteration 0 (q=0):
- Producer waits on empty[0] (returns immediately on startup), launches TMA into sA/sB for slot 0, then barrier_arrive_tx(full[0], ...).
- Consumers arrive-and-wait on full[0]; they are released only when the TMA completes.
- Consumers compute, then each arrives on empty[0].
- Iteration 1 (q=1), 2 (q=2), ...:
- Producer keeps staying ahead, using empty[q] to know when it’s safe to overwrite the slot and full[q] to notify consumers when a slot is ready.
- The ring of 5 slots lets the producer stay several steps ahead of compute, overlapping TMA and WGMMA.
- This kernel uses the producer-consumer approach
- Loading to shared memory is performed by warp_group 0, but only thread_id = 0
- Only wg_idx == 0 and tid == 0 issues TMA copies and barrier transactions; that’s sufficient because Hopper’s TMA runs independently once launched.
- The barriers and need for the two array barriers: As above—
full[] (ready-to-read)andempty[] (safe-to-overwrite)per queue slot provide independent synchronization directions to prevent both premature reads and overwrites while enabling pipelined overlap with a ring buffer.
__shared__ barrier barA; // Just allocates raw memory
__shared__ barrier barB; // Just allocates raw memoryThis is equivalent to:
__shared__ char barA_memory[sizeof(barrier)]; // Raw memory allocation
__shared__ char barB_memory[sizeof(barrier)]; // Raw memory allocationThe memory contains garbage values - whatever random bits happened to be there.
if (threadIdx.x == 0) {
init(&barA, blockDim.x);
init(&barB, blockDim.x);
cde::fence_proxy_async_shared_cta();
}init(&barA, blockDim.x); // THIS is where the constructor finally runsInside init(), placement new calls the constructor:
new (&barA) barrier(blockDim.x); // Constructor runs HERE-
Kernel Launch:
__shared__ barrier barA;- ✅ Memory allocated in shared memory
- ❌ Constructor NOT called
- ❌ Object NOT initialized
- 💀 Contains garbage data
-
Manual Init:
init(&barA, blockDim.x);- ✅ Constructor finally called via placement new
- ✅ Object properly initialized
- ✅ Safe to use
If you tried to use barA before calling init():
__shared__ barrier barA; // Just allocation
// DON'T DO THIS - undefined behavior!
auto token = barA.arrive(); // 💥 Crash! Object not initializedYou must call init() first:
__shared__ barrier barA; // Allocation only
if (threadIdx.x == 0) {
init(&barA, blockDim.x); // Actual initialization
}
__syncthreads();
// NOW it's safe to use
auto token = barA.arrive(); // ✅ Works correctlySo __shared__ declarations are pure memory allocation with zero initialization in CUDA!
