Created
August 4, 2025 12:14
-
-
Save a-r-r-o-w/28339b442d164084506c0967029968a8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| import time | |
| import torch | |
| import triton | |
| import triton.language as tl | |
| import triton.runtime as runtime | |
| import triton.tools.experimental_descriptor | |
| VERBOSE = False | |
| TORCH_TO_TRITON_DTYPE = { | |
| torch.float32: tl.float32, | |
| torch.float16: tl.float16, | |
| torch.bfloat16: tl.bfloat16, | |
| torch.float8_e4m3fn: tl.float8e4nv, | |
| } | |
| SM = torch.cuda.get_device_capability() | |
| # Computes quantiles from a list of values (e.g., elapsed times) | |
| def quantile(a, q, precision=5): | |
| n = len(a) | |
| a = sorted(a) | |
| def get_quantile(q): | |
| if not (0 <= q <= 1): | |
| raise ValueError("Quantiles must be in the range [0, 1]") | |
| point = q * (n - 1) | |
| lower = math.floor(point) | |
| upper = math.ceil(point) | |
| t = point - lower | |
| return round((1 - t) * a[lower] + t * a[upper], precision) | |
| return [get_quantile(q) for q in q] | |
| # Helper function to clear the device cache | |
| def clear_cache(cache): | |
| if triton.__version__ == "3.2.0": | |
| cache.zero_() | |
| elif hasattr(runtime.driver.active, "clear_cache"): | |
| runtime.driver.active.clear_cache(cache) | |
| else: | |
| raise RuntimeError("Unsupported Triton version.") | |
| # Helper function to check the correctness of a kernel output against ground truth output | |
| def correctness(x, truth, name): | |
| diff = x - truth | |
| absdiff = torch.abs(diff) | |
| absmax = absdiff.max() | |
| mae = absdiff.mean() | |
| mse = (diff ** 2).mean() | |
| cossim = torch.nn.functional.cosine_similarity(x.flatten(), truth.flatten(), dim=0) | |
| print(f"{name}: {absmax=:.3f}, {mae=:.3f}, {mse=:.3f} {cossim=:.3f}") | |
| # Helper function to compute the scale for fp8 quantization | |
| def get_scale(tensor, granularity): | |
| if granularity == "per_tensor": | |
| return tensor.abs().max().to(torch.float32).view(1, 1) | |
| elif granularity == "per_row": | |
| return tensor.abs().amax(dim=1, keepdim=True).to(torch.float32) | |
| else: | |
| raise ValueError(f"Unsupported granularity: {granularity}") | |
| # Helper function to benchmark a given function with specified parameters. | |
| # It runs the matmul kernel with input tensors of shape (M, K) and (K, N), | |
| # and measures the average time taken for the operation over multiple runs. | |
| def benchmark(fn, M: int, K: int, N: int, granularity: str, dtype=torch.float8_e4m3fn, device="cuda", num_warmups=8, num_repeats=32, out_dtype: torch.dtype = torch.bfloat16): | |
| # Make sure each run gets the same input for same shape (hacky but just want to get something running quickly) | |
| torch.manual_seed(42) | |
| init_dtype = torch.bfloat16 | |
| # Create input tensors | |
| a = torch.randn(M, K, dtype=init_dtype, device=device) | |
| b = torch.randn(K, N, dtype=init_dtype, device=device) | |
| scale_a = scale_b = None | |
| quant_min = torch.finfo(dtype).min | |
| quant_max = torch.finfo(dtype).max | |
| # Compute quantization scales (explanation in the blog post) | |
| if granularity == "per_tensor": | |
| scale_a = get_scale(a, granularity) | |
| scale_b = get_scale(b.t(), granularity) | |
| elif granularity == "per_row": | |
| scale_a = get_scale(a, granularity) | |
| scale_b = get_scale(b.t(), granularity) | |
| scale_b = scale_b.t().contiguous() | |
| if granularity is not None and dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: | |
| scale_a = scale_a / quant_max | |
| scale_b = scale_b / quant_max | |
| a = (a / scale_a).clamp(quant_min, quant_max) | |
| b = (b / scale_b).clamp(quant_min, quant_max) | |
| else: | |
| scale_a = scale_b = None | |
| a = a.to(dtype) | |
| b = b.to(dtype).t().contiguous().t() # Ensure b is column-major for cuBLAS compatibility | |
| out = torch.zeros((M, N), dtype=out_dtype, device=device) | |
| fn(a, b, scale_a, scale_b, out) | |
| torch.cuda.synchronize() | |
| time.sleep(0.5) | |
| cache = runtime.driver.active.get_empty_cache_for_benchmark() | |
| start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_repeats)] | |
| end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_repeats)] | |
| # Warmup phase to ensure the kernel is compiled and ready | |
| for _ in range(num_warmups): | |
| clear_cache(cache) | |
| out = torch.zeros((M, N), dtype=out_dtype, device=device) | |
| fn(a, b, scale_a, scale_b, out) | |
| torch.cuda.synchronize() | |
| time.sleep(0.5) | |
| # Benchmarking phase | |
| for i in range(num_repeats): | |
| clear_cache(cache) | |
| out = torch.zeros((M, N), dtype=out_dtype, device=device) | |
| start_events[i].record() | |
| fn(a, b, scale_a, scale_b, out) | |
| end_events[i].record() | |
| torch.cuda.synchronize() | |
| elapsed_times = [start.elapsed_time(end) for start, end in zip(start_events, end_events)] | |
| mean_time = sum(elapsed_times) / len(elapsed_times) | |
| quantile_times = quantile(elapsed_times, [0.5, 0.2, 0.8], precision=5) | |
| if VERBOSE: | |
| print(f"===== Benchmarking {fn.__name__} =====") | |
| print(f"M: {M}, K: {K}, N: {N}, granularity: {granularity}, dtype: {dtype}") | |
| print(f"Mean time: {mean_time:.5f} ms") | |
| print(f"Quantiles (0.5, 0.2, 0.8): {quantile_times}") | |
| print() | |
| return mean_time, out | |
| # Reference pytorch implementation for bf16 matmul. We will use it as a baseline for numerical correctness. | |
| def torch_mm(a, b, scale_a, scale_b, out): | |
| return torch.mm(a, b, out=out) | |
| # Reference pytorch implementation for fp8 scaled matrix multiplication. We will use it to compare performance | |
| # against Triton implementations. | |
| def torch_scaled_mm(a, b, scale_a, scale_b, out): | |
| return torch._scaled_mm(a, b, scale_a, scale_b, use_fast_accum=False, out=out) | |
| # Fast accumulation version of the scaled matrix multiplication in PyTorch. We will use it to compare performance | |
| # against Triton implementations. This function uses faster accumulation method, which has better performance | |
| # but different numerical properties. Search for "CUBLASLT_MATMUL_DESC_FAST_ACCUM" on https://docs.nvidia.com/cuda/cublas/ | |
| def torch_scaled_mm_fast_accum(a, b, scale_a, scale_b, out): | |
| return torch._scaled_mm(a, b, scale_a, scale_b, use_fast_accum=True, out=out) | |
| def main(): | |
| # M, K, N | |
| shapes = [ | |
| # (128, 128, 128), | |
| (512, 3072, 3072 * 4), | |
| (512, 3072 * 4, 3072), | |
| (1152, 3072, 3072 * 4), | |
| (1152, 3072 * 4, 3072), | |
| (2304, 3072, 3072 * 4), | |
| (2304, 3072 * 4, 3072), | |
| (4608, 3072, 3072 * 4), | |
| (4608, 3072 * 4, 3072), | |
| ] | |
| fp8_dtype = torch.float8_e4m3fn if SM >= (8, 9) else torch.bfloat16 | |
| USE_SPLIT_K = False | |
| out_dtype = torch.float16 if USE_SPLIT_K else torch.bfloat16 | |
| for shape in shapes: | |
| M, K, N = shape | |
| print(f"========== Running benchmarks for shape: M={M}, K={K}, N={N} ==========") | |
| # Benchmark torch.mm with bfloat16 for baseline | |
| baseline, baseline_out = benchmark(torch_mm, M, K, N, granularity=None, dtype=torch.bfloat16) | |
| # Benchmark torch._scaled_mm with fast accumulation | |
| if SM >= (8, 9): | |
| scaled_mm_fast_accum_pt, scaled_mm_fast_accum_pt_out = benchmark(torch_scaled_mm_fast_accum, M, K, N, granularity="per_tensor", dtype=fp8_dtype) | |
| scaled_mm_fast_accum_pr, scaled_mm_fast_accum_pr_out = benchmark(torch_scaled_mm_fast_accum, M, K, N, granularity="per_row", dtype=fp8_dtype) | |
| # Benchmark torch._scaled_mm without fast accumulation | |
| if SM >= (8, 9): | |
| scaled_mm_pt, scaled_mm_pt_out = benchmark(torch_scaled_mm, M, K, N, granularity="per_tensor", dtype=fp8_dtype) | |
| scaled_mm_pr, scaled_mm_pr_out = benchmark(torch_scaled_mm, M, K, N, granularity="per_row", dtype=fp8_dtype) | |
| # Benchmark triton fp8 per tensor | |
| triton_fp8_pt, triton_fp8_pt_out = benchmark(triton_fp8, M, K, N, granularity="per_tensor", dtype=fp8_dtype, out_dtype=out_dtype) | |
| triton_fp8_pt_out = triton_fp8_pt_out.to(torch.bfloat16) | |
| # Benchmark triton fp8 per row | |
| triton_fp8_pr, triton_fp8_pr_out = benchmark(triton_fp8, M, K, N, granularity="per_row", dtype=fp8_dtype, out_dtype=out_dtype) | |
| triton_fp8_pr_out = triton_fp8_pr_out.to(torch.bfloat16) | |
| # Benchmark triton fp8 per tensor persistent | |
| triton_fp8_pt_persistent, triton_fp8_pt_out_persistent = benchmark(triton_fp8_persistent, M, K, N, granularity="per_tensor", dtype=fp8_dtype, out_dtype=out_dtype) | |
| triton_fp8_pt_out_persistent = triton_fp8_pt_out_persistent.to(torch.bfloat16) | |
| # Benchmark triton fp8 per row persistent | |
| triton_fp8_pr_persistent, triton_fp8_pr_out_persistent = benchmark(triton_fp8_persistent, M, K, N, granularity="per_row", dtype=fp8_dtype, out_dtype=out_dtype) | |
| triton_fp8_pr_out_persistent = triton_fp8_pr_out_persistent.to(torch.bfloat16) | |
| # Benchmark triton fp8 per tensor persistent TMA | |
| triton_fp8_pt_persistent_tma, triton_fp8_pt_out_persistent_tma = benchmark(triton_fp8_persistent_tma, M, K, N, granularity="per_tensor", dtype=fp8_dtype, out_dtype=out_dtype) | |
| triton_fp8_pt_out_persistent_tma = triton_fp8_pt_out_persistent_tma.to(torch.bfloat16) | |
| # Benchmark triton fp8 per row persistent TMA | |
| triton_fp8_pr_persistent_tma, triton_fp8_pr_out_persistent_tma = benchmark(triton_fp8_persistent_tma, M, K, N, granularity="per_row", dtype=fp8_dtype, out_dtype=out_dtype) | |
| triton_fp8_pr_out_persistent_tma = triton_fp8_pr_out_persistent_tma.to(torch.bfloat16) | |
| # Benchmark triton fp8 per tensor persistent TMA with cooperative warp specialization | |
| triton_fp8_pt_persistent_tma_ws_cooperative, triton_fp8_pt_out_persistent_tma_ws_cooperative = benchmark(triton_fp8_persistent_tma_ws_cooperative, M, K, N, granularity="per_tensor", dtype=fp8_dtype, out_dtype=out_dtype) | |
| triton_fp8_pt_out_persistent_tma_ws_cooperative = triton_fp8_pt_out_persistent_tma_ws_cooperative.to(torch.bfloat16) | |
| # Benchmark triton fp8 per row persistent TMA with cooperative warp specialization | |
| triton_fp8_pr_persistent_tma_ws_cooperative, triton_fp8_pr_out_persistent_tma_ws_cooperative = benchmark(triton_fp8_persistent_tma_ws_cooperative, M, K, N, granularity="per_row", dtype=fp8_dtype, out_dtype=out_dtype) | |
| triton_fp8_pr_out_persistent_tma_ws_cooperative = triton_fp8_pr_out_persistent_tma_ws_cooperative.to(torch.bfloat16) | |
| if SM >= (8, 9): | |
| correctness(scaled_mm_fast_accum_pt_out, baseline_out, "scaled_mm_fast_accum_pt") | |
| correctness(scaled_mm_fast_accum_pr_out, baseline_out, "scaled_mm_fast_accum_pr") | |
| correctness(scaled_mm_pt_out, baseline_out, "scaled_mm_pt") | |
| correctness(scaled_mm_pr_out, baseline_out, "scaled_mm_pr") | |
| correctness(triton_fp8_pt_out, baseline_out, "triton_fp8_pt") | |
| correctness(triton_fp8_pr_out, baseline_out, "triton_fp8_pr") | |
| correctness(triton_fp8_pt_out_persistent, baseline_out, "triton_fp8_pt_persistent") | |
| correctness(triton_fp8_pr_out_persistent, baseline_out, "triton_fp8_pr_persistent") | |
| correctness(triton_fp8_pt_out_persistent_tma, baseline_out, "triton_fp8_pt_persistent_tma") | |
| correctness(triton_fp8_pr_out_persistent_tma, baseline_out, "triton_fp8_pr_persistent_tma") | |
| correctness(triton_fp8_pt_out_persistent_tma_ws_cooperative, baseline_out, "triton_fp8_pt_persistent_tma_ws_cooperative") | |
| correctness(triton_fp8_pr_out_persistent_tma_ws_cooperative, baseline_out, "triton_fp8_pr_persistent_tma_ws_cooperative") | |
| if SM >= (8, 9): | |
| print(f"speedup scaled_mm_fast_accum_pt: {baseline / scaled_mm_fast_accum_pt:.2f}x") | |
| print(f"speedup scaled_mm_fast_accum_pr: {baseline / scaled_mm_fast_accum_pr:.2f}x") | |
| print(f"speedup scaled_mm_pt: {baseline / scaled_mm_pt:.2f}x") | |
| print(f"speedup scaled_mm_pr: {baseline / scaled_mm_pr:.2f}x") | |
| print(f"speedup triton_fp8_pt: {baseline / triton_fp8_pt:.2f}x") | |
| print(f"speedup triton_fp8_pr: {baseline / triton_fp8_pr:.2f}x") | |
| print(f"speedup triton_fp8_pt_persistent: {baseline / triton_fp8_pt_persistent:.2f}x") | |
| print(f"speedup triton_fp8_pr_persistent: {baseline / triton_fp8_pr_persistent:.2f}x") | |
| print(f"speedup triton_fp8_pt_persistent_tma: {baseline / triton_fp8_pt_persistent_tma:.2f}x") | |
| print(f"speedup triton_fp8_pr_persistent_tma: {baseline / triton_fp8_pr_persistent_tma:.2f}x") | |
| print(f"speedup triton_fp8_pt_persistent_tma_ws_cooperative: {baseline / triton_fp8_pt_persistent_tma_ws_cooperative:.2f}x") | |
| print(f"speedup triton_fp8_pr_persistent_tma_ws_cooperative: {baseline / triton_fp8_pr_persistent_tma_ws_cooperative:.2f}x") | |
| print() | |
| # Configs for autotuning triton kernels. This is a dense grid of configurations, but ideally | |
| # should be pruned based on the problem sizes you're working with and utilizing known best configurations. | |
| # Explanation of the parameters: | |
| # - BLOCK_M: Number of rows each program processes from the input matrix A and produces in the output matrix. | |
| # - BLOCK_N: Number of columns each program processes from the input matrix B and produces in the output matrix. | |
| # Together with BLOCK_M, defines the size of the output submatrix per program. | |
| # - BLOCK_K: Width of the reduction tile. The K-dimension is divided into chunks of this size. | |
| # Each program accumulates partial results over (K // BLOCK_K) steps. | |
| # - GROUP_M: Improves L2 Cache utilization. See the official triton documentation for more details: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html | |
| # - SPLIT_K: Splits the K-dimension across multiple programs to reduce memory pressure. Use SPLIT_K > 1 only | |
| # if reduction becomes a bottleneck (for large K). | |
| # - num_warps: Number of warps (each warp is 32 threads) assigned per program. Higher values increase parallelism | |
| # but also resource usage (registers/shared memory). 4 or 8 warps are typical choices. | |
| # - num_stages: Number of pipeline stages used to overlap computation and memory loads. More stages can improve | |
| # latency hiding with better overlap, but this may increase resource pressure and lead to slow down | |
| # if not tuned properly. | |
| # | |
| # Warp-specialization parameters (enabled when warp_specialization=True): | |
| # - NUM_CONSUMER_GROUPS and num_buffers_warp_spec: Specific to Hopper and newer GPUs | |
| # https://pytorch.org/blog/warp-specialization/ | |
| def get_autotune_configs(warp_specialization: bool = False): | |
| if not warp_specialization: | |
| configs = [ | |
| triton.Config( | |
| { | |
| "BLOCK_M": BLOCK_M, | |
| "BLOCK_N": BLOCK_N, | |
| "BLOCK_K": BLOCK_K, | |
| "GROUP_M": GROUP_M, | |
| "SPLIT_K": SPLIT_K, | |
| }, | |
| num_warps=num_warps, | |
| num_stages=num_stages, | |
| ) | |
| for BLOCK_M in [64, 128] | |
| for BLOCK_N in [64, 128] | |
| for BLOCK_K in [64, 128, 256] | |
| for GROUP_M in [8] | |
| # for SPLIT_K in [1, 2, 4] # Seems like there's no speedup with SPLIT_K > 1, so we keep it simple | |
| for SPLIT_K in [1] | |
| for num_warps in [4, 8] | |
| # for num_stages in [3, 4] | |
| for num_stages in [3] | |
| ] | |
| else: | |
| configs = [ | |
| triton.Config( | |
| { | |
| "BLOCK_M": BLOCK_M, | |
| "BLOCK_N": BLOCK_N, | |
| "BLOCK_K": BLOCK_K, | |
| "GROUP_M": GROUP_M, | |
| "SPLIT_K": SPLIT_K, | |
| "NUM_CONSUMER_GROUPS": NUM_CONSUMER_GROUPS, | |
| }, | |
| num_stages=num_stages, | |
| num_warps=num_warps, | |
| num_consumer_groups=0 if NUM_CONSUMER_GROUPS == 1 else NUM_CONSUMER_GROUPS, | |
| num_buffers_warp_spec=num_buffers_warp_spec, | |
| ) | |
| for BLOCK_M in [64, 128] | |
| for BLOCK_N in [128, 256] | |
| for BLOCK_K in [64, 128] | |
| for GROUP_M in [8] | |
| for SPLIT_K in [1] | |
| for num_warps in [4] | |
| for num_stages in [2, 3] | |
| for NUM_CONSUMER_GROUPS in [1, 2] | |
| for num_buffers_warp_spec in [3] | |
| ] | |
| return configs | |
| # Simple rules to prune dense grid of autotuning configurations based on problem size | |
| def early_config_prune(configs, args, **kwargs): | |
| device = torch.cuda.current_device() | |
| max_shared_memory = triton.runtime.driver.active.utils.get_device_properties(device)["max_shared_mem"] | |
| is_persistent_tma = "_in_element_dtype" in kwargs and "_out_element_dtype" in kwargs | |
| if "a_ptr" in args: | |
| element_size = args["a_ptr"].element_size() | |
| elif "a_desc_ptr" in args: | |
| in_element_dtype = kwargs.get("_in_element_dtype") | |
| element_size = in_element_dtype.primitive_bitwidth // 8 | |
| else: | |
| raise ValueError("Either 'a_ptr' or 'a_desc_ptr' must be provided in args") | |
| M = kwargs.get("M") | |
| K = kwargs.get("K") | |
| N = kwargs.get("N") | |
| pruned_configs = [] | |
| for config in configs: | |
| BLOCK_M = config.kwargs["BLOCK_M"] | |
| BLOCK_N = config.kwargs["BLOCK_N"] | |
| BLOCK_K = config.kwargs["BLOCK_K"] | |
| SPLIT_K = config.kwargs["SPLIT_K"] | |
| num_stages = config.num_stages | |
| required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * element_size | |
| if required_shared_memory > max_shared_memory: | |
| continue | |
| if M % BLOCK_M != 0: | |
| continue | |
| if N % BLOCK_N != 0: | |
| continue | |
| if K < 4096 and SPLIT_K > 1: | |
| continue | |
| if K % (BLOCK_K * SPLIT_K) != 0: | |
| continue | |
| if is_persistent_tma and BLOCK_K >= 256: | |
| # Triton hangs indefinitely when using persistent matmul with TMA when BLOCK_K >= 256 | |
| continue | |
| if not is_persistent_tma and BLOCK_K <= 64: | |
| # Small BLOCK_K not fast other than when using TMA | |
| continue | |
| if config.kwargs.get("NUM_CONSUMER_GROUPS", 1) > 1 and BLOCK_M <= 64: | |
| # Hangs indefinitely | |
| continue | |
| pruned_configs.append(config) | |
| print(f"Pruned configs: {len(pruned_configs)} out of {len(configs)}") | |
| return pruned_configs | |
| # Kernel for Triton FP8 matrix multiplication. | |
| @triton.autotune(configs=get_autotune_configs(), key=["M", "K", "N", "PER_ROW"], prune_configs_by={"early_config_prune": early_config_prune}) | |
| @triton.jit | |
| def matmul_fp8_kernel( | |
| a_ptr, | |
| b_ptr, | |
| scale_a_ptr, | |
| scale_b_ptr, | |
| out_ptr, | |
| M, | |
| K, | |
| N, | |
| PER_ROW: tl.constexpr, | |
| BLOCK_M: tl.constexpr, | |
| BLOCK_N: tl.constexpr, | |
| BLOCK_K: tl.constexpr, | |
| GROUP_M: tl.constexpr, | |
| SPLIT_K: tl.constexpr, | |
| USE_FAST_ACCUM: tl.constexpr = True, | |
| ): | |
| pid = tl.program_id(0) | |
| pid_k = tl.program_id(1) | |
| # L2 cache optimization: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html#l2-cache-optimizations | |
| num_pid_m = tl.cdiv(M, BLOCK_M) | |
| num_pid_n = tl.cdiv(N, BLOCK_N) | |
| num_pid_in_group = num_pid_n * GROUP_M | |
| group_id = pid // num_pid_in_group | |
| first_pid_m = group_id * GROUP_M | |
| group_size_m = min(num_pid_m - first_pid_m, GROUP_M) | |
| pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) | |
| pid_n = (pid % num_pid_in_group) // group_size_m | |
| # SPLIT_K is the umber of programs K dimension reduction has been split across | |
| # chunk_size_k is the size of each split across the programs handling the K dimension | |
| chunk_size_k = K // SPLIT_K | |
| # num_chunks_k is the number of chunks of size BLOCK_K, which each program iterates over to perform reduction | |
| num_chunks_k = chunk_size_k // BLOCK_K | |
| offsets_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
| offsets_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
| offsets_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) | |
| offsets_am = tl.max_contiguous(tl.multiple_of(offsets_m, BLOCK_M), BLOCK_M) | |
| offsets_bn = tl.max_contiguous(tl.multiple_of(offsets_n, BLOCK_N), BLOCK_N) | |
| # Pointers to a block of A [BLOCK_M, BLOCK_K] and B [BLOCK_K, BLOCK_N] | |
| ptrs_a = a_ptr + offsets_am[:, None] * K + offsets_k[None, :] | |
| ptrs_b = b_ptr + offsets_k[:, None] + offsets_bn[None, :] * K | |
| # Tile of the output matrix that every program computes | |
| accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) | |
| # Load scales if doing scaled mm | |
| if scale_a_ptr is not None and scale_b_ptr is not None: | |
| if PER_ROW: | |
| scale_a = tl.load(scale_a_ptr + offsets_am).to(tl.float32) | |
| scale_b = tl.load(scale_b_ptr + offsets_bn).to(tl.float32) | |
| scale = scale_a[:, None] * scale_b[None, :] | |
| else: | |
| scale_a = tl.load(scale_a_ptr).to(tl.float32) | |
| scale_b = tl.load(scale_b_ptr).to(tl.float32) | |
| scale = scale_a * scale_b | |
| # Iterate over the K dimension in chunks of BLOCK_K, reduce across these chunks, and accumulate | |
| for k in range(num_chunks_k): | |
| a = tl.load(ptrs_a) | |
| b = tl.load(ptrs_b) | |
| if USE_FAST_ACCUM: | |
| accumulator = tl.dot(a, b, accumulator, out_dtype=tl.float32) | |
| else: | |
| accumulator += tl.dot(a, b, out_dtype=tl.float32) | |
| ptrs_a += BLOCK_K * SPLIT_K | |
| ptrs_b += BLOCK_K * SPLIT_K | |
| # Multiply the scale if provided. | |
| # For per-tensor scaling, all output values [BLOCK_M, BLOCK_N] are multiplied by the same scale. | |
| # For per-row scaling, each row of the output is multiplied by the corresponding scale. | |
| if scale_a_ptr is not None and scale_b_ptr is not None: | |
| accumulator = scale * accumulator | |
| # Cast the accumulator to output dtype (fp8). | |
| accumulator = accumulator.to(out_ptr.dtype.element_ty) | |
| ptrs_out = out_ptr + offsets_am[:, None] * N + offsets_bn[None, :] | |
| if SPLIT_K == 1: | |
| # If we haven't split the reduction dimension across different programs, we can store the result directly. | |
| tl.store(ptrs_out, accumulator) | |
| else: | |
| # If we have split the reduction dimension, each [BLOCK_M, BLOCK_N] tiles needs to be accumulated | |
| # across all SPLIT_K programs to produce the final result. | |
| tl.atomic_add(ptrs_out, accumulator, sem="relaxed") | |
| def triton_fp8(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out: torch.Tensor): | |
| M, K = a.shape | |
| _, N = b.shape | |
| per_row = scale_a.shape[0] > 1 | |
| grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"]) | |
| matmul_fp8_kernel[grid](a, b, scale_a, scale_b, out, M=M, K=K, N=N, PER_ROW=per_row) | |
| # Persistent kernel for FP8 matrix multiplication. | |
| # A GPU has a fixed number of SMs. | |
| # With persistence, we launch as many thread blocks as possible to occupy available SMs. | |
| # Persistence thread blocks can compute one or more tiles of the output matrix depending on the problem size. | |
| # The normal kernel above launches multiple thread blocks multiple times to compute the entire matrix, but | |
| # with persistence, we save the time overhead of launching thread-blocks. | |
| @triton.autotune(configs=get_autotune_configs(), key=["M", "K", "N", "PER_ROW"], prune_configs_by={"early_config_prune": early_config_prune}) | |
| @triton.jit | |
| def matmul_fp8_persistent_kernel( | |
| a_ptr, | |
| b_ptr, | |
| scale_a_ptr, | |
| scale_b_ptr, | |
| out_ptr, | |
| M, | |
| K, | |
| N, | |
| PER_ROW: tl.constexpr, | |
| BLOCK_M: tl.constexpr, | |
| BLOCK_N: tl.constexpr, | |
| BLOCK_K: tl.constexpr, | |
| GROUP_M: tl.constexpr, | |
| SPLIT_K: tl.constexpr, | |
| NUM_SMS: tl.constexpr, | |
| USE_FAST_ACCUM: tl.constexpr = True, | |
| ): | |
| start_pid = tl.program_id(0) | |
| # Number of programs launched across each dimension. | |
| num_pid_m = tl.cdiv(M, BLOCK_M) | |
| num_pid_n = tl.cdiv(N, BLOCK_N) | |
| # Number of chunks of size BLOCK_K along the K dimension | |
| k_tiles = tl.cdiv(K, BLOCK_K) | |
| # Total number of tiles to be computed in order to cover the entire matrix. This | |
| # is also the total number of programs launched | |
| num_tiles = num_pid_m * num_pid_n | |
| # Number of output tiles each SM will compute results for | |
| tiles_per_SM = num_tiles // NUM_SMS | |
| if start_pid < num_tiles % NUM_SMS: | |
| tiles_per_SM += 1 | |
| tile_id = start_pid - NUM_SMS | |
| ki = -1 | |
| num_pid_in_group = GROUP_M * num_pid_n | |
| pid_m = 0 | |
| pid_n = 0 | |
| offsets_am = tl.arange(0, BLOCK_M) | |
| offsets_bn = tl.arange(0, BLOCK_N) | |
| accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) | |
| # Load scales if doing scaled mm | |
| if scale_a_ptr is not None and scale_b_ptr is not None: | |
| if PER_ROW: | |
| scale_a = tl.load(scale_a_ptr + offsets_am).to(tl.float32) | |
| scale_b = tl.load(scale_b_ptr + offsets_bn).to(tl.float32) | |
| scale = scale_a[:, None] * scale_b[None, :] | |
| else: | |
| scale_a = tl.load(scale_a_ptr).to(tl.float32) | |
| scale_b = tl.load(scale_b_ptr).to(tl.float32) | |
| scale = scale_a * scale_b | |
| # Total number of iterations denote number of chunked BLOCK_K reductions | |
| # across all tiles that each SM will handle. | |
| for _ in range(k_tiles * tiles_per_SM): | |
| # If reduction over K dimension is complete, we start processing the next tile. If | |
| # not, we continue reducing over the next chunk of size BLOCK_K. | |
| ki = tl.where(ki == k_tiles - 1, 0, ki + 1) | |
| if ki == 0: | |
| # Each time we start a new tile, we have to keep in mind that NUM_SMS programs are | |
| # already launched and processing their respective tiles. In order to avoid processing | |
| # the same tile, we offsets the tile_id by NUM_SMS. | |
| tile_id += NUM_SMS | |
| # L2 cache optimization: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html#l2-cache-optimizations | |
| group_id = tile_id // num_pid_in_group | |
| first_pid_m = group_id * GROUP_M | |
| group_size_m = min(num_pid_m - first_pid_m, GROUP_M) | |
| pid_m = first_pid_m + (tile_id % group_size_m) | |
| pid_n = (tile_id % num_pid_in_group) // group_size_m | |
| offsets_am = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
| offsets_bn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
| offsets_am = tl.max_contiguous(tl.multiple_of(offsets_am, BLOCK_M), BLOCK_M) | |
| offsets_bn = tl.max_contiguous(tl.multiple_of(offsets_bn, BLOCK_N), BLOCK_N) | |
| # Load A [BLOCK_M, BLOCK_K] and B [BLOCK_K, BLOCK_N] tiles for the current program, | |
| # reduce across the BLOCK_K dimension, and accumulate the results. | |
| offsets_k = ki * BLOCK_K + tl.arange(0, BLOCK_K) | |
| ptrs_a = a_ptr + offsets_am[:, None] * K + offsets_k[None, :] | |
| ptrs_b = b_ptr + offsets_k[:, None] + offsets_bn[None, :] * K | |
| a = tl.load(ptrs_a) | |
| b = tl.load(ptrs_b) | |
| if USE_FAST_ACCUM: | |
| accumulator = tl.dot(a, b, accumulator, out_dtype=tl.float32) | |
| else: | |
| accumulator += tl.dot(a, b, out_dtype=tl.float32) | |
| # After the last reduction over the K dimension for current tile, store | |
| # the accumulated results to the output tile [BLOCK_M, BLOCK_N]. | |
| if ki == k_tiles - 1: | |
| # Multiply the scale if provided. Cast to output dtype (fp8). | |
| if scale_a_ptr is not None and scale_b_ptr is not None: | |
| accumulator = scale * accumulator | |
| accumulator = accumulator.to(out_ptr.dtype.element_ty) | |
| offsets_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
| offsets_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
| ptrs_out = out_ptr + offsets_cm[:, None] * N + offsets_cn[None, :] | |
| if SPLIT_K == 1: | |
| tl.store(ptrs_out, accumulator) | |
| else: | |
| tl.static_assert(True, "Persistent kernel does not support SPLIT_K > 1") | |
| # Reset the accumulator for the next output tile. | |
| accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) | |
| def triton_fp8_persistent(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out: torch.Tensor): | |
| NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count | |
| M, K = a.shape | |
| _, N = b.shape | |
| per_row = scale_a.shape[0] > 1 | |
| grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"])),) | |
| matmul_fp8_persistent_kernel[grid](a, b, scale_a, scale_b, out, M=M, K=K, N=N, NUM_SMS=NUM_SMS, PER_ROW=per_row) | |
| # Persistent matmul kernel which makes use of specialized hardware units available on Hopper and newer GPUs: https://pytorch.org/blog/hopper-tma-unit/ | |
| @triton.autotune(configs=get_autotune_configs(), key=["M", "K", "N", "PER_ROW"], prune_configs_by={"early_config_prune": early_config_prune}) | |
| @triton.jit | |
| def matmul_fp8_persistent_tma_kernel( | |
| a_desc_ptr, | |
| b_desc_ptr, | |
| scale_a_ptr, | |
| scale_b_ptr, | |
| out_desc_ptr, | |
| M, | |
| K, | |
| N, | |
| PER_ROW: tl.constexpr, | |
| BLOCK_M: tl.constexpr, | |
| BLOCK_N: tl.constexpr, | |
| BLOCK_K: tl.constexpr, | |
| GROUP_M: tl.constexpr, | |
| SPLIT_K: tl.constexpr, | |
| NUM_SMS: tl.constexpr, | |
| USE_FAST_ACCUM: tl.constexpr = True, | |
| _in_element_dtype: tl.constexpr = None, | |
| _out_element_dtype: tl.constexpr = None, | |
| ): | |
| start_pid = tl.program_id(0) | |
| num_pid_m = tl.cdiv(M, BLOCK_M) | |
| num_pid_n = tl.cdiv(N, BLOCK_N) | |
| k_tiles = tl.cdiv(K, BLOCK_K) | |
| num_tiles = num_pid_m * num_pid_n | |
| tiles_per_SM = num_tiles // NUM_SMS | |
| if start_pid < num_tiles % NUM_SMS: | |
| tiles_per_SM += 1 | |
| tile_id = start_pid - NUM_SMS | |
| ki = -1 | |
| num_pid_in_group = GROUP_M * num_pid_n | |
| pid_m = 0 | |
| pid_n = 0 | |
| offsets_am = 0 | |
| offsets_bn = 0 | |
| accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) | |
| if scale_a_ptr is not None and scale_b_ptr is not None: | |
| if PER_ROW: | |
| scale_a = tl.load(scale_a_ptr + offsets_am).to(tl.float32) | |
| scale_b = tl.load(scale_b_ptr + offsets_bn).to(tl.float32) | |
| scale = scale_a[:, None] * scale_b[None, :] | |
| else: | |
| scale_a = tl.load(scale_a_ptr).to(tl.float32) | |
| scale_b = tl.load(scale_b_ptr).to(tl.float32) | |
| scale = scale_a * scale_b | |
| for _ in range(k_tiles * tiles_per_SM): | |
| ki = tl.where(ki == k_tiles - 1, 0, ki + 1) | |
| if ki == 0: | |
| tile_id += NUM_SMS | |
| group_id = tile_id // num_pid_in_group | |
| first_pid_m = group_id * GROUP_M | |
| group_size_m = min(num_pid_m - first_pid_m, GROUP_M) | |
| pid_m = first_pid_m + (tile_id % group_size_m) | |
| pid_n = (tile_id % num_pid_in_group) // group_size_m | |
| offsets_am = pid_m * BLOCK_M | |
| offsets_bn = pid_n * BLOCK_N | |
| offsets_k = ki * BLOCK_K | |
| a = tl._experimental_descriptor_load(a_desc_ptr, [offsets_am, offsets_k], [BLOCK_M, BLOCK_K], _in_element_dtype) | |
| b = tl._experimental_descriptor_load(b_desc_ptr, [offsets_bn, offsets_k], [BLOCK_N, BLOCK_K], _in_element_dtype) | |
| if USE_FAST_ACCUM: | |
| accumulator = tl.dot(a, b.T, accumulator, out_dtype=tl.float32) | |
| else: | |
| accumulator += tl.dot(a, b.T, out_dtype=tl.float32) | |
| if ki == k_tiles - 1: | |
| if scale_a_ptr is not None and scale_b_ptr is not None: | |
| accumulator = scale * accumulator | |
| accumulator = accumulator.to(_out_element_dtype) | |
| if SPLIT_K == 1: | |
| tl._experimental_descriptor_store(out_desc_ptr, accumulator, [offsets_am, offsets_bn]) | |
| else: | |
| tl.static_assert(True, "Persistent kernel does not support SPLIT_K > 1") | |
| accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) | |
| def triton_fp8_persistent_tma(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out: torch.Tensor): | |
| NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count | |
| M, K = a.shape | |
| _, N = b.shape | |
| desc_helper = TMAAutotuneHelper() | |
| desc_helper.init_tma_descriptor("a") | |
| desc_helper.init_tma_descriptor("b") | |
| desc_helper.init_tma_descriptor("out") | |
| def grid(META): | |
| desc_helper.fill_2d_tma_descriptor("a", a.data_ptr(), M, K, META["BLOCK_M"], META["BLOCK_K"], a.element_size()) | |
| desc_helper.fill_2d_tma_descriptor("b", b.data_ptr(), N, K, META["BLOCK_N"], META["BLOCK_K"], b.element_size()) | |
| desc_helper.fill_2d_tma_descriptor("out", out.data_ptr(), M, N, META["BLOCK_M"], META["BLOCK_N"], out.element_size()) | |
| return (min(NUM_SMS, triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"])),) | |
| desc_a = desc_helper.get_tma_descriptor_kernel_param("a") | |
| desc_b = desc_helper.get_tma_descriptor_kernel_param("b") | |
| desc_out = desc_helper.get_tma_descriptor_kernel_param("out") | |
| _in_element_dtype = TORCH_TO_TRITON_DTYPE.get(a.dtype) | |
| _out_element_dtype = TORCH_TO_TRITON_DTYPE.get(out.dtype) | |
| per_row = scale_a.shape[0] > 1 | |
| matmul_fp8_persistent_tma_kernel[grid](desc_a, desc_b, scale_a, scale_b, desc_out, M=M, K=K, N=N, NUM_SMS=NUM_SMS, PER_ROW=per_row, _in_element_dtype=_in_element_dtype, _out_element_dtype=_out_element_dtype) | |
| # Persistent matmul kernel utilizing specialized TMA hardware on Hopper and newer GPUs, along with cooperative warp specialization. | |
| @triton.autotune(configs=get_autotune_configs(warp_specialization=True), key=["M", "K", "N", "PER_ROW"], prune_configs_by={"early_config_prune": early_config_prune}) | |
| @triton.jit | |
| def matmul_fp8_persistent_tma_ws_cooperative_kernel( | |
| a_desc_ptr, | |
| b_desc_ptr, | |
| scale_a_ptr, | |
| scale_b_ptr, | |
| out_desc_ptr, | |
| M, | |
| K, | |
| N, | |
| PER_ROW: tl.constexpr, | |
| BLOCK_M: tl.constexpr, | |
| BLOCK_N: tl.constexpr, | |
| BLOCK_K: tl.constexpr, | |
| GROUP_M: tl.constexpr, | |
| SPLIT_K: tl.constexpr, | |
| NUM_CONSUMER_GROUPS: tl.constexpr, | |
| USE_FAST_ACCUM: tl.constexpr = True, | |
| _in_element_dtype: tl.constexpr = None, | |
| _out_element_dtype: tl.constexpr = None, | |
| ): | |
| pid = tl.program_id(0) | |
| num_programs = tl.num_programs(0) | |
| num_tiles = tl.cdiv(M, BLOCK_M) * tl.cdiv(N, BLOCK_N) | |
| num_pid_m = tl.cdiv(M, BLOCK_M) | |
| num_pid_n = tl.cdiv(N, BLOCK_N) | |
| num_pid_in_group = GROUP_M * num_pid_n | |
| num_chunks_k = tl.cdiv(K, BLOCK_K) | |
| if scale_a_ptr is not None and scale_b_ptr is not None: | |
| if PER_ROW: | |
| scale_a = tl.load(scale_a_ptr + tl.arange(0, BLOCK_M)).to(tl.float32) | |
| scale_b = tl.load(scale_b_ptr + tl.arange(0, BLOCK_N)).to(tl.float32) | |
| scale = scale_a[:, None] * scale_b[None, :] | |
| else: | |
| scale_a = tl.load(scale_a_ptr).to(tl.float32) | |
| scale_b = tl.load(scale_b_ptr).to(tl.float32) | |
| scale = scale_a * scale_b | |
| for pid in range(pid, num_tiles, num_programs): | |
| group_id = pid // num_pid_in_group | |
| first_pid_m = group_id * GROUP_M | |
| group_size_m = min(num_pid_m - first_pid_m, GROUP_M) | |
| pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) | |
| pid_n = (pid % num_pid_in_group) // group_size_m | |
| offsets_am = pid_m * BLOCK_M | |
| offsets_bn = pid_n * BLOCK_N | |
| offsets_k = 0 | |
| accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) | |
| for _ in range(0, num_chunks_k): | |
| a = tl._experimental_descriptor_load(a_desc_ptr, [offsets_am, offsets_k], [BLOCK_M, BLOCK_K], _in_element_dtype) | |
| b = tl._experimental_descriptor_load(b_desc_ptr, [offsets_bn, offsets_k], [BLOCK_N, BLOCK_K], _in_element_dtype) | |
| if USE_FAST_ACCUM: | |
| accumulator = tl.dot(a, b.T, accumulator, out_dtype=tl.float32) | |
| else: | |
| accumulator += tl.dot(a, b.T, out_dtype=tl.float32) | |
| offsets_k += BLOCK_K | |
| if scale_a_ptr is not None and scale_b_ptr is not None: | |
| accumulator = scale * accumulator | |
| accumulator = accumulator.to(_out_element_dtype) | |
| if SPLIT_K == 1: | |
| tl._experimental_descriptor_store(out_desc_ptr, accumulator, [offsets_am, offsets_bn]) | |
| else: | |
| tl.static_assert(True, "Persistent kernel does not support SPLIT_K > 1") | |
| def triton_fp8_persistent_tma_ws_cooperative(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out: torch.Tensor): | |
| NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count | |
| M, K = a.shape | |
| _, N = b.shape | |
| desc_helper = TMAAutotuneHelper() | |
| desc_helper.init_tma_descriptor("a") | |
| desc_helper.init_tma_descriptor("b") | |
| desc_helper.init_tma_descriptor("out") | |
| def grid(META): | |
| BLOCK_M = META["BLOCK_M"] // META["NUM_CONSUMER_GROUPS"] | |
| desc_helper.fill_2d_tma_descriptor("a", a.data_ptr(), M, K, BLOCK_M, META["BLOCK_K"], a.element_size()) | |
| desc_helper.fill_2d_tma_descriptor("b", b.data_ptr(), N, K, META["BLOCK_N"], META["BLOCK_K"], b.element_size()) | |
| desc_helper.fill_2d_tma_descriptor("out", out.data_ptr(), M, N, BLOCK_M, META["BLOCK_N"], out.element_size()) | |
| return (min(NUM_SMS, triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"])),) | |
| desc_a = desc_helper.get_tma_descriptor_kernel_param("a") | |
| desc_b = desc_helper.get_tma_descriptor_kernel_param("b") | |
| desc_out = desc_helper.get_tma_descriptor_kernel_param("out") | |
| _in_element_dtype = TORCH_TO_TRITON_DTYPE.get(a.dtype) | |
| _out_element_dtype = TORCH_TO_TRITON_DTYPE.get(out.dtype) | |
| per_row = scale_a.shape[0] > 1 | |
| matmul_fp8_persistent_tma_ws_cooperative_kernel[grid](desc_a, desc_b, scale_a, scale_b, desc_out, M=M, K=K, N=N, PER_ROW=per_row, _in_element_dtype=_in_element_dtype, _out_element_dtype=_out_element_dtype) | |
| HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl) | |
| if HAS_TMA_DESC: | |
| print("TMA benchmarks will be running with experimental grid constant TMA descriptor.", ) | |
| else: | |
| print("TMA benchmarks will be running without grid constant TMA descriptor.", ) | |
| class TMAAutotuneHelper: | |
| class KernelParamWrapper: | |
| def __init__(self, desc): | |
| self.desc = desc | |
| def tma_desc_cpu_ptr(self): | |
| return self.desc.data_ptr() | |
| TMA_SIZE = 128 | |
| def __init__(self): | |
| if HAS_TMA_DESC: | |
| self.descriptors = {} | |
| else: | |
| self.cuda_descriptors = {} | |
| def init_tma_descriptor(self, name): | |
| if HAS_TMA_DESC: | |
| self.descriptors[name] = torch.empty(TMAAutotuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8) | |
| else: | |
| self.cuda_descriptors[name] = torch.empty(TMAAutotuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8) | |
| def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size): | |
| if HAS_TMA_DESC: | |
| desc_x = self.descriptors[name] | |
| assert desc_x.data_ptr() % 64 == 0 | |
| triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dim, block_dim, element_size, desc_x.data_ptr()) | |
| else: | |
| desc_x = self.cuda_descriptors[name] | |
| buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) | |
| triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dim, block_dim, element_size, buf_x.data_ptr()) | |
| desc_x.copy_(buf_x, non_blocking=True) | |
| def fill_2d_tma_descriptor(self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size): | |
| if HAS_TMA_DESC: | |
| desc_x = self.descriptors[name] | |
| assert desc_x.data_ptr() % 64 == 0 | |
| triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()) | |
| else: | |
| desc_x = self.cuda_descriptors[name] | |
| buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) | |
| triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()) | |
| desc_x.copy_(buf_x, non_blocking=True) | |
| def get_tma_descriptor_kernel_param(self, name): | |
| if HAS_TMA_DESC: | |
| assert self.descriptors[name] is not None | |
| return self.KernelParamWrapper(self.descriptors[name]) | |
| else: | |
| assert self.cuda_descriptors[name] is not None | |
| return self.cuda_descriptors[name] | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment