Skip to content

Instantly share code, notes, and snippets.

@ita9naiwa
Created July 20, 2025 01:08
Show Gist options
  • Save ita9naiwa/d50578932b6d3c644df31cec8f18cb54 to your computer and use it in GitHub Desktop.
Save ita9naiwa/d50578932b6d3c644df31cec8f18cb54 to your computer and use it in GitHub Desktop.
benchmark.py
import os
os.environ["TRITON_CACHE_DIR"] = "./cache"
os.environ["TRITON_DUMP_DIR"] = "./cache"
import torch, triton, triton.language as tl, os, statistics as stats
# ---------------------------
# Fused 2-dot kernel
# Dot A: (M1,K1)x(K1,N1) -> C1
# Dot B: (M2,K2)x(K2,N2) -> C2 (smaller)
# Shapes for second dot are independent (could be sub-problem).
# ---------------------------
@triton.jit
def dual_gemm_kernel(
A1, B1, C1, M1, N1, K1,
A2, B2, C2, M2, N2, K2,
# strides (row-major assumed; pass explicit for generality)
sa1_m, sa1_k, sb1_k, sb1_n, sc1_m, sc1_n,
sa2_m, sa2_k, sb2_k, sb2_n, sc2_m, sc2_n,
# block sizes large dot
BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, BLOCK_K1: tl.constexpr,
# block sizes small dot
BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, BLOCK_K2: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
# --------- First (large) matmul ---------
offs_m1 = pid_m * BLOCK_M1 + tl.arange(0, BLOCK_M1)
offs_n1 = pid_n * BLOCK_N1 + tl.arange(0, BLOCK_N1)
offs_k1 = tl.arange(0, BLOCK_K1)
a1_ptrs = A1 + offs_m1[:, None] * sa1_m + offs_k1[None, :] * sa1_k
b1_ptrs = B1 + offs_k1[:, None] * sb1_k + offs_n1[None, :] * sb1_n
acc1 = tl.zeros([BLOCK_M1, BLOCK_N1], dtype=tl.float32)
for k in range(0, K1, BLOCK_K1):
a1 = tl.load(a1_ptrs, mask=offs_m1[:, None] < M1, other=0.)
b1 = tl.load(b1_ptrs, mask=offs_n1[None, :] < N1, other=0.)
acc1 += tl.dot(a1, b1)
a1_ptrs += BLOCK_K1 * sa1_k
b1_ptrs += BLOCK_K1 * sb1_k
c1 = acc1.to(tl.float16)
c1_ptrs = C1 + offs_m1[:, None] * sc1_m + offs_n1[None, :] * sc1_n
tl.store(c1_ptrs,
c1,
mask=(offs_m1[:, None] < M1) & (offs_n1[None, :] < N1))
# Barrier-like logical separation (no real sync needed: independent CTAs).
# --------- Second (small) matmul ---------
offs_m2 = pid_m * BLOCK_M2 + tl.arange(0, BLOCK_M2)
offs_n2 = pid_n * BLOCK_N2 + tl.arange(0, BLOCK_N2)
offs_k2 = tl.arange(0, BLOCK_K2)
a2_ptrs = A2 + offs_m2[:, None] * sa2_m + offs_k2[None, :] * sa2_k
b2_ptrs = B2 + offs_k2[:, None] * sb2_k + offs_n2[None, :] * sb2_n
acc2 = tl.zeros([BLOCK_M2, BLOCK_N2], dtype=tl.float32)
for k in range(0, K2, BLOCK_K2):
a2 = tl.load(a2_ptrs, mask=offs_m2[:, None] < M2, other=0.)
b2 = tl.load(b2_ptrs, mask=offs_n2[None, :] < N2, other=0.)
acc2 += tl.dot(a2, b2)
a2_ptrs += BLOCK_K2 * sa2_k
b2_ptrs += BLOCK_K2 * sb2_k
c2 = acc2.to(tl.float16)
c2_ptrs = C2 + offs_m2[:, None] * sc2_m + offs_n2[None, :] * sc2_n
tl.store(c2_ptrs,
c2,
mask=(offs_m2[:, None] < M2) & (offs_n2[None, :] < N2))
def run_dual(
M1,N1,K1, M2,N2,K2,
mode, # 'legacy' or 'new'
reps=1000, warmup=1000,
# choose distinct blockings so both dots appear distinctly in IR
BM1=128, BN1=128, BK1=32,
BM2=64, BN2=64, BK2=32,
dtype=torch.float16,
):
torch.manual_seed(0)
dev='cuda'
A1 = torch.randn((M1,K1), device=dev, dtype=dtype)
B1 = torch.randn((K1,N1), device=dev, dtype=dtype)
C1 = torch.empty((M1,N1), device=dev, dtype=dtype)
A2 = torch.randn((M2,K2), device=dev, dtype=dtype)
B2 = torch.randn((K2,N2), device=dev, dtype=dtype)
C2 = torch.empty((M2,N2), device=dev, dtype=dtype)
grid0 = (triton.cdiv(M1, BM1), triton.cdiv(N1, BN1))
# NOTE: We reuse same grid for both dots; second dot may have more CTAs than needed
# but that is fine—the mask handles out-of-range threads.
grid = grid0
# Warmup
for _ in range(warmup):
dual_gemm_kernel[grid](
A1, B1, C1, M1, N1, K1,
A2, B2, C2, M2, N2, K2,
A1.stride(0), A1.stride(1),
B1.stride(0), B1.stride(1),
C1.stride(0), C1.stride(1),
A2.stride(0), A2.stride(1),
B2.stride(0), B2.stride(1),
C2.stride(0), C2.stride(1),
BLOCK_M1=BM1, BLOCK_N1=BN1, BLOCK_K1=BK1,
BLOCK_M2=BM2, BLOCK_N2=BN2, BLOCK_K2=BK2,
)
torch.cuda.synchronize()
start_evt = torch.cuda.Event(True); end_evt = torch.cuda.Event(True)
times=[]
for _ in range(reps):
start_evt.record()
dual_gemm_kernel[grid](
A1, B1, C1, M1, N1, K1,
A2, B2, C2, M2, N2, K2,
A1.stride(0), A1.stride(1),
B1.stride(0), B1.stride(1),
C1.stride(0), C1.stride(1),
A2.stride(0), A2.stride(1),
B2.stride(0), B2.stride(1),
C2.stride(0), C2.stride(1),
BLOCK_M1=BM1, BLOCK_N1=BN1, BLOCK_K1=BK1,
BLOCK_M2=BM2, BLOCK_N2=BN2, BLOCK_K2=BK2,
)
end_evt.record()
torch.cuda.synchronize()
times.append(start_evt.elapsed_time(end_evt))
mean = stats.mean(times)
# FLOPs both dots (2*M*N*K each)
flops1 = 2.0 * M1 * N1 * K1
flops2 = 2.0 * M2 * N2 * K2
tflops = (flops1 + flops2) / (mean * 1e-3) / 1e12
return mean, tflops, (flops1, flops2)
if __name__ == "__main__":
# Large vs much smaller
tests = [
(4096, 4096, 512, 512, 512, 64),
(4096, 4096, 128, 1024, 256, 64),
(8192, 1024, 256, 512, 512, 128),
(1024, 8192, 256, 512, 512, 64),
(4032, 4096, 128, 512, 512, 64),
(4096, 4096, 32, 512, 512, 32)
]
for t in tests:
(M1,N1,K1,M2,N2,K2) = t
new = run_dual(*t, mode='new', dtype=torch.float16)[0]
print(f"[dual] large=({M1},{N1},{K1}) small=({M2},{N2},{K2}) "
f"new={new:.3f}ms")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment