Skip to content

Instantly share code, notes, and snippets.

@shunting314
Created June 13, 2025 00:08
Show Gist options
  • Save shunting314/f1cbdb264ac2024e6768cdea0adc3ea4 to your computer and use it in GitHub Desktop.
Save shunting314/f1cbdb264ac2024e6768cdea0adc3ea4 to your computer and use it in GitHub Desktop.
# quick fix: https://gist.github.com/shunting314/0d32fe66ba6c771a3cc69574fab359c6
import torch
from triton.testing import do_bench
import functools
from torch._inductor import config
from torch._dynamo.decorators import mark_dynamic
import os
@torch.compile(dynamic=True)
def f(x):
return x.sum(dim=0)
N = 512
C = functools.partial(torch.randn, device="cuda")
x_small = C(4096, N)
x_large = C(4096 * 1000, N)
if os.getenv("HINT_WITH_SMALL_INPUT") == "1":
x = x_small
else:
x = x_large
mark_dynamic(x, 0)
f(x)
ms = do_bench(lambda: f(x_large))
# 4.03ms if hint with large input. Output code: https://gist.github.com/shunting314/0be562a0c14f8ec0852b12bbf53d7a15
# 8.32ms if hint with small input. Output code: https://gist.github.com/shunting314/79b924c266d5c562703c3bdfb48d8272
print(ms)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment