Skip to content

Instantly share code, notes, and snippets.

@a-r-r-o-w
Created August 31, 2025 12:35
Show Gist options
  • Select an option

  • Save a-r-r-o-w/ad460e6ff0155a5bf6511be7f6defcb5 to your computer and use it in GitHub Desktop.

Select an option

Save a-r-r-o-w/ad460e6ff0155a5bf6511be7f6defcb5 to your computer and use it in GitHub Desktop.
triton autotuning somehow reports slower times
import torch
import torch._dynamo.config
import torch._inductor.config
import triton
import triton.language as tl
torch._dynamo.config.cache_size_limit = 10000
torch._inductor.config.triton.cudagraphs = False
torch._inductor.config.triton.cudagraph_trees = False
def get_configs():
configs = []
configs.extend([
triton.Config({"BLOCK_SIZE": BLOCK_SIZE}, num_warps=num_warps, num_stages=1)
for BLOCK_SIZE in [512, 1024]
for num_warps in [4, 8]
])
return configs
@triton.jit
def vector_add_kernel(a_ptr, b_ptr, out_ptr, N: tl.constexpr, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < N
a = tl.load(a_ptr + offsets, mask=mask, other=0.0)
b = tl.load(b_ptr + offsets, mask=mask, other=0.0)
out = a + b
tl.store(out_ptr + offsets, out, mask=mask)
autotuned_vector_add_kernel = triton.autotune(
configs=get_configs(),
key=["N"],
do_bench=lambda kernel_call, quantiles: triton.testing.do_bench(kernel_call, quantiles=quantiles, warmup=8, rep=32)
)(vector_add_kernel)
def vector_add_autotuned(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
assert a.shape == b.shape
N = a.numel()
out = torch.empty_like(a)
grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
autotuned_vector_add_kernel[grid](a, b, out, N)
return out
def vector_add_static(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
assert a.shape == b.shape
N = a.numel()
out = torch.empty_like(a)
BLOCK_SIZE = 1024
grid = (triton.cdiv(N, BLOCK_SIZE),)
vector_add_kernel[grid](a, b, out, N, BLOCK_SIZE=1024, num_warps=4, num_stages=1)
return out
def vector_add_torch(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return a + b
vector_add_torch_compiled = torch.compile(vector_add_torch, mode="default", fullgraph=True, dynamic=False)
def correctness():
N = 2**20
a = torch.randn(N, device='cuda', dtype=torch.float32)
b = torch.randn(N, device='cuda', dtype=torch.float32)
out = vector_add_autotuned(a, b)
assert torch.allclose(out, a + b)
print("Correctness test passed!")
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["N"],
x_vals=[2**i for i in range(10, 25)],
line_arg="provider",
line_vals=["triton_autotuned", "triton_static", "torch_eager", "torch_compile"],
line_names=["triton_autotuned", "triton_static", "torch_eager", "torch_compile"],
styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("yellow", "-")],
ylabel="GFLOPS",
plot_name="vector_add",
args={},
)
)
def benchmark_fwd(N, provider):
a = torch.randn(N, device="cuda", dtype=torch.bfloat16)
b = torch.randn(N, device="cuda", dtype=torch.bfloat16)
if provider == "triton_autotuned":
fn = lambda: vector_add_autotuned(a, b)
elif provider == "triton_static":
fn = lambda: vector_add_static(a, b)
elif provider == "torch_eager":
fn = lambda: vector_add_torch(a, b)
elif provider == "torch_compile":
fn = lambda: vector_add_torch_compiled(a, b)
else:
raise ValueError(f"Unknown provider: {provider}")
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=256, rep=1024, quantiles=[0.5, 0.2, 0.8])
flop = N
gflops = (flop * 1e-9) / (ms * 1e-3)
min_gflops = (flop * 1e-9) / (max_ms * 1e-3)
max_gflops = (flop * 1e-9) / (min_ms * 1e-3)
return gflops, min_gflops, max_gflops
torch.manual_seed(42)
correctness()
benchmark_fwd.run(print_data=True, save_path="dump_vec_add")
@a-r-r-o-w
Copy link
Author

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment