Skip to content

Instantly share code, notes, and snippets.

@nihalpasham
Created September 20, 2025 12:39
Show Gist options
  • Save nihalpasham/5fc3b9905440dd474bebf7f14ba9caf4 to your computer and use it in GitHub Desktop.
Save nihalpasham/5fc3b9905440dd474bebf7f14ba9caf4 to your computer and use it in GitHub Desktop.
Making GP-GPU programming generally accessible

GP-GPU Programming Ramblings

#nvidia #cuda #gpu #arch #metal #rocm #triton

image
  • Tensor Cores and Tensor Memory Accelerator
  • Split Barriers (Ampere and above)
  • Thread Block Cluster (Hopper and above)
  • Producer-Consumer Queues

Larger picture

Global Memory (8192 × 8192):
┌─────────────────────────────────────┐
│ [0,0]   [0,1]   ...   [0,8191]      │
│ [1,0]   [1,1]   ...   [1,8191]      │
│  ...     ...    ...     ...         │
│[8191,0][8191,1] ... [8191,8191]     │
└─────────────────────────────────────┘

TMA copies 64×64 tiles:
┌───────┬───────┬─────┬───────┐
│ TileTile  │ ... │ Tile  │
│ (0,0)(0,1) │     │(0,127)│
├───────┼───────┼─────┼───────┤
│ TileTile  │ ... │ Tile  │
│ (1,0)(1,1) │     │(1,127)│
├───────┼───────┼─────┼───────┤
│  ...  │  ...  │ ... │  ...  │
├───────┼───────┼─────┼───────┤
│ TileTile  │ ... │ Tile  │
│(127,0)(127,1)│     │(127,127)│
└───────┴───────┴─────┴───────┘
  • Whole matrix (8192×8192) is split into tiles with dimensions 64×64 → In total, we have 128×128 = 16,384 tiles.
  • One CUDA block computes one 64×64 tile of C and then writes it out.

What a single block does for its 64×64 tile

  • K is processed in chunks of 64 (BK=64). Since K=8192, that’s 8192/64 = 128 iterations.

  • In each iteration:

    • Load the current 64×64 slices from A and B into shared memory.
    • Run 4 tensor-core ops of size m64n64k16 (because 4×16 = 64 along K).
    • Each op accumulates into the same register tile d[ ][ ] (per-thread registers).
  • Registers d[ ][ ] start at zero for this tile and keep accumulating across all 128 iterations.

  • After the 128 iterations, the register tile holds the final 64×64 results and is written to C (column‑major).

  • Note: each thread holds 32 accumulator registers; 32 × 128 = 4096 = 64 × 64.

Whole-matrix picture

  • There are 128 tiles across M and 128 tiles across N → launch 128×128 blocks.
  • Each block independently:
    1. accumulates its 64×64 results over 128 K-steps via 4 WGMMA calls per step,
    2. stores its finished 64×64 tile to C.

Tiny recap

  • Per 64×64 tile: 128 K-steps; each step issues 4× wgmma.m64n64k16; all accumulate in registers.
  • Whole matrix: 16,384 such tiles; each computed by one block and stored to C.

Tensor Memory Accelerator (TMA) - Loads

  • TMA is a new piece of hardware introduced in the Hopper architecture— a faster way to load tiles of multi-dimensional matrices between GMEM and SMEM.
  • For example—It’s essentially used to tell the GPU: Here's how to efficiently slice this big matrix into 64×64 chunks and copy them to shared memory.
  • TMA loads directly support the swizzling patterns required by tensor cores.
  • TMA takes in a tiling configuration for a given matrix, and can load any requested tile into SMEM.
  • We’ll need to create a TMA map descriptor for efficient data movement between global memory and shared memory on modern CUDA GPUs.

Example Function for TMA operations:

template <int BlockMajorSize, int BlockMinorSize>
__host__ static inline CUtensorMap* allocate_and_create_tensor_map(bf16* src, int blocks_height, int blocks_width) {
    CUtensorMap *tma_map_d;
    cudaMalloc(&tma_map_d, sizeof(CUtensorMap));
    CUtensorMap tma_map_host;
    create_tensor_map<BlockMajorSize, BlockMinorSize>(&tma_map_host, src, blocks_height, blocks_width);
    cudaMemcpy(tma_map_d, &tma_map_host, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
    return tma_map_d;
}

template <int BlockMajorSize, int BlockMinorSize>
void create_tensor_map(CUtensorMap *tma_map, bf16* gmem_ptr, int blocks_height, int blocks_width) {
    void* gmem_address = (void*)gmem_ptr;
    uint64_t gmem_prob_shape[5] = {(uint64_t)BlockMinorSize*blocks_width, (uint64_t)BlockMajorSize*blocks_height, 1, 1, 1};
    uint64_t gmem_prob_stride[5] = {sizeof(bf16), sizeof(bf16) * BlockMinorSize*blocks_width, 0, 0, 0};
    uint32_t smem_box_shape[5] = {uint32_t(BlockMinorSize), uint32_t(BlockMajorSize), 1, 1, 1};
    uint32_t smem_box_stride[5] = {1, 1, 1, 1, 1};

    CUresult result = cuTensorMapEncodeTiled(
        tma_map, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, 2, gmem_address, gmem_prob_shape,
        gmem_prob_stride + 1, smem_box_shape, smem_box_stride, CU_TENSOR_MAP_INTERLEAVE_NONE,
        CU_TENSOR_MAP_SWIZZLE_128B, CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);

    assert(result == CUDA_SUCCESS);
}

This function sets up a hardware-accelerated memory copy descriptor that tells the GPU how to efficiently move data from global memory to shared memory using TMA instructions.

Assume Given Parameters

  • BlockMajorSize = 64
  • BlockMinorSize = 64
  • blocks_width = 128
  • blocks_height = 128

Breaking Down Each Array

1. Global Memory Problem Shape (gmem_prob_shape)

uint64_t gmem_prob_shape[5] = {
    (uint64_t)BlockMinorSize*blocks_width,  // 64 * 128 = 8192
    (uint64_t)BlockMajorSize*blocks_height, // 64 * 128 = 8192  
    1, 1, 1
};

This describes the total tensor dimensions in global memory:

  • Dimension 0: Width = 8192 elements
  • Dimension 1: Height = 8192 elements
  • Dimensions 2-4: Unused (set to 1)

So we have an 8192 × 8192 matrix in global memory.

2. Global Memory Stride (gmem_prob_stride)

uint64_t gmem_prob_stride[5] = {
    sizeof(bf16),                                    // 2 bytes
    sizeof(bf16) * BlockMinorSize*blocks_width,      // 2 * 8192 = 16384 bytes
    0, 0, 0
};

This tells how to navigate through global memory:

  • Stride 0: 2 bytes to next element in same row
  • Stride 1: 16384 bytes to next row (skip entire row width)
  • Strides 2-4: Unused

3. Shared Memory Box Shape (smem_box_shape)

uint32_t smem_box_shape[5] = {
    uint32_t(BlockMinorSize),  // 64
    uint32_t(BlockMajorSize),  // 64
    1, 1, 1
};

This describes the tile size we copy to shared memory:

  • 64 × 64 tile of bf16 elements
  • Each TMA operation copies this much data

4. Shared Memory Box Stride (smem_box_stride)

uint32_t smem_box_stride[5] = {1, 1, 1, 1, 1};

This is always {1, 1, 1, 1, 1} for contiguous shared memory layout.

The TMA Map Creation

CUresult result = cuTensorMapEncodeTiled(
    tma_map,                           // Output: the TMA descriptor
    CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, // Data type: bf16
    2,                                // 2D tensor
    gmem_address,                     // Global memory base address
    gmem_prob_shape,                  // Global tensor dimensions
    gmem_prob_stride + 1,             // Skip first stride element
    smem_box_shape,                   // Tile size for shared memory
    smem_box_stride,                  // Shared memory layout
    CU_TENSOR_MAP_INTERLEAVE_NONE,    // No interleaving
    CU_TENSOR_MAP_SWIZZLE_128B,       // 128-byte swizzling for bank conflicts
    CU_TENSOR_MAP_L2_PROMOTION_NONE,  // No L2 cache promotion
    CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE // No out-of-bounds fill
);

Visual Representation

Global Memory (8192 × 8192):
┌─────────────────────────────────────┐
│ [0,0]   [0,1]   ...   [0,8191]      │
│ [1,0]   [1,1]   ...   [1,8191]      │
│  ...     ...    ...     ...         │
│[8191,0][8191,1] ... [8191,8191]     │
└─────────────────────────────────────┘

TMA copies 64×64 tiles:
┌───────┬───────┬─────┬───────┐
│ TileTile  │ ... │ Tile  │
│ (0,0)(0,1) │     │(0,127)│
├───────┼───────┼─────┼───────┤
│ TileTile  │ ... │ Tile  │
│ (1,0)(1,1) │     │(1,127)│
├───────┼───────┼─────┼───────┤
│  ...  │  ...  │ ... │  ...  │
├───────┼───────┼─────┼───────┤
│ TileTile  │ ... │ Tile  │
│(127,0)(127,1)│     │(127,127)│
└───────┴───────┴─────┴───────┘

Key Parameters Explained

gmem_prob_stride + 1

This skips the first stride element, effectively using:

{16384, 0, 0, 0}  // Instead of {2, 16384, 0, 0, 0}

This is a common pattern for 2D tensors.

CU_TENSOR_MAP_SWIZZLE_128B

Applies 128-byte swizzling to avoid shared memory bank conflicts when accessing the data.

How It's Used

After creating this TMA map, you can use it in kernels:

// Copy a 64×64 tile from global memory to shared memory
asm volatile (
    "cp.async.bulk.tensor.3d.shared::cluster.global.tile.mbarrier::complete_tx::bytes"
    " [%0], [%1, {%3, %4, %5}], [%2];"
    :
    : "r"(dst_ptr), "l"(tma_ptr), "r"(mbar_ptr),
    "n"(0), "r"(global_row_idx), "r"(global_col_idx/64)
    : "memory"
);

// Note:
//	•	One TMA instruction = One 64×64 tile (8KB)
//	•	8192×8192 matrix = 16,384 tiles total
//	•	Need 16,384 TMA instructions to copy entire matrix
//	•	Typical usage: Copy tiles one-by-one as needed for computation
//  •	The power of TMA isn't copying everything at once - it's efficiently copying exactly what you need, when you need it!

Visual breakdown

8192×8192 Global Matrix (128×128 tiles):
┌─────┬─────┬─────┬─────┬─────┐
│(0,0)(0,1)(0,2)│ ... │(0,127)│
├─────┼─────┼─────┼─────┼─────┤
│(1,0)(1,1)(1,2)│ ... │(1,127)│  ← This instruction copies
├─────┼─────┼─────┼─────┼─────┤     ONE of these tiles(2,0)(2,1)(2,2)│ ... │(2,127)│
├─────┼─────┼─────┼─────┼─────┤
│ ... │ ... │ ... │ ... │ ... │
├─────┼─────┼─────┼─────┼─────┤
│(127,0)(127,1)│... │(127,127)│
└─────┴─────┴─────┴─────┴─────┘

Single TMA instruction: Copies tile (global_row_idx, global_col_idx/64)

To copy multiple tiles, we usually do

// Copy tiles for matrix multiplication
for (int k = 0; k < num_k_tiles; k++) {
    // Copy one A tile
    asm volatile("cp.async.bulk.tensor.2d..." 
        :: "r"(&sA[0]), "l"(tma_map_A), "r"(k), "r"(block_row));
    
    // Copy one B tile  
    asm volatile("cp.async.bulk.tensor.2d..."
        :: "r"(&sB[0]), "l"(tma_map_B), "r"(k), "r"(block_col));
    
    // Wait for copies to complete
    barrier.wait();
    
    // Do computation on these tiles
    wgmma(...);
}

Summary

This function creates a hardware descriptor that enables:

  • Fast bulk copies of 64×64 tiles from an 8192×8192 global matrix
  • Automatic address calculation for different tile positions
  • Optimized memory access patterns with swizzling
  • Hardware-accelerated transfers using TMA instructions

Split barriers:

Before Ampere (Pascal, Volta, Turing)

  • __syncthreads() / bar.sync:
  • A monolithic barrier: all threads in the block must reach it before any proceed.
  • If some threads arrive much earlier, they must stall until the last one arrives.
  • Coarse-grained → limits overlap of independent work.

Ampere: Split Barriers

  • Introduced barrier.arrive and barrier.wait primitives.
  • Benefits:
    1. Decoupled arrival & waiting → a warp can “signal” it’s done with some work (arrive), then continue with independent tasks before later executing wait.
    2. Better overlap of computation & synchronization → allows useful work to be done instead of idling at the barrier.
    3. Finer-grained control → synchronization can happen at warp-level rather than full block-level.
    4. Improved latency hiding → especially important for async copies (cp.async) into shared memory, where producer warps can continue issuing loads while consumer warps wait only when needed.

It's a 64-bit Value with Special Hardware Meaning

asm volatile ("mbarrier.init.shared.b64 [%0], %1;" // Initializes a hardware barrier in shared memory
		// Converts a C++ object address to a shared memory address
	:: "r"(static_cast<_CUDA_VSTD::uint32_t>(__cvta_generic_to_shared(&__b->__barrier))),
       "r"(static_cast<_CUDA_VSTD::uint32_t>(__expected)) // Sets the expected arrival count
    : "memory"); 

// Breaking Down the Syntax
// 1. asm volatile
//		•	asm = Inline assembly block
//		•	volatile = "Don't optimize this away, it has side effects"
// 2. The Instruction: "mbarrier.init.shared.b64 [%0], %1;"
//      •	mbarrier.init = Initialize a memory barrier
//		•	shared = Located in shared memory
// 		•	b64 = 64-bit barrier object
//		•	[%0] = Memory address (first parameter)
//		•	%1 = Expected count (second parameter)
// 3. Input Constraints: :: "r"(...), "r"(...)
// 	  The  :: separates outputs (none) from inputs:
// 		•	First "r"(...) → %0 (register constraint)
// 		•	Second "r"(...) → %1 (register constraint)
// 4. Clobber List: : "memory"
// 		•	"memory" = This instruction may modify memory, so don't cache memory values across it

// Enables hardware-accelerated synchronization for all threads in the block
// It's essentially telling the GPU: This memory location is now a hardware barrier that expects N threads to synchronize on it!

The b64 tells us this is a 64-bit barrier object in shared memory.

Within a Single SM (Streaming Multiprocessor)

About the scope:

Single SM Scenario:

  • All threads in the thread block can access the same shared memory
  • All threads see the same 64-bit barrier value
  • Hardware recognizes this specific memory location as a barrier

The 64-bit Layout (Conceptual):

[63:48] Phase + Hardware State
[47:32] Expected Arrival Count  
[31:0 ] Current Arrival Count

How Threads Interact with It

Thread Synchronization Pattern:

__shared__ uint64_t barrier;  // The special 64-bit value

// Thread 0 initializes
if (threadIdx.x == 0) {
    asm volatile ("mbarrier.init.shared.b64 [%0], %1;" 
        :: "r"(&barrier), "r"(blockDim.x));
}
__syncthreads();

// ALL threads arrive at barrier
asm volatile ("mbarrier.arrive.shared.b64 [%0];" :: "r"(&barrier));

// ALL threads wait for completion  
asm volatile ("mbarrier.test_wait.shared.b64 [%0];" :: "r"(&barrier));

The Hardware Magic

What Makes It "Special":

  1. Memory Controller Recognition: The SM's memory controller knows this address contains a barrier
  2. Atomic Operations: Hardware can atomically increment arrival count
  3. Efficient Waiting: Threads don't busy-wait, they can be suspended
  4. Cache Coherency: All threads see consistent state instantly

When Threads "Update" the Barrier:

// Thread arrives - hardware atomically does:
// barrier[31:0]++;  // Increment arrival count
// if (arrival_count == expected_count) {
//     barrier[63:48] ^= 1;  // Flip phase bit - releases all waiters
// }

Visual Example

Initial state:    [Phase=0][Expected=4][Current=0]
Thread 0 arrives: [Phase=0][Expected=4][Current=1] 
Thread 1 arrives: [Phase=0][Expected=4][Current=2]
Thread 2 arrives: [Phase=0][Expected=4][Current=3]
Thread 3 arrives: [Phase=1][Expected=4][Current=4] ← Phase flips, all threads released!

Key Insight

  • It's just a 64-bit integer in shared memory
  • But the hardware treats it specially when you use mbarrier.* instructions
  • All threads in the block can interact with the same barrier
  • Hardware ensures atomicity and efficient synchronization

The "magic" isn't in the memory itself - it's in the specialized hardware instructions (mbarrier.arrive, mbarrier.test_wait, etc.) that know how to interpret and manipulate this 64-bit value as a synchronization primitive!

So a barrier is essentially a hardware-accelerated 64-bit counter with special semantics that all threads in the SM can safely and efficiently synchronize on.

Tensor Core Wgmma Instruction:

What's Really Happening

We perform 4 WGMMA operations on the loaded tiles

Further Subdivision

WGMMA64 Operation:

  • Each wgmma64 performs: 64×16 @ 16×64 → 64×64
  • sA: 64×16 slice (64 rows, 16 columns)
  • sB: 16×64 slice (16 rows, 64 columns)
  • Result: 64×64 accumulation into d

The 4 Operations:

wgmma64<1, 1, 1, 0, 0>(d, &sA[0],        &sB[0]);          // K=0:15
wgmma64<1, 1, 1, 0, 0>(d, &sA[WGMMA_K],  &sB[WGMMA_K]);    // K=16:31  
wgmma64<1, 1, 1, 0, 0>(d, &sA[2*WGMMA_K], &sB[2*WGMMA_K]); // K=32:47
wgmma64<1, 1, 1, 0, 0>(d, &sA[3*WGMMA_K], &sB[3*WGMMA_K]); // K=48:63

Where WGMMA_K = 16.

Visual Breakdown

sA (64×64) - 4 vertical slices of 64×16:

sA (64×64):
┌─────┬─────┬─────┬─────┐
│64×16│64×16│64×16│64×16│
│  0  │ 16  │ 32  │ 48  │
│     │     │     │     │
│     │     │     │     │
│     │     │     │     │
└─────┴─────┴─────┴─────┘

sB (64×64) - 4 horizontal slices of 16×64:

sB (64×64):
┌─────────────────────────┐ ← 16×64 (slice 0)
├─────────────────────────┤ ← 16×64 (slice 16) 
├─────────────────────────┤ ← 16×64 (slice 32)
└─────────────────────────┘ ← 16×64 (slice 48)

The 4 WGMMA Operations

wgmma64<1, 1, 1, 0, 0>(d, &sA[0],        &sB[0]);          // sA[64×16₀] × sB[16₀×64]
wgmma64<1, 1, 1, 0, 0>(d, &sA[WGMMA_K],  &sB[WGMMA_K]);    // sA[64×16₁] × sB[16₁×64]
wgmma64<1, 1, 1, 0, 0>(d, &sA[2*WGMMA_K], &sB[2*WGMMA_K]); // sA[64×16₂] × sB[16₂×64]
wgmma64<1, 1, 1, 0, 0>(d, &sA[3*WGMMA_K], &sB[3*WGMMA_K]); // sA[64×16₃] × sB[16₃×64]

Note: For sA, the first wgmma instruction wgmma64<1, 1, 1, 0, 0>(d, &sA[0], &sB[0]) automatically picks a 64×16 slice from the 64×64 tile in row-major form:

Row 0:  elements 0-15
Row 1:  elements 64-79  
Row 2:  elements 128-143
Row 3:  elements 192-207
...
Row 63: elements 4032-4047

This happens because:

  1. sA is laid out in SMEM as a 64×64 tile with K (BK=64) as the contiguous dimension
  2. The wgmma instruction expects A data in a specific swizzled pattern that TMA automatically provides
  3. When you pass &sA[0], the tensor core hardware reads the first 16 columns (K=0 to K=15) across all 64 rows

The pattern continues for subsequent calls:

  • &sA[16]: columns 16-31 across all 64 rows
  • &sA[32]: columns 32-47 across all 64 rows
  • &sA[48]: columns 48-63 across all 64 rows

So, each wgmma call processes one vertical 64×16 "stripe" of the sA tile, and the +16 offset selects which stripe along the K dimension.

Each operation multiplies:

  • sA slice: 64×16 (64 rows, 16 columns)
  • sB slice: 16×64 (16 rows, 64 columns)
  • Result: 64×64 accumulated into d

Matrix Multiplication Breakdown

C[64×64] = A[64×16] × B[16₀×64] +   // K-slice 0:15
           A[64×16] × B[16₁×64] +   // K-slice 16:31
           A[64×16] × B[16₂×64] +   // K-slice 32:47
           A[64×16] × B[16₃×64]     // K-slice 48:63

There are exactly 4 tiles in each of sA and sB, and sB should indeed be 16×64 for each slice.

Each WGMMA Operation:

// Operation 1: K-dimension 0:15
sA[0:63, 0:15] × sB[0:15, 0:63] → d[0:63, 0:63]

// Operation 2: K-dimension 16:31  
sA[0:63, 16:31] × sB[16:31, 0:63] → d[0:63, 0:63] (accumulate)

// Operation 3: K-dimension 32:47
sA[0:63, 32:47] × sB[32:47, 0:63] → d[0:63, 0:63] (accumulate)

// Operation 4: K-dimension 48:63
sA[0:63, 48:63] × sB[48:63, 0:63] → d[0:63, 0:63] (accumulate)

Why 4 Operations?

Hardware Limitation:

  • WGMMA instruction can only handle K-dimension of 16 at a time
  • Our tile has K=64, so we need 64/16 = 4 operations

Matrix Multiplication Math:

C[64×64] = A[64×64] × B[64×64]

// Broken down by K-dimension:
C = A[64×0:15] × B[0:15×64] +     // First 16 K-elements
    A[64×16:31] × B[16:31×64] +   // Next 16 K-elements  
    A[64×32:47] × B[32:47×64] +   // Next 16 K-elements
    A[64×48:63] × B[48:63×64]     // Last 16 K-elements

Memory Addressing

sA Addressing:

&sA[0]         // Points to sA[0:63, 0:15]
&sA[WGMMA_K]   // Points to sA[0:63, 16:31] (skip 16 columns)
&sA[2*WGMMA_K] // Points to sA[0:63, 32:47] (skip 32 columns)  
&sA[3*WGMMA_K] // Points to sA[0:63, 48:63] (skip 48 columns)

sB Addressing:

&sB[0]         // Points to sB[0:15, 0:63]
&sB[WGMMA_K]   // Points to sB[16:31, 0:63] (skip 16 rows)
&sB[2*WGMMA_K] // Points to sB[32:47, 0:63] (skip 32 rows)
&sB[3*WGMMA_K] // Points to sB[48:63, 0:63] (skip 48 rows)

Summary

  • Load: One 64×64 tile each for A and B into shared memory
  • Compute: 4 WGMMA operations, each handling 16 elements of the K-dimension
  • Result: Complete 64×64 matrix multiplication accumulated in register d
  • Subdivision: Along the K-dimension (inner dimension), not the output dimensions

So we’re processing the same 64×64 output tile but breaking down the inner K-dimension into 4 chunks of 16 elements each!

Regular Stores

To reiterate

  • Launch: 16,384 blocks (128×128 tiles), each with 128 threads = 4 warps.
  • Each thread accumulates 32 results (in d[ ][ ]) and stores those 32 into its block’s 64×64 tile of C (not the whole tile per thread).
// Store
    {
        int tid = threadIdx.x;
        int lane = tid % 32;
        int warp = tid / 32;
        uint32_t row = warp*16 + lane / 4;
        bf16 *block_C = C + num_block_n*BN*M + num_block_m*BM;

        for (int m_it = 0; m_it < BM/WGMMA_M; ++m_it) {
            for (int n_it = 0; n_it < BN/WGMMA_N; ++n_it) {
                for (int w = 0; w < WGMMA_N/16; ++w) {
                    int col = 16*w + 2*(tid % 4);
                    #define IDX(i, j) ((j + n_it*WGMMA_N)*M + ((i) + m_it*WGMMA_M))

                    block_C[IDX(row, col)] = d[w][0];
                    block_C[IDX(row, col+1)] = d[w][1];
                    block_C[IDX(row+8, col)] = d[w][2];
                    block_C[IDX(row+8, col+1)] = d[w][3];
    
                    block_C[IDX(row, col+8)] = d[w][4];
                    block_C[IDX(row, col+9)] = d[w][5];
                    block_C[IDX(row+8, col+8)] = d[w][6];
                    block_C[IDX(row+8, col+9)] = d[w][7];

                    #undef IDX
                }
            }
        }
    }

Intuition

  • The 64×64 tile is split across the 4 warps in the warp‑group:
    • Each warp handles 16 rows: rows [warp16 .. warp16+15].
  • Within a warp:
    • Each thread writes 2 rows: row_base = warp*16 + lane/4, and row_base+8.
    • Columns are covered in 4 groups of 16 (w = 0..3). In each group, a thread writes 4 columns: col, col+1, col+8, col+9. Across w=0..3, that’s 16 distinct columns per thread.
  • Therefore:
    • One thread covers 2 rows × 16 columns = 32 elements.
    • The 4 threads with the same row pair (tid%4 = 0..3) together cover 2 rows × 64 columns.
  • Sanity check: 128 threads × 32 elements = 4096 = 64×64.

Ties to code

uint32_t row = warp*16 + lane / 4;   // 2 rows per thread: row and row+8
int col = 16*w + 2*(tid % 4);        // 4 cols per w: col,col+1,col+8,col+9

What this means

  • Each warp handles a 16-row band. Within that band:
    • lanes 0–3 → row 0; 4–7 → row 1; …; 28–31 → row 7
    • Each thread writes two rows: row and row+8 (so 16 rows per warp total)
  • Columns are handled in 4 groups of 16 columns each (w = 0..3; WGMMA_N=64):
    • Inside each 16-wide group, each thread writes 2 adjacent columns at col and col+1, and also col+8 and col+9 (so 4 columns per w)
    • Over w=0..3, that’s 16 columns per thread’s rows, and collectively the warp covers all 64 columns

Visual examples

  1. tid = 0 (warp=0, lane=0)
  • Rows: row = 0, and row+8 = 8
  • Columns per w:
    • w=0 → col=0 → {0,1,8,9}
    • w=1 → col=16 → {16,17,24,25}
    • w=2 → col=32 → {32,33,40,41}
    • w=3 → col=48 → {48,49,56,57}
  • Writes 32 values total across rows {0,8}
  1. tid = 1 (warp=0, lane=1)
  • Rows: {0, 8}
  • Columns per w:
    • w=0 → col=2 → {2,3,10,11}
    • w=1 → col=18 → {18,19,26,27}
    • w=2 → col=34 → {34,35,42,43}
    • w=3 → col=50 → {50,51,58,59}
  1. tid = 2 (warp=0, lane=2)
  • Rows: {0, 8}
  • Columns per w:
    • w=0 → col=4 → {4,5,12,13}
    • w=1 → col=20 → {20,21,28,29}
    • w=2 → col=36 → {36,37,44,45}
    • w=3 → col=52 → {52,53,60,61}
  1. tid = 3 (warp=0, lane=3)
  • Rows: {0, 8}
  • Columns per w:
    • w=0 → col=6 → {6,7,14,15}
    • w=1 → col=22 → {22,23,30,31}
    • w=2 → col=38 → {38,39,46,47}
    • w=3 → col=54 → {54,55,62,63}

Now move down one row in the band:

  1. tid = 4 (warp=0, lane=4)
  • Rows: {1, 9}
  • Columns per w (same pattern as tid=0, just different rows):
    • w=0 → {0,1,8,9}
    • w=1 → {16,17,24,25}
    • w=2 → {32,33,40,41}
    • w=3 → {48,49,56,57}

Next warp (warp 1) jumps to the next 16-row band:

  1. tid = 32 (warp=1, lane=0)
  • Rows: {16, 24}
  • Columns per w (same as tid=0, just shifted rows):
    • w=0 → {0,1,8,9}
    • w=1 → {16,17,24,25}
    • w=2 → {32,33,40,41}
    • w=3 → {48,49,56,57}

Mental model

  • Warps tile rows in 16-row bands: warp 0 → rows [0..15], warp 1 → [16..31], warp 2 → [32..47], warp 3 → [48..63].
  • Inside a warp, lanes come in groups of 4 that share the same base row; each thread writes 2 rows: row and row+8.
  • The 64 columns are split into four 16-wide groups (w=0..3). Each thread writes 4 columns per group (two adjacent, plus the same two at +8), so 16 columns per thread across all w.
  • Over 128 threads, this perfectly partitions the 64×64 tile; each thread contributes 32 elements.

Producer <- -> Consumer: Hiding load latencies

  • Only thread 0 of warpgroup 0 issues the global→shared loads (via TMA), while the other warpgroups are consumers that compute.
  • Two arrays of barriers are used per queue slot:
    • full[q]: signals “tile q is filled and safe to read” (producer → consumers).
    • empty[q]: signals “tile q is no longer in use and safe to overwrite” (consumers → producer).

Pointers in the code

Producer (wg 0, tid 0) doing the loads:

// Producer
if (wg_idx == 0) {
  if (tid == 0) {
    int qidx = 0;
    for (int block_k_iter = 0; block_k_iter < num_blocks_k; ++block_k_iter, ++qidx) {
      if (qidx == QSIZE) qidx = 0;
      empty[qidx].wait(empty[qidx].arrive());
      cde::cp_async_bulk_tensor_2d_global_to_shared(&sA[qidx*BK*BM], tensorMapA, ... , full[qidx]);
      cde::cp_async_bulk_tensor_2d_global_to_shared(&sB[qidx*BK*BN], tensorMapB, ... , full[qidx]);
      barrier::arrival_token _ = cuda::device::barrier_arrive_tx(full[qidx], 1, (BK*BN+BK*BM)*sizeof(bf16));
    }
  }
}

Consumers initialize empties, then wait for “full,” compute, and finally “arrive” on empty:

} else {
  for (int i = 0; i < QSIZE; ++i) {
    barrier::arrival_token _ = empty[i].arrive();   // Pre-post EMPTY on all slots
  }
  int qidx = 0;
  for (int block_k_iter = 0; block_k_iter < num_blocks_k; ++block_k_iter, ++qidx) {
    if (qidx == QSIZE) qidx = 0;
    full[qidx].wait(full[qidx].arrive());           // Wait until tile is FULL
    // ... compute using sA/sB for slot qidx ...
    barrier::arrival_token _ = empty[qidx].arrive();// Mark slot qidx EMPTY again
  }
}

Barriers are allocated per slot and initialized with the correct participant count:

__shared__ barrier full[QSIZE], empty[QSIZE];
if (threadIdx.x == 0) {
  for (int i = 0; i < QSIZE; ++i) {
    init(&full[i],  num_consumers * 128 + 1);
    init(&empty[i], num_consumers * 128 + 1);
  }
}

Notes:

  • Each warpgroup is 128 threads; wg_idx == 0 is producer; num_consumers = (#warpgroups - 1).
  • The expected participant count for each barrier phase is: all consumer threads (num_consumers*128) + 1 producer thread.

What the two barrier arrays do and why both are needed

Think of a ring buffer with QSIZE slots (e.g., 5). Each slot has two independent “phases” to coordinate different directions of ownership:

  1. empty[q]: producer-side guard against overwrite

    • Meaning: “Every consumer is done with slot q; producer may overwrite it.”
    • Usage:
      • On startup, all consumers pre-post empty[i].arrive() for every slot i, so all slots begin EMPTY.
      • Before filling slot q, the producer does empty[q].wait(empty[q].arrive()), i.e., arrive-and-wait until all consumer threads have arrived for that phase. Only then it refills the slot.
  2. full[q]: consumer-side guard against premature read

    • Meaning: “Slot q contains a new tile; consumers may safely read it.”
    • Usage:
      • Producer starts TMA copies into slot q and then calls barrier_arrive_tx(full[q], 1, bytes) to bind transaction completion to the full[q] barrier.
      • Consumers do full[q].wait(full[q].arrive()) for the slot they need next. They “arrive” early, then “wait” until the TMA engine completes and arrives transactionally on the barrier, releasing them to compute safely.

Why two arrays instead of one?

  • You need two independent, alternating conditions per slot:
    • Producer must not overwrite until consumers finish (empty barrier).
    • Consumers must not read until producer (TMA) finishes (full barrier).
  • Keeping them separate decouples the direction of waiting and makes the ring buffer pipelining clean: full controls availability to consumers; empty controls reusability by the producer. A single barrier would require more fragile phase accounting and increase the chance of read/overwrite races.

How the counts line up per phase

  • init(..., num_consumers*128 + 1) sets the number of arrivals needed to complete a barrier phase for both arrays.
  • For full[q] per iteration:
    • Each consumer thread calls arrive(), then wait() on full[q].
    • The producer contributes via barrier_arrive_tx when the TMA completes (counts as the “+1” arrival).
    • When all consumers have arrived and the TMA arrives, the phase completes and all consumers’ wait() returns.
  • For empty[q] per iteration:
    • Each consumer thread calls empty[q].arrive() after finishing with the slot.
    • The producer does empty[q].wait(empty[q].arrive()) before reuse; its arrive() plus all consumer arrivals complete the phase and its wait() returns.

With queue size = 5 (example timeline)

  • Startup:
    • Consumers pre-post empty[i].arrive() for i in 0..4 → all 5 slots start EMPTY.
  • Iteration 0 (q=0):
    • Producer waits on empty[0] (returns immediately on startup), launches TMA into sA/sB for slot 0, then barrier_arrive_tx(full[0], ...).
    • Consumers arrive-and-wait on full[0]; they are released only when the TMA completes.
    • Consumers compute, then each arrives on empty[0].
  • Iteration 1 (q=1), 2 (q=2), ...:
    • Producer keeps staying ahead, using empty[q] to know when it’s safe to overwrite the slot and full[q] to notify consumers when a slot is ready.
    • The ring of 5 slots lets the producer stay several steps ahead of compute, overlapping TMA and WGMMA.

Summary

  • This kernel uses the producer-consumer approach
  • Loading to shared memory is performed by warp_group 0, but only thread_id = 0
  • Only wg_idx == 0 and tid == 0 issues TMA copies and barrier transactions; that’s sufficient because Hopper’s TMA runs independently once launched.
  • The barriers and need for the two array barriers: As above—full[] (ready-to-read)and empty[] (safe-to-overwrite) per queue slot provide independent synchronization directions to prevent both premature reads and overwrites while enabling pipelined overlap with a ring buffer.

Cuda/C++ quirks:

__shared__ MyBarrier bar; is ONLY allocation

__shared__ barrier barA;  // Just allocates raw memory
__shared__ barrier barB;  // Just allocates raw memory

This is equivalent to:

__shared__ char barA_memory[sizeof(barrier)];  // Raw memory allocation
__shared__ char barB_memory[sizeof(barrier)];  // Raw memory allocation

No Constructor = No Initialization

The memory contains garbage values - whatever random bits happened to be there.

if (threadIdx.x == 0) {
    init(&barA, blockDim.x);
    init(&barB, blockDim.x);
    cde::fence_proxy_async_shared_cta();
}

The init() Call Does the Actual Construction

init(&barA, blockDim.x);  // THIS is where the constructor finally runs

Inside init(), placement new calls the constructor:

new (&barA) barrier(blockDim.x);  // Constructor runs HERE

Timeline

  1. Kernel Launch: __shared__ barrier barA;

    • ✅ Memory allocated in shared memory
    • ❌ Constructor NOT called
    • ❌ Object NOT initialized
    • 💀 Contains garbage data
  2. Manual Init: init(&barA, blockDim.x);

    • ✅ Constructor finally called via placement new
    • ✅ Object properly initialized
    • ✅ Safe to use

Why This Matters

If you tried to use barA before calling init():

__shared__ barrier barA;  // Just allocation

// DON'T DO THIS - undefined behavior!
auto token = barA.arrive();  // 💥 Crash! Object not initialized

You must call init() first:

__shared__ barrier barA;  // Allocation only

if (threadIdx.x == 0) {
    init(&barA, blockDim.x);  // Actual initialization
}
__syncthreads();

// NOW it's safe to use
auto token = barA.arrive();  // ✅ Works correctly

So __shared__ declarations are pure memory allocation with zero initialization in CUDA!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment