Created
August 31, 2025 12:35
-
-
Save a-r-r-o-w/ad460e6ff0155a5bf6511be7f6defcb5 to your computer and use it in GitHub Desktop.
triton autotuning somehow reports slower times
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch._dynamo.config | |
| import torch._inductor.config | |
| import triton | |
| import triton.language as tl | |
| torch._dynamo.config.cache_size_limit = 10000 | |
| torch._inductor.config.triton.cudagraphs = False | |
| torch._inductor.config.triton.cudagraph_trees = False | |
| def get_configs(): | |
| configs = [] | |
| configs.extend([ | |
| triton.Config({"BLOCK_SIZE": BLOCK_SIZE}, num_warps=num_warps, num_stages=1) | |
| for BLOCK_SIZE in [512, 1024] | |
| for num_warps in [4, 8] | |
| ]) | |
| return configs | |
| @triton.jit | |
| def vector_add_kernel(a_ptr, b_ptr, out_ptr, N: tl.constexpr, BLOCK_SIZE: tl.constexpr): | |
| pid = tl.program_id(0) | |
| block_start = pid * BLOCK_SIZE | |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) | |
| mask = offsets < N | |
| a = tl.load(a_ptr + offsets, mask=mask, other=0.0) | |
| b = tl.load(b_ptr + offsets, mask=mask, other=0.0) | |
| out = a + b | |
| tl.store(out_ptr + offsets, out, mask=mask) | |
| autotuned_vector_add_kernel = triton.autotune( | |
| configs=get_configs(), | |
| key=["N"], | |
| do_bench=lambda kernel_call, quantiles: triton.testing.do_bench(kernel_call, quantiles=quantiles, warmup=8, rep=32) | |
| )(vector_add_kernel) | |
| def vector_add_autotuned(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: | |
| assert a.shape == b.shape | |
| N = a.numel() | |
| out = torch.empty_like(a) | |
| grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) | |
| autotuned_vector_add_kernel[grid](a, b, out, N) | |
| return out | |
| def vector_add_static(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: | |
| assert a.shape == b.shape | |
| N = a.numel() | |
| out = torch.empty_like(a) | |
| BLOCK_SIZE = 1024 | |
| grid = (triton.cdiv(N, BLOCK_SIZE),) | |
| vector_add_kernel[grid](a, b, out, N, BLOCK_SIZE=1024, num_warps=4, num_stages=1) | |
| return out | |
| def vector_add_torch(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: | |
| return a + b | |
| vector_add_torch_compiled = torch.compile(vector_add_torch, mode="default", fullgraph=True, dynamic=False) | |
| def correctness(): | |
| N = 2**20 | |
| a = torch.randn(N, device='cuda', dtype=torch.float32) | |
| b = torch.randn(N, device='cuda', dtype=torch.float32) | |
| out = vector_add_autotuned(a, b) | |
| assert torch.allclose(out, a + b) | |
| print("Correctness test passed!") | |
| @triton.testing.perf_report( | |
| triton.testing.Benchmark( | |
| x_names=["N"], | |
| x_vals=[2**i for i in range(10, 25)], | |
| line_arg="provider", | |
| line_vals=["triton_autotuned", "triton_static", "torch_eager", "torch_compile"], | |
| line_names=["triton_autotuned", "triton_static", "torch_eager", "torch_compile"], | |
| styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("yellow", "-")], | |
| ylabel="GFLOPS", | |
| plot_name="vector_add", | |
| args={}, | |
| ) | |
| ) | |
| def benchmark_fwd(N, provider): | |
| a = torch.randn(N, device="cuda", dtype=torch.bfloat16) | |
| b = torch.randn(N, device="cuda", dtype=torch.bfloat16) | |
| if provider == "triton_autotuned": | |
| fn = lambda: vector_add_autotuned(a, b) | |
| elif provider == "triton_static": | |
| fn = lambda: vector_add_static(a, b) | |
| elif provider == "torch_eager": | |
| fn = lambda: vector_add_torch(a, b) | |
| elif provider == "torch_compile": | |
| fn = lambda: vector_add_torch_compiled(a, b) | |
| else: | |
| raise ValueError(f"Unknown provider: {provider}") | |
| ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=256, rep=1024, quantiles=[0.5, 0.2, 0.8]) | |
| flop = N | |
| gflops = (flop * 1e-9) / (ms * 1e-3) | |
| min_gflops = (flop * 1e-9) / (max_ms * 1e-3) | |
| max_gflops = (flop * 1e-9) / (min_ms * 1e-3) | |
| return gflops, min_gflops, max_gflops | |
| torch.manual_seed(42) | |
| correctness() | |
| benchmark_fwd.run(print_data=True, save_path="dump_vec_add") |
Author
a-r-r-o-w
commented
Aug 31, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment