Created
July 20, 2025 01:08
-
-
Save ita9naiwa/d50578932b6d3c644df31cec8f18cb54 to your computer and use it in GitHub Desktop.
benchmark.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
os.environ["TRITON_CACHE_DIR"] = "./cache" | |
os.environ["TRITON_DUMP_DIR"] = "./cache" | |
import torch, triton, triton.language as tl, os, statistics as stats | |
# --------------------------- | |
# Fused 2-dot kernel | |
# Dot A: (M1,K1)x(K1,N1) -> C1 | |
# Dot B: (M2,K2)x(K2,N2) -> C2 (smaller) | |
# Shapes for second dot are independent (could be sub-problem). | |
# --------------------------- | |
@triton.jit | |
def dual_gemm_kernel( | |
A1, B1, C1, M1, N1, K1, | |
A2, B2, C2, M2, N2, K2, | |
# strides (row-major assumed; pass explicit for generality) | |
sa1_m, sa1_k, sb1_k, sb1_n, sc1_m, sc1_n, | |
sa2_m, sa2_k, sb2_k, sb2_n, sc2_m, sc2_n, | |
# block sizes large dot | |
BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, BLOCK_K1: tl.constexpr, | |
# block sizes small dot | |
BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, BLOCK_K2: tl.constexpr, | |
): | |
pid_m = tl.program_id(0) | |
pid_n = tl.program_id(1) | |
# --------- First (large) matmul --------- | |
offs_m1 = pid_m * BLOCK_M1 + tl.arange(0, BLOCK_M1) | |
offs_n1 = pid_n * BLOCK_N1 + tl.arange(0, BLOCK_N1) | |
offs_k1 = tl.arange(0, BLOCK_K1) | |
a1_ptrs = A1 + offs_m1[:, None] * sa1_m + offs_k1[None, :] * sa1_k | |
b1_ptrs = B1 + offs_k1[:, None] * sb1_k + offs_n1[None, :] * sb1_n | |
acc1 = tl.zeros([BLOCK_M1, BLOCK_N1], dtype=tl.float32) | |
for k in range(0, K1, BLOCK_K1): | |
a1 = tl.load(a1_ptrs, mask=offs_m1[:, None] < M1, other=0.) | |
b1 = tl.load(b1_ptrs, mask=offs_n1[None, :] < N1, other=0.) | |
acc1 += tl.dot(a1, b1) | |
a1_ptrs += BLOCK_K1 * sa1_k | |
b1_ptrs += BLOCK_K1 * sb1_k | |
c1 = acc1.to(tl.float16) | |
c1_ptrs = C1 + offs_m1[:, None] * sc1_m + offs_n1[None, :] * sc1_n | |
tl.store(c1_ptrs, | |
c1, | |
mask=(offs_m1[:, None] < M1) & (offs_n1[None, :] < N1)) | |
# Barrier-like logical separation (no real sync needed: independent CTAs). | |
# --------- Second (small) matmul --------- | |
offs_m2 = pid_m * BLOCK_M2 + tl.arange(0, BLOCK_M2) | |
offs_n2 = pid_n * BLOCK_N2 + tl.arange(0, BLOCK_N2) | |
offs_k2 = tl.arange(0, BLOCK_K2) | |
a2_ptrs = A2 + offs_m2[:, None] * sa2_m + offs_k2[None, :] * sa2_k | |
b2_ptrs = B2 + offs_k2[:, None] * sb2_k + offs_n2[None, :] * sb2_n | |
acc2 = tl.zeros([BLOCK_M2, BLOCK_N2], dtype=tl.float32) | |
for k in range(0, K2, BLOCK_K2): | |
a2 = tl.load(a2_ptrs, mask=offs_m2[:, None] < M2, other=0.) | |
b2 = tl.load(b2_ptrs, mask=offs_n2[None, :] < N2, other=0.) | |
acc2 += tl.dot(a2, b2) | |
a2_ptrs += BLOCK_K2 * sa2_k | |
b2_ptrs += BLOCK_K2 * sb2_k | |
c2 = acc2.to(tl.float16) | |
c2_ptrs = C2 + offs_m2[:, None] * sc2_m + offs_n2[None, :] * sc2_n | |
tl.store(c2_ptrs, | |
c2, | |
mask=(offs_m2[:, None] < M2) & (offs_n2[None, :] < N2)) | |
def run_dual( | |
M1,N1,K1, M2,N2,K2, | |
mode, # 'legacy' or 'new' | |
reps=1000, warmup=1000, | |
# choose distinct blockings so both dots appear distinctly in IR | |
BM1=128, BN1=128, BK1=32, | |
BM2=64, BN2=64, BK2=32, | |
dtype=torch.float16, | |
): | |
torch.manual_seed(0) | |
dev='cuda' | |
A1 = torch.randn((M1,K1), device=dev, dtype=dtype) | |
B1 = torch.randn((K1,N1), device=dev, dtype=dtype) | |
C1 = torch.empty((M1,N1), device=dev, dtype=dtype) | |
A2 = torch.randn((M2,K2), device=dev, dtype=dtype) | |
B2 = torch.randn((K2,N2), device=dev, dtype=dtype) | |
C2 = torch.empty((M2,N2), device=dev, dtype=dtype) | |
grid0 = (triton.cdiv(M1, BM1), triton.cdiv(N1, BN1)) | |
# NOTE: We reuse same grid for both dots; second dot may have more CTAs than needed | |
# but that is fine—the mask handles out-of-range threads. | |
grid = grid0 | |
# Warmup | |
for _ in range(warmup): | |
dual_gemm_kernel[grid]( | |
A1, B1, C1, M1, N1, K1, | |
A2, B2, C2, M2, N2, K2, | |
A1.stride(0), A1.stride(1), | |
B1.stride(0), B1.stride(1), | |
C1.stride(0), C1.stride(1), | |
A2.stride(0), A2.stride(1), | |
B2.stride(0), B2.stride(1), | |
C2.stride(0), C2.stride(1), | |
BLOCK_M1=BM1, BLOCK_N1=BN1, BLOCK_K1=BK1, | |
BLOCK_M2=BM2, BLOCK_N2=BN2, BLOCK_K2=BK2, | |
) | |
torch.cuda.synchronize() | |
start_evt = torch.cuda.Event(True); end_evt = torch.cuda.Event(True) | |
times=[] | |
for _ in range(reps): | |
start_evt.record() | |
dual_gemm_kernel[grid]( | |
A1, B1, C1, M1, N1, K1, | |
A2, B2, C2, M2, N2, K2, | |
A1.stride(0), A1.stride(1), | |
B1.stride(0), B1.stride(1), | |
C1.stride(0), C1.stride(1), | |
A2.stride(0), A2.stride(1), | |
B2.stride(0), B2.stride(1), | |
C2.stride(0), C2.stride(1), | |
BLOCK_M1=BM1, BLOCK_N1=BN1, BLOCK_K1=BK1, | |
BLOCK_M2=BM2, BLOCK_N2=BN2, BLOCK_K2=BK2, | |
) | |
end_evt.record() | |
torch.cuda.synchronize() | |
times.append(start_evt.elapsed_time(end_evt)) | |
mean = stats.mean(times) | |
# FLOPs both dots (2*M*N*K each) | |
flops1 = 2.0 * M1 * N1 * K1 | |
flops2 = 2.0 * M2 * N2 * K2 | |
tflops = (flops1 + flops2) / (mean * 1e-3) / 1e12 | |
return mean, tflops, (flops1, flops2) | |
if __name__ == "__main__": | |
# Large vs much smaller | |
tests = [ | |
(4096, 4096, 512, 512, 512, 64), | |
(4096, 4096, 128, 1024, 256, 64), | |
(8192, 1024, 256, 512, 512, 128), | |
(1024, 8192, 256, 512, 512, 64), | |
(4032, 4096, 128, 512, 512, 64), | |
(4096, 4096, 32, 512, 512, 32) | |
] | |
for t in tests: | |
(M1,N1,K1,M2,N2,K2) = t | |
new = run_dual(*t, mode='new', dtype=torch.float16)[0] | |
print(f"[dual] large=({M1},{N1},{K1}) small=({M2},{N2},{K2}) " | |
f"new={new:.3f}ms") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment