Skip to content

Instantly share code, notes, and snippets.

@a-r-r-o-w
Created August 4, 2025 12:14
Show Gist options
  • Select an option

  • Save a-r-r-o-w/28339b442d164084506c0967029968a8 to your computer and use it in GitHub Desktop.

Select an option

Save a-r-r-o-w/28339b442d164084506c0967029968a8 to your computer and use it in GitHub Desktop.
import math
import time
import torch
import triton
import triton.language as tl
import triton.runtime as runtime
import triton.tools.experimental_descriptor
VERBOSE = False
TORCH_TO_TRITON_DTYPE = {
torch.float32: tl.float32,
torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16,
torch.float8_e4m3fn: tl.float8e4nv,
}
SM = torch.cuda.get_device_capability()
# Computes quantiles from a list of values (e.g., elapsed times)
def quantile(a, q, precision=5):
n = len(a)
a = sorted(a)
def get_quantile(q):
if not (0 <= q <= 1):
raise ValueError("Quantiles must be in the range [0, 1]")
point = q * (n - 1)
lower = math.floor(point)
upper = math.ceil(point)
t = point - lower
return round((1 - t) * a[lower] + t * a[upper], precision)
return [get_quantile(q) for q in q]
# Helper function to clear the device cache
def clear_cache(cache):
if triton.__version__ == "3.2.0":
cache.zero_()
elif hasattr(runtime.driver.active, "clear_cache"):
runtime.driver.active.clear_cache(cache)
else:
raise RuntimeError("Unsupported Triton version.")
# Helper function to check the correctness of a kernel output against ground truth output
def correctness(x, truth, name):
diff = x - truth
absdiff = torch.abs(diff)
absmax = absdiff.max()
mae = absdiff.mean()
mse = (diff ** 2).mean()
cossim = torch.nn.functional.cosine_similarity(x.flatten(), truth.flatten(), dim=0)
print(f"{name}: {absmax=:.3f}, {mae=:.3f}, {mse=:.3f} {cossim=:.3f}")
# Helper function to compute the scale for fp8 quantization
def get_scale(tensor, granularity):
if granularity == "per_tensor":
return tensor.abs().max().to(torch.float32).view(1, 1)
elif granularity == "per_row":
return tensor.abs().amax(dim=1, keepdim=True).to(torch.float32)
else:
raise ValueError(f"Unsupported granularity: {granularity}")
# Helper function to benchmark a given function with specified parameters.
# It runs the matmul kernel with input tensors of shape (M, K) and (K, N),
# and measures the average time taken for the operation over multiple runs.
def benchmark(fn, M: int, K: int, N: int, granularity: str, dtype=torch.float8_e4m3fn, device="cuda", num_warmups=8, num_repeats=32, out_dtype: torch.dtype = torch.bfloat16):
# Make sure each run gets the same input for same shape (hacky but just want to get something running quickly)
torch.manual_seed(42)
init_dtype = torch.bfloat16
# Create input tensors
a = torch.randn(M, K, dtype=init_dtype, device=device)
b = torch.randn(K, N, dtype=init_dtype, device=device)
scale_a = scale_b = None
quant_min = torch.finfo(dtype).min
quant_max = torch.finfo(dtype).max
# Compute quantization scales (explanation in the blog post)
if granularity == "per_tensor":
scale_a = get_scale(a, granularity)
scale_b = get_scale(b.t(), granularity)
elif granularity == "per_row":
scale_a = get_scale(a, granularity)
scale_b = get_scale(b.t(), granularity)
scale_b = scale_b.t().contiguous()
if granularity is not None and dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
scale_a = scale_a / quant_max
scale_b = scale_b / quant_max
a = (a / scale_a).clamp(quant_min, quant_max)
b = (b / scale_b).clamp(quant_min, quant_max)
else:
scale_a = scale_b = None
a = a.to(dtype)
b = b.to(dtype).t().contiguous().t() # Ensure b is column-major for cuBLAS compatibility
out = torch.zeros((M, N), dtype=out_dtype, device=device)
fn(a, b, scale_a, scale_b, out)
torch.cuda.synchronize()
time.sleep(0.5)
cache = runtime.driver.active.get_empty_cache_for_benchmark()
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_repeats)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_repeats)]
# Warmup phase to ensure the kernel is compiled and ready
for _ in range(num_warmups):
clear_cache(cache)
out = torch.zeros((M, N), dtype=out_dtype, device=device)
fn(a, b, scale_a, scale_b, out)
torch.cuda.synchronize()
time.sleep(0.5)
# Benchmarking phase
for i in range(num_repeats):
clear_cache(cache)
out = torch.zeros((M, N), dtype=out_dtype, device=device)
start_events[i].record()
fn(a, b, scale_a, scale_b, out)
end_events[i].record()
torch.cuda.synchronize()
elapsed_times = [start.elapsed_time(end) for start, end in zip(start_events, end_events)]
mean_time = sum(elapsed_times) / len(elapsed_times)
quantile_times = quantile(elapsed_times, [0.5, 0.2, 0.8], precision=5)
if VERBOSE:
print(f"===== Benchmarking {fn.__name__} =====")
print(f"M: {M}, K: {K}, N: {N}, granularity: {granularity}, dtype: {dtype}")
print(f"Mean time: {mean_time:.5f} ms")
print(f"Quantiles (0.5, 0.2, 0.8): {quantile_times}")
print()
return mean_time, out
# Reference pytorch implementation for bf16 matmul. We will use it as a baseline for numerical correctness.
def torch_mm(a, b, scale_a, scale_b, out):
return torch.mm(a, b, out=out)
# Reference pytorch implementation for fp8 scaled matrix multiplication. We will use it to compare performance
# against Triton implementations.
def torch_scaled_mm(a, b, scale_a, scale_b, out):
return torch._scaled_mm(a, b, scale_a, scale_b, use_fast_accum=False, out=out)
# Fast accumulation version of the scaled matrix multiplication in PyTorch. We will use it to compare performance
# against Triton implementations. This function uses faster accumulation method, which has better performance
# but different numerical properties. Search for "CUBLASLT_MATMUL_DESC_FAST_ACCUM" on https://docs.nvidia.com/cuda/cublas/
def torch_scaled_mm_fast_accum(a, b, scale_a, scale_b, out):
return torch._scaled_mm(a, b, scale_a, scale_b, use_fast_accum=True, out=out)
def main():
# M, K, N
shapes = [
# (128, 128, 128),
(512, 3072, 3072 * 4),
(512, 3072 * 4, 3072),
(1152, 3072, 3072 * 4),
(1152, 3072 * 4, 3072),
(2304, 3072, 3072 * 4),
(2304, 3072 * 4, 3072),
(4608, 3072, 3072 * 4),
(4608, 3072 * 4, 3072),
]
fp8_dtype = torch.float8_e4m3fn if SM >= (8, 9) else torch.bfloat16
USE_SPLIT_K = False
out_dtype = torch.float16 if USE_SPLIT_K else torch.bfloat16
for shape in shapes:
M, K, N = shape
print(f"========== Running benchmarks for shape: M={M}, K={K}, N={N} ==========")
# Benchmark torch.mm with bfloat16 for baseline
baseline, baseline_out = benchmark(torch_mm, M, K, N, granularity=None, dtype=torch.bfloat16)
# Benchmark torch._scaled_mm with fast accumulation
if SM >= (8, 9):
scaled_mm_fast_accum_pt, scaled_mm_fast_accum_pt_out = benchmark(torch_scaled_mm_fast_accum, M, K, N, granularity="per_tensor", dtype=fp8_dtype)
scaled_mm_fast_accum_pr, scaled_mm_fast_accum_pr_out = benchmark(torch_scaled_mm_fast_accum, M, K, N, granularity="per_row", dtype=fp8_dtype)
# Benchmark torch._scaled_mm without fast accumulation
if SM >= (8, 9):
scaled_mm_pt, scaled_mm_pt_out = benchmark(torch_scaled_mm, M, K, N, granularity="per_tensor", dtype=fp8_dtype)
scaled_mm_pr, scaled_mm_pr_out = benchmark(torch_scaled_mm, M, K, N, granularity="per_row", dtype=fp8_dtype)
# Benchmark triton fp8 per tensor
triton_fp8_pt, triton_fp8_pt_out = benchmark(triton_fp8, M, K, N, granularity="per_tensor", dtype=fp8_dtype, out_dtype=out_dtype)
triton_fp8_pt_out = triton_fp8_pt_out.to(torch.bfloat16)
# Benchmark triton fp8 per row
triton_fp8_pr, triton_fp8_pr_out = benchmark(triton_fp8, M, K, N, granularity="per_row", dtype=fp8_dtype, out_dtype=out_dtype)
triton_fp8_pr_out = triton_fp8_pr_out.to(torch.bfloat16)
# Benchmark triton fp8 per tensor persistent
triton_fp8_pt_persistent, triton_fp8_pt_out_persistent = benchmark(triton_fp8_persistent, M, K, N, granularity="per_tensor", dtype=fp8_dtype, out_dtype=out_dtype)
triton_fp8_pt_out_persistent = triton_fp8_pt_out_persistent.to(torch.bfloat16)
# Benchmark triton fp8 per row persistent
triton_fp8_pr_persistent, triton_fp8_pr_out_persistent = benchmark(triton_fp8_persistent, M, K, N, granularity="per_row", dtype=fp8_dtype, out_dtype=out_dtype)
triton_fp8_pr_out_persistent = triton_fp8_pr_out_persistent.to(torch.bfloat16)
# Benchmark triton fp8 per tensor persistent TMA
triton_fp8_pt_persistent_tma, triton_fp8_pt_out_persistent_tma = benchmark(triton_fp8_persistent_tma, M, K, N, granularity="per_tensor", dtype=fp8_dtype, out_dtype=out_dtype)
triton_fp8_pt_out_persistent_tma = triton_fp8_pt_out_persistent_tma.to(torch.bfloat16)
# Benchmark triton fp8 per row persistent TMA
triton_fp8_pr_persistent_tma, triton_fp8_pr_out_persistent_tma = benchmark(triton_fp8_persistent_tma, M, K, N, granularity="per_row", dtype=fp8_dtype, out_dtype=out_dtype)
triton_fp8_pr_out_persistent_tma = triton_fp8_pr_out_persistent_tma.to(torch.bfloat16)
# Benchmark triton fp8 per tensor persistent TMA with cooperative warp specialization
triton_fp8_pt_persistent_tma_ws_cooperative, triton_fp8_pt_out_persistent_tma_ws_cooperative = benchmark(triton_fp8_persistent_tma_ws_cooperative, M, K, N, granularity="per_tensor", dtype=fp8_dtype, out_dtype=out_dtype)
triton_fp8_pt_out_persistent_tma_ws_cooperative = triton_fp8_pt_out_persistent_tma_ws_cooperative.to(torch.bfloat16)
# Benchmark triton fp8 per row persistent TMA with cooperative warp specialization
triton_fp8_pr_persistent_tma_ws_cooperative, triton_fp8_pr_out_persistent_tma_ws_cooperative = benchmark(triton_fp8_persistent_tma_ws_cooperative, M, K, N, granularity="per_row", dtype=fp8_dtype, out_dtype=out_dtype)
triton_fp8_pr_out_persistent_tma_ws_cooperative = triton_fp8_pr_out_persistent_tma_ws_cooperative.to(torch.bfloat16)
if SM >= (8, 9):
correctness(scaled_mm_fast_accum_pt_out, baseline_out, "scaled_mm_fast_accum_pt")
correctness(scaled_mm_fast_accum_pr_out, baseline_out, "scaled_mm_fast_accum_pr")
correctness(scaled_mm_pt_out, baseline_out, "scaled_mm_pt")
correctness(scaled_mm_pr_out, baseline_out, "scaled_mm_pr")
correctness(triton_fp8_pt_out, baseline_out, "triton_fp8_pt")
correctness(triton_fp8_pr_out, baseline_out, "triton_fp8_pr")
correctness(triton_fp8_pt_out_persistent, baseline_out, "triton_fp8_pt_persistent")
correctness(triton_fp8_pr_out_persistent, baseline_out, "triton_fp8_pr_persistent")
correctness(triton_fp8_pt_out_persistent_tma, baseline_out, "triton_fp8_pt_persistent_tma")
correctness(triton_fp8_pr_out_persistent_tma, baseline_out, "triton_fp8_pr_persistent_tma")
correctness(triton_fp8_pt_out_persistent_tma_ws_cooperative, baseline_out, "triton_fp8_pt_persistent_tma_ws_cooperative")
correctness(triton_fp8_pr_out_persistent_tma_ws_cooperative, baseline_out, "triton_fp8_pr_persistent_tma_ws_cooperative")
if SM >= (8, 9):
print(f"speedup scaled_mm_fast_accum_pt: {baseline / scaled_mm_fast_accum_pt:.2f}x")
print(f"speedup scaled_mm_fast_accum_pr: {baseline / scaled_mm_fast_accum_pr:.2f}x")
print(f"speedup scaled_mm_pt: {baseline / scaled_mm_pt:.2f}x")
print(f"speedup scaled_mm_pr: {baseline / scaled_mm_pr:.2f}x")
print(f"speedup triton_fp8_pt: {baseline / triton_fp8_pt:.2f}x")
print(f"speedup triton_fp8_pr: {baseline / triton_fp8_pr:.2f}x")
print(f"speedup triton_fp8_pt_persistent: {baseline / triton_fp8_pt_persistent:.2f}x")
print(f"speedup triton_fp8_pr_persistent: {baseline / triton_fp8_pr_persistent:.2f}x")
print(f"speedup triton_fp8_pt_persistent_tma: {baseline / triton_fp8_pt_persistent_tma:.2f}x")
print(f"speedup triton_fp8_pr_persistent_tma: {baseline / triton_fp8_pr_persistent_tma:.2f}x")
print(f"speedup triton_fp8_pt_persistent_tma_ws_cooperative: {baseline / triton_fp8_pt_persistent_tma_ws_cooperative:.2f}x")
print(f"speedup triton_fp8_pr_persistent_tma_ws_cooperative: {baseline / triton_fp8_pr_persistent_tma_ws_cooperative:.2f}x")
print()
# Configs for autotuning triton kernels. This is a dense grid of configurations, but ideally
# should be pruned based on the problem sizes you're working with and utilizing known best configurations.
# Explanation of the parameters:
# - BLOCK_M: Number of rows each program processes from the input matrix A and produces in the output matrix.
# - BLOCK_N: Number of columns each program processes from the input matrix B and produces in the output matrix.
# Together with BLOCK_M, defines the size of the output submatrix per program.
# - BLOCK_K: Width of the reduction tile. The K-dimension is divided into chunks of this size.
# Each program accumulates partial results over (K // BLOCK_K) steps.
# - GROUP_M: Improves L2 Cache utilization. See the official triton documentation for more details: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
# - SPLIT_K: Splits the K-dimension across multiple programs to reduce memory pressure. Use SPLIT_K > 1 only
# if reduction becomes a bottleneck (for large K).
# - num_warps: Number of warps (each warp is 32 threads) assigned per program. Higher values increase parallelism
# but also resource usage (registers/shared memory). 4 or 8 warps are typical choices.
# - num_stages: Number of pipeline stages used to overlap computation and memory loads. More stages can improve
# latency hiding with better overlap, but this may increase resource pressure and lead to slow down
# if not tuned properly.
#
# Warp-specialization parameters (enabled when warp_specialization=True):
# - NUM_CONSUMER_GROUPS and num_buffers_warp_spec: Specific to Hopper and newer GPUs
# https://pytorch.org/blog/warp-specialization/
def get_autotune_configs(warp_specialization: bool = False):
if not warp_specialization:
configs = [
triton.Config(
{
"BLOCK_M": BLOCK_M,
"BLOCK_N": BLOCK_N,
"BLOCK_K": BLOCK_K,
"GROUP_M": GROUP_M,
"SPLIT_K": SPLIT_K,
},
num_warps=num_warps,
num_stages=num_stages,
)
for BLOCK_M in [64, 128]
for BLOCK_N in [64, 128]
for BLOCK_K in [64, 128, 256]
for GROUP_M in [8]
# for SPLIT_K in [1, 2, 4] # Seems like there's no speedup with SPLIT_K > 1, so we keep it simple
for SPLIT_K in [1]
for num_warps in [4, 8]
# for num_stages in [3, 4]
for num_stages in [3]
]
else:
configs = [
triton.Config(
{
"BLOCK_M": BLOCK_M,
"BLOCK_N": BLOCK_N,
"BLOCK_K": BLOCK_K,
"GROUP_M": GROUP_M,
"SPLIT_K": SPLIT_K,
"NUM_CONSUMER_GROUPS": NUM_CONSUMER_GROUPS,
},
num_stages=num_stages,
num_warps=num_warps,
num_consumer_groups=0 if NUM_CONSUMER_GROUPS == 1 else NUM_CONSUMER_GROUPS,
num_buffers_warp_spec=num_buffers_warp_spec,
)
for BLOCK_M in [64, 128]
for BLOCK_N in [128, 256]
for BLOCK_K in [64, 128]
for GROUP_M in [8]
for SPLIT_K in [1]
for num_warps in [4]
for num_stages in [2, 3]
for NUM_CONSUMER_GROUPS in [1, 2]
for num_buffers_warp_spec in [3]
]
return configs
# Simple rules to prune dense grid of autotuning configurations based on problem size
def early_config_prune(configs, args, **kwargs):
device = torch.cuda.current_device()
max_shared_memory = triton.runtime.driver.active.utils.get_device_properties(device)["max_shared_mem"]
is_persistent_tma = "_in_element_dtype" in kwargs and "_out_element_dtype" in kwargs
if "a_ptr" in args:
element_size = args["a_ptr"].element_size()
elif "a_desc_ptr" in args:
in_element_dtype = kwargs.get("_in_element_dtype")
element_size = in_element_dtype.primitive_bitwidth // 8
else:
raise ValueError("Either 'a_ptr' or 'a_desc_ptr' must be provided in args")
M = kwargs.get("M")
K = kwargs.get("K")
N = kwargs.get("N")
pruned_configs = []
for config in configs:
BLOCK_M = config.kwargs["BLOCK_M"]
BLOCK_N = config.kwargs["BLOCK_N"]
BLOCK_K = config.kwargs["BLOCK_K"]
SPLIT_K = config.kwargs["SPLIT_K"]
num_stages = config.num_stages
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * element_size
if required_shared_memory > max_shared_memory:
continue
if M % BLOCK_M != 0:
continue
if N % BLOCK_N != 0:
continue
if K < 4096 and SPLIT_K > 1:
continue
if K % (BLOCK_K * SPLIT_K) != 0:
continue
if is_persistent_tma and BLOCK_K >= 256:
# Triton hangs indefinitely when using persistent matmul with TMA when BLOCK_K >= 256
continue
if not is_persistent_tma and BLOCK_K <= 64:
# Small BLOCK_K not fast other than when using TMA
continue
if config.kwargs.get("NUM_CONSUMER_GROUPS", 1) > 1 and BLOCK_M <= 64:
# Hangs indefinitely
continue
pruned_configs.append(config)
print(f"Pruned configs: {len(pruned_configs)} out of {len(configs)}")
return pruned_configs
# Kernel for Triton FP8 matrix multiplication.
@triton.autotune(configs=get_autotune_configs(), key=["M", "K", "N", "PER_ROW"], prune_configs_by={"early_config_prune": early_config_prune})
@triton.jit
def matmul_fp8_kernel(
a_ptr,
b_ptr,
scale_a_ptr,
scale_b_ptr,
out_ptr,
M,
K,
N,
PER_ROW: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
USE_FAST_ACCUM: tl.constexpr = True,
):
pid = tl.program_id(0)
pid_k = tl.program_id(1)
# L2 cache optimization: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html#l2-cache-optimizations
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_in_group = num_pid_n * GROUP_M
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# SPLIT_K is the umber of programs K dimension reduction has been split across
# chunk_size_k is the size of each split across the programs handling the K dimension
chunk_size_k = K // SPLIT_K
# num_chunks_k is the number of chunks of size BLOCK_K, which each program iterates over to perform reduction
num_chunks_k = chunk_size_k // BLOCK_K
offsets_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offsets_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offsets_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
offsets_am = tl.max_contiguous(tl.multiple_of(offsets_m, BLOCK_M), BLOCK_M)
offsets_bn = tl.max_contiguous(tl.multiple_of(offsets_n, BLOCK_N), BLOCK_N)
# Pointers to a block of A [BLOCK_M, BLOCK_K] and B [BLOCK_K, BLOCK_N]
ptrs_a = a_ptr + offsets_am[:, None] * K + offsets_k[None, :]
ptrs_b = b_ptr + offsets_k[:, None] + offsets_bn[None, :] * K
# Tile of the output matrix that every program computes
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# Load scales if doing scaled mm
if scale_a_ptr is not None and scale_b_ptr is not None:
if PER_ROW:
scale_a = tl.load(scale_a_ptr + offsets_am).to(tl.float32)
scale_b = tl.load(scale_b_ptr + offsets_bn).to(tl.float32)
scale = scale_a[:, None] * scale_b[None, :]
else:
scale_a = tl.load(scale_a_ptr).to(tl.float32)
scale_b = tl.load(scale_b_ptr).to(tl.float32)
scale = scale_a * scale_b
# Iterate over the K dimension in chunks of BLOCK_K, reduce across these chunks, and accumulate
for k in range(num_chunks_k):
a = tl.load(ptrs_a)
b = tl.load(ptrs_b)
if USE_FAST_ACCUM:
accumulator = tl.dot(a, b, accumulator, out_dtype=tl.float32)
else:
accumulator += tl.dot(a, b, out_dtype=tl.float32)
ptrs_a += BLOCK_K * SPLIT_K
ptrs_b += BLOCK_K * SPLIT_K
# Multiply the scale if provided.
# For per-tensor scaling, all output values [BLOCK_M, BLOCK_N] are multiplied by the same scale.
# For per-row scaling, each row of the output is multiplied by the corresponding scale.
if scale_a_ptr is not None and scale_b_ptr is not None:
accumulator = scale * accumulator
# Cast the accumulator to output dtype (fp8).
accumulator = accumulator.to(out_ptr.dtype.element_ty)
ptrs_out = out_ptr + offsets_am[:, None] * N + offsets_bn[None, :]
if SPLIT_K == 1:
# If we haven't split the reduction dimension across different programs, we can store the result directly.
tl.store(ptrs_out, accumulator)
else:
# If we have split the reduction dimension, each [BLOCK_M, BLOCK_N] tiles needs to be accumulated
# across all SPLIT_K programs to produce the final result.
tl.atomic_add(ptrs_out, accumulator, sem="relaxed")
def triton_fp8(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out: torch.Tensor):
M, K = a.shape
_, N = b.shape
per_row = scale_a.shape[0] > 1
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"])
matmul_fp8_kernel[grid](a, b, scale_a, scale_b, out, M=M, K=K, N=N, PER_ROW=per_row)
# Persistent kernel for FP8 matrix multiplication.
# A GPU has a fixed number of SMs.
# With persistence, we launch as many thread blocks as possible to occupy available SMs.
# Persistence thread blocks can compute one or more tiles of the output matrix depending on the problem size.
# The normal kernel above launches multiple thread blocks multiple times to compute the entire matrix, but
# with persistence, we save the time overhead of launching thread-blocks.
@triton.autotune(configs=get_autotune_configs(), key=["M", "K", "N", "PER_ROW"], prune_configs_by={"early_config_prune": early_config_prune})
@triton.jit
def matmul_fp8_persistent_kernel(
a_ptr,
b_ptr,
scale_a_ptr,
scale_b_ptr,
out_ptr,
M,
K,
N,
PER_ROW: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
NUM_SMS: tl.constexpr,
USE_FAST_ACCUM: tl.constexpr = True,
):
start_pid = tl.program_id(0)
# Number of programs launched across each dimension.
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
# Number of chunks of size BLOCK_K along the K dimension
k_tiles = tl.cdiv(K, BLOCK_K)
# Total number of tiles to be computed in order to cover the entire matrix. This
# is also the total number of programs launched
num_tiles = num_pid_m * num_pid_n
# Number of output tiles each SM will compute results for
tiles_per_SM = num_tiles // NUM_SMS
if start_pid < num_tiles % NUM_SMS:
tiles_per_SM += 1
tile_id = start_pid - NUM_SMS
ki = -1
num_pid_in_group = GROUP_M * num_pid_n
pid_m = 0
pid_n = 0
offsets_am = tl.arange(0, BLOCK_M)
offsets_bn = tl.arange(0, BLOCK_N)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# Load scales if doing scaled mm
if scale_a_ptr is not None and scale_b_ptr is not None:
if PER_ROW:
scale_a = tl.load(scale_a_ptr + offsets_am).to(tl.float32)
scale_b = tl.load(scale_b_ptr + offsets_bn).to(tl.float32)
scale = scale_a[:, None] * scale_b[None, :]
else:
scale_a = tl.load(scale_a_ptr).to(tl.float32)
scale_b = tl.load(scale_b_ptr).to(tl.float32)
scale = scale_a * scale_b
# Total number of iterations denote number of chunked BLOCK_K reductions
# across all tiles that each SM will handle.
for _ in range(k_tiles * tiles_per_SM):
# If reduction over K dimension is complete, we start processing the next tile. If
# not, we continue reducing over the next chunk of size BLOCK_K.
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
if ki == 0:
# Each time we start a new tile, we have to keep in mind that NUM_SMS programs are
# already launched and processing their respective tiles. In order to avoid processing
# the same tile, we offsets the tile_id by NUM_SMS.
tile_id += NUM_SMS
# L2 cache optimization: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html#l2-cache-optimizations
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
offsets_am = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offsets_bn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offsets_am = tl.max_contiguous(tl.multiple_of(offsets_am, BLOCK_M), BLOCK_M)
offsets_bn = tl.max_contiguous(tl.multiple_of(offsets_bn, BLOCK_N), BLOCK_N)
# Load A [BLOCK_M, BLOCK_K] and B [BLOCK_K, BLOCK_N] tiles for the current program,
# reduce across the BLOCK_K dimension, and accumulate the results.
offsets_k = ki * BLOCK_K + tl.arange(0, BLOCK_K)
ptrs_a = a_ptr + offsets_am[:, None] * K + offsets_k[None, :]
ptrs_b = b_ptr + offsets_k[:, None] + offsets_bn[None, :] * K
a = tl.load(ptrs_a)
b = tl.load(ptrs_b)
if USE_FAST_ACCUM:
accumulator = tl.dot(a, b, accumulator, out_dtype=tl.float32)
else:
accumulator += tl.dot(a, b, out_dtype=tl.float32)
# After the last reduction over the K dimension for current tile, store
# the accumulated results to the output tile [BLOCK_M, BLOCK_N].
if ki == k_tiles - 1:
# Multiply the scale if provided. Cast to output dtype (fp8).
if scale_a_ptr is not None and scale_b_ptr is not None:
accumulator = scale * accumulator
accumulator = accumulator.to(out_ptr.dtype.element_ty)
offsets_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offsets_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ptrs_out = out_ptr + offsets_cm[:, None] * N + offsets_cn[None, :]
if SPLIT_K == 1:
tl.store(ptrs_out, accumulator)
else:
tl.static_assert(True, "Persistent kernel does not support SPLIT_K > 1")
# Reset the accumulator for the next output tile.
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
def triton_fp8_persistent(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out: torch.Tensor):
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
M, K = a.shape
_, N = b.shape
per_row = scale_a.shape[0] > 1
grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"])),)
matmul_fp8_persistent_kernel[grid](a, b, scale_a, scale_b, out, M=M, K=K, N=N, NUM_SMS=NUM_SMS, PER_ROW=per_row)
# Persistent matmul kernel which makes use of specialized hardware units available on Hopper and newer GPUs: https://pytorch.org/blog/hopper-tma-unit/
@triton.autotune(configs=get_autotune_configs(), key=["M", "K", "N", "PER_ROW"], prune_configs_by={"early_config_prune": early_config_prune})
@triton.jit
def matmul_fp8_persistent_tma_kernel(
a_desc_ptr,
b_desc_ptr,
scale_a_ptr,
scale_b_ptr,
out_desc_ptr,
M,
K,
N,
PER_ROW: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
NUM_SMS: tl.constexpr,
USE_FAST_ACCUM: tl.constexpr = True,
_in_element_dtype: tl.constexpr = None,
_out_element_dtype: tl.constexpr = None,
):
start_pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
k_tiles = tl.cdiv(K, BLOCK_K)
num_tiles = num_pid_m * num_pid_n
tiles_per_SM = num_tiles // NUM_SMS
if start_pid < num_tiles % NUM_SMS:
tiles_per_SM += 1
tile_id = start_pid - NUM_SMS
ki = -1
num_pid_in_group = GROUP_M * num_pid_n
pid_m = 0
pid_n = 0
offsets_am = 0
offsets_bn = 0
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
if scale_a_ptr is not None and scale_b_ptr is not None:
if PER_ROW:
scale_a = tl.load(scale_a_ptr + offsets_am).to(tl.float32)
scale_b = tl.load(scale_b_ptr + offsets_bn).to(tl.float32)
scale = scale_a[:, None] * scale_b[None, :]
else:
scale_a = tl.load(scale_a_ptr).to(tl.float32)
scale_b = tl.load(scale_b_ptr).to(tl.float32)
scale = scale_a * scale_b
for _ in range(k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
if ki == 0:
tile_id += NUM_SMS
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
offsets_am = pid_m * BLOCK_M
offsets_bn = pid_n * BLOCK_N
offsets_k = ki * BLOCK_K
a = tl._experimental_descriptor_load(a_desc_ptr, [offsets_am, offsets_k], [BLOCK_M, BLOCK_K], _in_element_dtype)
b = tl._experimental_descriptor_load(b_desc_ptr, [offsets_bn, offsets_k], [BLOCK_N, BLOCK_K], _in_element_dtype)
if USE_FAST_ACCUM:
accumulator = tl.dot(a, b.T, accumulator, out_dtype=tl.float32)
else:
accumulator += tl.dot(a, b.T, out_dtype=tl.float32)
if ki == k_tiles - 1:
if scale_a_ptr is not None and scale_b_ptr is not None:
accumulator = scale * accumulator
accumulator = accumulator.to(_out_element_dtype)
if SPLIT_K == 1:
tl._experimental_descriptor_store(out_desc_ptr, accumulator, [offsets_am, offsets_bn])
else:
tl.static_assert(True, "Persistent kernel does not support SPLIT_K > 1")
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
def triton_fp8_persistent_tma(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out: torch.Tensor):
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
M, K = a.shape
_, N = b.shape
desc_helper = TMAAutotuneHelper()
desc_helper.init_tma_descriptor("a")
desc_helper.init_tma_descriptor("b")
desc_helper.init_tma_descriptor("out")
def grid(META):
desc_helper.fill_2d_tma_descriptor("a", a.data_ptr(), M, K, META["BLOCK_M"], META["BLOCK_K"], a.element_size())
desc_helper.fill_2d_tma_descriptor("b", b.data_ptr(), N, K, META["BLOCK_N"], META["BLOCK_K"], b.element_size())
desc_helper.fill_2d_tma_descriptor("out", out.data_ptr(), M, N, META["BLOCK_M"], META["BLOCK_N"], out.element_size())
return (min(NUM_SMS, triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"])),)
desc_a = desc_helper.get_tma_descriptor_kernel_param("a")
desc_b = desc_helper.get_tma_descriptor_kernel_param("b")
desc_out = desc_helper.get_tma_descriptor_kernel_param("out")
_in_element_dtype = TORCH_TO_TRITON_DTYPE.get(a.dtype)
_out_element_dtype = TORCH_TO_TRITON_DTYPE.get(out.dtype)
per_row = scale_a.shape[0] > 1
matmul_fp8_persistent_tma_kernel[grid](desc_a, desc_b, scale_a, scale_b, desc_out, M=M, K=K, N=N, NUM_SMS=NUM_SMS, PER_ROW=per_row, _in_element_dtype=_in_element_dtype, _out_element_dtype=_out_element_dtype)
# Persistent matmul kernel utilizing specialized TMA hardware on Hopper and newer GPUs, along with cooperative warp specialization.
@triton.autotune(configs=get_autotune_configs(warp_specialization=True), key=["M", "K", "N", "PER_ROW"], prune_configs_by={"early_config_prune": early_config_prune})
@triton.jit
def matmul_fp8_persistent_tma_ws_cooperative_kernel(
a_desc_ptr,
b_desc_ptr,
scale_a_ptr,
scale_b_ptr,
out_desc_ptr,
M,
K,
N,
PER_ROW: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
NUM_CONSUMER_GROUPS: tl.constexpr,
USE_FAST_ACCUM: tl.constexpr = True,
_in_element_dtype: tl.constexpr = None,
_out_element_dtype: tl.constexpr = None,
):
pid = tl.program_id(0)
num_programs = tl.num_programs(0)
num_tiles = tl.cdiv(M, BLOCK_M) * tl.cdiv(N, BLOCK_N)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_in_group = GROUP_M * num_pid_n
num_chunks_k = tl.cdiv(K, BLOCK_K)
if scale_a_ptr is not None and scale_b_ptr is not None:
if PER_ROW:
scale_a = tl.load(scale_a_ptr + tl.arange(0, BLOCK_M)).to(tl.float32)
scale_b = tl.load(scale_b_ptr + tl.arange(0, BLOCK_N)).to(tl.float32)
scale = scale_a[:, None] * scale_b[None, :]
else:
scale_a = tl.load(scale_a_ptr).to(tl.float32)
scale_b = tl.load(scale_b_ptr).to(tl.float32)
scale = scale_a * scale_b
for pid in range(pid, num_tiles, num_programs):
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offsets_am = pid_m * BLOCK_M
offsets_bn = pid_n * BLOCK_N
offsets_k = 0
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for _ in range(0, num_chunks_k):
a = tl._experimental_descriptor_load(a_desc_ptr, [offsets_am, offsets_k], [BLOCK_M, BLOCK_K], _in_element_dtype)
b = tl._experimental_descriptor_load(b_desc_ptr, [offsets_bn, offsets_k], [BLOCK_N, BLOCK_K], _in_element_dtype)
if USE_FAST_ACCUM:
accumulator = tl.dot(a, b.T, accumulator, out_dtype=tl.float32)
else:
accumulator += tl.dot(a, b.T, out_dtype=tl.float32)
offsets_k += BLOCK_K
if scale_a_ptr is not None and scale_b_ptr is not None:
accumulator = scale * accumulator
accumulator = accumulator.to(_out_element_dtype)
if SPLIT_K == 1:
tl._experimental_descriptor_store(out_desc_ptr, accumulator, [offsets_am, offsets_bn])
else:
tl.static_assert(True, "Persistent kernel does not support SPLIT_K > 1")
def triton_fp8_persistent_tma_ws_cooperative(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out: torch.Tensor):
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
M, K = a.shape
_, N = b.shape
desc_helper = TMAAutotuneHelper()
desc_helper.init_tma_descriptor("a")
desc_helper.init_tma_descriptor("b")
desc_helper.init_tma_descriptor("out")
def grid(META):
BLOCK_M = META["BLOCK_M"] // META["NUM_CONSUMER_GROUPS"]
desc_helper.fill_2d_tma_descriptor("a", a.data_ptr(), M, K, BLOCK_M, META["BLOCK_K"], a.element_size())
desc_helper.fill_2d_tma_descriptor("b", b.data_ptr(), N, K, META["BLOCK_N"], META["BLOCK_K"], b.element_size())
desc_helper.fill_2d_tma_descriptor("out", out.data_ptr(), M, N, BLOCK_M, META["BLOCK_N"], out.element_size())
return (min(NUM_SMS, triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"])),)
desc_a = desc_helper.get_tma_descriptor_kernel_param("a")
desc_b = desc_helper.get_tma_descriptor_kernel_param("b")
desc_out = desc_helper.get_tma_descriptor_kernel_param("out")
_in_element_dtype = TORCH_TO_TRITON_DTYPE.get(a.dtype)
_out_element_dtype = TORCH_TO_TRITON_DTYPE.get(out.dtype)
per_row = scale_a.shape[0] > 1
matmul_fp8_persistent_tma_ws_cooperative_kernel[grid](desc_a, desc_b, scale_a, scale_b, desc_out, M=M, K=K, N=N, PER_ROW=per_row, _in_element_dtype=_in_element_dtype, _out_element_dtype=_out_element_dtype)
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)
if HAS_TMA_DESC:
print("TMA benchmarks will be running with experimental grid constant TMA descriptor.", )
else:
print("TMA benchmarks will be running without grid constant TMA descriptor.", )
class TMAAutotuneHelper:
class KernelParamWrapper:
def __init__(self, desc):
self.desc = desc
def tma_desc_cpu_ptr(self):
return self.desc.data_ptr()
TMA_SIZE = 128
def __init__(self):
if HAS_TMA_DESC:
self.descriptors = {}
else:
self.cuda_descriptors = {}
def init_tma_descriptor(self, name):
if HAS_TMA_DESC:
self.descriptors[name] = torch.empty(TMAAutotuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8)
else:
self.cuda_descriptors[name] = torch.empty(TMAAutotuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8)
def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
if HAS_TMA_DESC:
desc_x = self.descriptors[name]
assert desc_x.data_ptr() % 64 == 0
triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dim, block_dim, element_size, desc_x.data_ptr())
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dim, block_dim, element_size, buf_x.data_ptr())
desc_x.copy_(buf_x, non_blocking=True)
def fill_2d_tma_descriptor(self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size):
if HAS_TMA_DESC:
desc_x = self.descriptors[name]
assert desc_x.data_ptr() % 64 == 0
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr())
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr())
desc_x.copy_(buf_x, non_blocking=True)
def get_tma_descriptor_kernel_param(self, name):
if HAS_TMA_DESC:
assert self.descriptors[name] is not None
return self.KernelParamWrapper(self.descriptors[name])
else:
assert self.cuda_descriptors[name] is not None
return self.cuda_descriptors[name]
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment