Skip to content

Instantly share code, notes, and snippets.

@shunting314
Created June 13, 2025 19:43
Show Gist options
  • Save shunting314/19ba2bfd67573fe0f021042b3bba13c2 to your computer and use it in GitHub Desktop.
Save shunting314/19ba2bfd67573fe0f021042b3bba13c2 to your computer and use it in GitHub Desktop.
# quick fix: https://gist.github.com/shunting314/0d32fe66ba6c771a3cc69574fab359c6
import torch
from triton.testing import do_bench
import functools
from torch._inductor import config
from torch._dynamo.decorators import mark_dynamic
import os
@torch.compile(dynamic=True)
def f(x):
return x.sum(dim=0)
N = 512
C = functools.partial(torch.randn, device="cuda")
x_small = C(4096, N)
x_large = C(4096 * 1000, N)
if os.getenv("HINT_WITH_SMALL_INPUT") == "1":
x = x_small
else:
x = x_large
mark_dynamic(x, 0)
f(x)
ms = do_bench(lambda: f(x_large))
# 4.03ms if hint with large input. Output code: https://gist.github.com/shunting314/0be562a0c14f8ec0852b12bbf53d7a15
# 8.32ms if hint with small input. Output code: https://gist.github.com/shunting314/79b924c266d5c562703c3bdfb48d8272
# 3.92ms if hint with small input, and force min num split: Output code: https://gist.github.com/shunting314/c82917a1849b698bf4d2be2fde2fd2ba
print(ms)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment