Created
May 21, 2025 23:50
-
-
Save csullivan/ee5f44b5b8df94730085001ba7ccc589 to your computer and use it in GitHub Desktop.
Roughly analogous performance to the fp8xfp8 the first FC layer from the triton-lang/triton/python/triton_kernels _p_matmul_ogs.py Mixture of Experts kernel when the routing is exactly uniform (even; no variance) to all the experts
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytest | |
from typing import Optional | |
import torch | |
import triton | |
import triton.language as tl | |
DEVICE = "cuda" | |
def num_sms(): | |
return torch.cuda.get_device_properties("cuda").multi_processor_count | |
@triton.jit | |
def grouped_matmul_tma_kernel( | |
# device tensor of matrices pointers | |
group_a_ptrs, | |
group_b_ptrs, | |
group_c_ptrs, | |
gm: tl.constexpr, gn: tl.constexpr, gk: tl.constexpr, | |
# device tensor of leading dimension sizes. its shape is [group_size, 3] | |
# dim 0 is group_size, dim 1 is the values of <lda, ldb, ldc> of each gemm | |
g_lds, | |
# number of gemms | |
group_size: tl.constexpr, | |
# number of virtual SM | |
NUM_SM: tl.constexpr, | |
# tile sizes | |
BLOCK_SIZE_M: tl.constexpr, | |
BLOCK_SIZE_N: tl.constexpr, | |
BLOCK_SIZE_K: tl.constexpr, | |
# is the output FP8 or FP16 | |
FP8: tl.constexpr, | |
): | |
dtype = tl.float8e4nv if FP8 else tl.float16 | |
num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M) | |
num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N) | |
num_tiles = num_m_tiles * num_n_tiles | |
num_iterations = group_size * num_tiles | |
start_pid = tl.program_id(axis=0) | |
for idx in tl.range(start_pid, num_iterations, NUM_SM, flatten=True, disallow_acc_multi_buffer=False, warp_specialize=True): | |
# faster - subsequent SMs use the same group and only move to the next group once all tiles are mapped | |
g = idx // num_tiles | |
tile_idx = idx % num_tiles | |
# slower - subsequent SMs each use a different group and only move to the next tile once all groups are mapped | |
# g = idx % group_size | |
# tile_idx = idx // group_size | |
lda = tl.load(g_lds + g * 3) | |
ldb = tl.load(g_lds + g * 3 + 1) | |
ldc = tl.load(g_lds + g * 3 + 2) | |
a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(dtype)) | |
b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(dtype)) | |
c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(dtype)) | |
a_desc = tl.make_tensor_descriptor( | |
a_ptr, | |
shape=[gm, gk], | |
strides=[lda, 1], | |
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], | |
) | |
b_desc = tl.make_tensor_descriptor( | |
b_ptr, | |
shape=[gn, gk], | |
strides=[ldb, 1], | |
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], | |
) | |
c_desc = tl.make_tensor_descriptor( | |
c_ptr, | |
shape=[gm, gn], | |
strides=[ldc, 1], | |
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], | |
) | |
# for tile_idx in tl.range(start_pid, num_tiles, NUM_SM): | |
tile_m_idx = tile_idx // num_n_tiles | |
tile_n_idx = tile_idx % num_n_tiles | |
offs_am = tile_m_idx * BLOCK_SIZE_M | |
offs_bn = tile_n_idx * BLOCK_SIZE_N | |
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) | |
for kk in range(0, tl.cdiv(gk, BLOCK_SIZE_K)): | |
a = a_desc.load([offs_am, kk * BLOCK_SIZE_K]) | |
b = b_desc.load([offs_bn, kk * BLOCK_SIZE_K]) | |
accumulator += tl.dot(a, b.T) | |
offs_cm = tile_m_idx * BLOCK_SIZE_M | |
offs_cn = tile_n_idx * BLOCK_SIZE_N | |
c = accumulator.to(tl.float8e4nv) | |
c_desc.store([offs_cm, offs_cn], c) | |
def group_gemm_tma_fn(group_A, group_B): | |
assert len(group_A) == len(group_B) | |
group_size = len(group_A) | |
A_addrs = [] | |
B_addrs = [] | |
C_addrs = [] | |
g_sizes = [] | |
g_lds = [] | |
group_C = [] | |
M, K = group_A[0].shape | |
N, _ = group_B[0].shape | |
for i in range(group_size): | |
A = group_A[i] | |
B = group_B[i] | |
C = torch.zeros((M, N), device=DEVICE).to(torch.float8_e4m3fn) | |
group_C.append(C) | |
A_addrs.append(A.data_ptr()) | |
B_addrs.append(B.data_ptr()) | |
C_addrs.append(C.data_ptr()) | |
g_sizes += [M, N, K] | |
g_lds += [A.stride(0), B.stride(0), C.stride(0)] | |
d_a_ptrs = torch.tensor(A_addrs, device=DEVICE) | |
d_b_ptrs = torch.tensor(B_addrs, device=DEVICE) | |
d_c_ptrs = torch.tensor(C_addrs, device=DEVICE) | |
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE) | |
def alloc_fn(size: int, alignment: int, stream: Optional[int]): | |
return torch.empty(size, device="cuda", dtype=torch.int8) | |
triton.set_allocator(alloc_fn) | |
grid = lambda META: (META['NUM_SM'], ) | |
for _ in range(100): | |
out = grouped_matmul_tma_kernel[grid](d_a_ptrs, d_b_ptrs, d_c_ptrs, M, N, K, d_g_lds, group_size, | |
BLOCK_SIZE_M=64, | |
BLOCK_SIZE_N=128, | |
BLOCK_SIZE_K=128, | |
FP8=torch.float8_e4m3fn == group_A[0].dtype, NUM_SM=num_sms(), | |
num_stages=2) | |
# print(out.asm["ttgir"], flush=True) | |
return group_C | |
#@pytest.mark.parametrize("M", [128, 256, 512, 1024, 2048, 4096, 8192]) | |
#@pytest.mark.parametrize("N", [256, 512, 1024, 2048, 4096, 8192]) | |
#@pytest.mark.parametrize("K", [128, 256, 512, 1024, 2048, 4096]) | |
#@pytest.mark.parametrize("group_size", [1, 4, 8, 16]) | |
def test_grouped_gemm(M, N, K, group_size): | |
group_A = [] | |
group_B = [] | |
group_B_T = [] | |
for i in range(group_size): | |
#A = torch.rand((M, K), device=DEVICE, dtype=torch.float8_e4m3fn) | |
#B = torch.rand((K, N), device=DEVICE, dtype=torch.float8_e4m3fn) | |
A = torch.rand((M, K), device=DEVICE).to(torch.float8_e4m3fn) | |
# B = torch.rand((N, K), device=DEVICE).to(torch.float8_e4m3fn)*2 -1 # Shape matches kernel expectation [gn, gk] | |
B_T = torch.rand((N, K), device=DEVICE)*2 -1 | |
B_T = B_T.to(torch.float8_e4m3fn) | |
# A = torch.rand((M, K), device=DEVICE).to(torch.float16)*2 - 1 | |
# B = torch.rand((K, N), device=DEVICE).to(torch.float16)*2 - 1 | |
B_T = B_T.contiguous() # Already in the expected format for the kernel | |
group_A.append(A) | |
group_B_T.append(B_T) | |
a_s = torch.tensor(1.0, device="cuda") | |
b_s = torch.tensor(1.0, device="cuda") | |
# For reference calculation, we need A @ B where B is properly oriented | |
ref_out = [] | |
for a, b in zip(group_A, group_B_T): | |
# B is [N, K] and needs to be transposed for proper multiplication with A [M, K] | |
result = torch._scaled_mm(a, b.T, scale_a=a_s, scale_b=b_s) | |
ref_out.append(result) | |
# group_A = [a.to(torch.float8_e4m3fn) for a in group_A] | |
# group_B_T = [b.to(torch.float8_e4m3fn) for b in group_B_T] | |
tri_tma_out = group_gemm_tma_fn(group_A, group_B_T) | |
for i in range(group_size): | |
# import ipdb; ipdb.set_trace() | |
# print(ref_out[i][0:16, 0:16]) | |
# print(tri_tma_out[i][0:16, 0:16]) | |
torch.testing.assert_close(ref_out[i].to(torch.float16), tri_tma_out[i].to(torch.float16), atol=1e-2, rtol=1e-2) | |
print("PASS") | |
if __name__ == "__main__": | |
test_grouped_gemm(M=64, N=2048, K=5120, group_size=128) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment