Skip to content

Instantly share code, notes, and snippets.

@a-r-r-o-w
Created July 8, 2025 08:16
Show Gist options
  • Save a-r-r-o-w/32ba9df444b60f23786ae9177b05451f to your computer and use it in GitHub Desktop.
Save a-r-r-o-w/32ba9df444b60f23786ae9177b05451f to your computer and use it in GitHub Desktop.
Can be faster than torch.compile if you don't use masks! Almost always possible in common transformer scenarios with aligned block sizes
import torch
import triton
import triton.language as tl
torch._dynamo.config.cache_size_limit = 10000
ENABLE_TRITON = True
ENABLE_DEEP_AUTOTUNE = True
def get_autotune_configs():
configs = []
if ENABLE_DEEP_AUTOTUNE:
for BLOCK_M in [1, 2, 4, 8]:
for BLOCK_N in [512, 1024]:
for num_warps in [4, 8]:
for num_stages in [1]:
configs.append(
triton.Config(
{"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N},
num_warps=num_warps,
num_stages=num_stages,
)
)
else:
configs.append(triton.Config({"BLOCK_M": 1, "BLOCK_N": 1024}, num_warps=4, num_stages=1))
configs.append(triton.Config({"BLOCK_M": 1, "BLOCK_N": 1024}, num_warps=8, num_stages=1))
configs.append(triton.Config({"BLOCK_M": 2, "BLOCK_N": 1024}, num_warps=4, num_stages=1))
configs.append(triton.Config({"BLOCK_M": 2, "BLOCK_N": 1024}, num_warps=8, num_stages=1))
return configs
@triton.autotune(configs=get_autotune_configs(), key=["BATCH_SIZE", "EMBEDDING_DIM"])
@triton.jit
def adaptive_layernorm_zero_kernel(
ptr_x,
ptr_shift,
ptr_scale,
ptr_out,
EPS: tl.constexpr,
BATCH_SIZE,
EMBEDDING_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
UPCAST: tl.constexpr = True,
):
pid_m = tl.program_id(0)
m_offsets = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
n_offsets = tl.arange(0, BLOCK_N)
x_ptrs = ptr_x + m_offsets[:, None] * EMBEDDING_DIM + n_offsets[None, :]
out_ptrs = ptr_out + m_offsets[:, None] * EMBEDDING_DIM + n_offsets[None, :]
shift_ptrs = ptr_shift + n_offsets
scale_ptrs = ptr_scale + n_offsets
mean = tl.zeros((BLOCK_M,), dtype=tl.float32)
var = tl.zeros((BLOCK_M,), dtype=tl.float32)
for n_start in range(0, EMBEDDING_DIM, BLOCK_N):
x = tl.load(x_ptrs + n_start, eviction_policy="evict_last")
if UPCAST:
x = x.to(tl.float32)
mean += tl.sum(x, axis=1)
var += tl.sum(x * x, axis=1)
mean = mean / EMBEDDING_DIM
var = var / EMBEDDING_DIM - mean * mean
rstd = tl.rsqrt(var + EPS)
for n_start in range(0, EMBEDDING_DIM, BLOCK_N):
x = tl.load(x_ptrs + n_start, eviction_policy="evict_last")
shift = tl.load(shift_ptrs + n_start, eviction_policy="evict_last")
scale = tl.load(scale_ptrs + n_start, eviction_policy="evict_last")
if UPCAST:
x = x.to(tl.float32)
shift = shift.to(tl.float32)
scale = scale.to(tl.float32)
x = (x - mean[:, None]) * rstd[:, None]
x = x * (1 + scale) + shift
tl.store(out_ptrs + n_start, x)
@triton.jit
def welford_reduce(value, mean, m2, weight, first_iteration):
if first_iteration:
new_weight = tl.full(weight.shape, 1, weight.dtype)
new_mean = value
new_m2 = tl.zeros_like(m2)
else:
delta = value - mean
new_weight = weight + 1
new_mean = mean + delta / new_weight
new_m2 = m2 + delta * (value - new_mean)
return new_mean, new_m2, new_weight
@triton.jit
def welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
delta = mean_2 - mean_1
new_weight = weight_1 + weight_2
w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight)
return (
mean_1 + delta * w2_over_w,
m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,
new_weight,
)
@triton.jit
def welford(mean, m2, weight, dim):
return tl.reduce((mean, m2, weight), dim, welford_combine)
@triton.autotune(configs=get_autotune_configs(), key=["BATCH_SIZE", "EMBEDDING_DIM"])
@triton.jit
def adaptive_layernorm_zero_welford_kernel(
ptr_x,
ptr_shift,
ptr_scale,
ptr_out,
EPS: tl.constexpr,
BATCH_SIZE,
EMBEDDING_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid_m = tl.program_id(0)
m_offsets = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
n_offsets = tl.arange(0, BLOCK_N)
x_ptrs = ptr_x + m_offsets[:, None] * EMBEDDING_DIM + n_offsets[None, :]
out_ptrs = ptr_out + m_offsets[:, None] * EMBEDDING_DIM + n_offsets[None, :]
shift_ptrs = ptr_shift + n_offsets
scale_ptrs = ptr_scale + n_offsets
mean = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
m2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
weight = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for n_start in range(0, EMBEDDING_DIM, BLOCK_N):
x = tl.load(x_ptrs + n_start, eviction_policy="evict_last").to(tl.float32)
mean, m2, weight = welford_reduce(x, mean, m2, weight, n_start == 0)
mean, m2, weight = welford(mean, m2, weight, 1)
m2 = m2 / EMBEDDING_DIM
rstd = tl.rsqrt(m2 + EPS)
for n_start in range(0, EMBEDDING_DIM, BLOCK_N):
x = tl.load(x_ptrs + n_start, eviction_policy="evict_last").to(tl.float32)
shift = tl.load(shift_ptrs + n_start, eviction_policy="evict_last").to(tl.float32)
scale = tl.load(scale_ptrs + n_start, eviction_policy="evict_last").to(tl.float32)
x = (x - mean[:, None]) * rstd[:, None]
x = x * (1 + scale) + shift
tl.store(out_ptrs + n_start, x)
def adaptive_layernorm_zero_triton(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor, *, upcast: bool) -> torch.Tensor:
batch_size, seq_len, embedding_dim = x.shape
# We know this to be true considering the context of what we're optimizing
assert embedding_dim == 3072, "Embedding dimension must be 3072"
assert batch_size == 1, "Batch size must be 1"
assert shift.shape == scale.shape == (1, 1, embedding_dim), "Shift and scale must have shape (1, 1, embedding_dim)"
effective_batch_size = batch_size * seq_len
x = x.view(-1, embedding_dim)
shift = shift.view(embedding_dim)
scale = scale.view(embedding_dim)
out = torch.empty_like(x)
grid = lambda META: (triton.cdiv(effective_batch_size, META["BLOCK_M"]),)
adaptive_layernorm_zero_kernel[grid](
x,
shift,
scale,
out,
EPS=1e-6,
BATCH_SIZE=effective_batch_size,
EMBEDDING_DIM=3072,
UPCAST=upcast,
)
return out.view(batch_size, seq_len, embedding_dim)
def adaptive_layernorm_zero_welford_triton(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, embedding_dim = x.shape
# We know this to be true considering the context of what we're optimizing
assert embedding_dim == 3072, "Embedding dimension must be 3072"
assert batch_size == 1, "Batch size must be 1"
assert shift.shape == scale.shape == (1, 1, embedding_dim), "Shift and scale must have shape (1, 1, embedding_dim)"
effective_batch_size = batch_size * seq_len
x = x.view(-1, embedding_dim)
shift = shift.view(embedding_dim)
scale = scale.view(embedding_dim)
out = torch.empty_like(x)
grid = lambda META: (triton.cdiv(effective_batch_size, META["BLOCK_M"]),)
adaptive_layernorm_zero_welford_kernel[grid](
x,
shift,
scale,
out,
EPS=1e-6,
BATCH_SIZE=effective_batch_size,
EMBEDDING_DIM=3072,
)
return out.view(batch_size, seq_len, embedding_dim)
class AdaLayerNormZeroEager(torch.nn.Module):
def __init__(self, embedding_dim: int, bias=True):
super().__init__()
self.linear = torch.nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
self.norm = torch.nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
def forward(self, x: torch.Tensor, emb: torch.Tensor):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
x = self.norm(x)
x = torch.addcmul(shift_msa, x, 1 + scale_msa)
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class AdaLayerNormZeroTriton(torch.nn.Module):
def __init__(self, embedding_dim: int, bias=True, kernel_type: str = "welford"):
super().__init__()
self.linear = torch.nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
self.norm = torch.nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
self.kernel_type = kernel_type
assert kernel_type in ["nocast", "upcast", "welford"]
def forward(self, x: torch.Tensor, emb: torch.Tensor):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
if ENABLE_TRITON:
if self.kernel_type == "upcast":
x = adaptive_layernorm_zero_triton(x, shift_msa, scale_msa, upcast=True)
elif self.kernel_type == "nocast":
x = adaptive_layernorm_zero_triton(x, shift_msa, scale_msa, upcast=False)
elif self.kernel_type == "welford":
x = adaptive_layernorm_zero_welford_triton(x, shift_msa, scale_msa)
else:
x = self.norm(x)
x = torch.addcmul(shift_msa, x, 1 + scale_msa)
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
device = "cuda"
dtype = torch.bfloat16
batch_size = 1
seq_lens = [4096 + 128 * i for i in range(10)]
embedding_dim = 3072
model_eager = AdaLayerNormZeroEager(embedding_dim).to(device, dtype)
model_triton_upcast = AdaLayerNormZeroTriton(embedding_dim, kernel_type="upcast").to(device, dtype)
model_triton_nocast = AdaLayerNormZeroTriton(embedding_dim, kernel_type="nocast").to(device, dtype)
model_triton_welford = AdaLayerNormZeroTriton(embedding_dim, kernel_type="welford").to(device, dtype)
models = {}
models["eager"] = model_eager
models["triton_upcast"] = model_triton_upcast
models["triton_nocast"] = model_triton_nocast
models["triton_welford"] = model_triton_welford
models["eager_compiled_d"] = torch.compile(model_eager, mode="default", fullgraph=True)
models["eager_compiled_max"] = torch.compile(model_eager, mode="max-autotune", fullgraph=True)
def get_color_and_linestyle(n: int) -> tuple[str, str]:
color_names = ["red", "blue", "green", "orange", "purple", "brown", "pink", "gray", "olive", "cyan"]
line_styles = ["-", "--", "-.", ":"]
if n > len(color_names) * len(line_styles):
raise ValueError(f"Required {n=} styles but maximum is {len(color_names) * len(line_styles)}")
styles = []
for i in range(n):
color = color_names[i % len(color_names)]
linestyle = line_styles[i // len(color_names)]
styles.append((color, linestyle))
return styles
def get_inputs(seq_len: int):
x = torch.randn(batch_size, seq_len, embedding_dim, device=device, dtype=dtype)
emb = torch.randn(batch_size, 1, embedding_dim, device=device, dtype=dtype)
emb = model_eager.linear(emb)
return x, emb
def correctness():
x, emb = get_inputs(4608)
results = {}
for name, model in models.items():
results[name] = model(x, emb)
reference = results["eager"]
for name, result in results.items():
for i, (res, ref) in enumerate(zip(result, reference)):
diff = res - ref
absdiff = torch.abs(diff)
absmax = torch.max(absdiff)
mae = torch.mean(absdiff)
mse = torch.mean(diff * diff)
print(f"{name} {i}: absmax={absmax:.4}, mae={mae:.4}, mse={mse:.4}")
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["seq_len"],
x_vals=seq_lens,
x_log=False,
line_arg="provider",
line_vals=list(models.keys()),
line_names=[x.removeprefix("solution_") for x in models.keys()],
styles=get_color_and_linestyle(len(models)),
ylabel="time (ms)",
plot_name="AdaLN Zero",
args={},
)
)
def benchmark(seq_len: int, provider: str):
x, emb = get_inputs(seq_len)
fn = models[provider]
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fn(x, emb),
warmup=8,
rep=32,
quantiles=[0.5, 0.2, 0.8],
)
return ms, max_ms, min_ms
if __name__ == "__main__":
torch.manual_seed(0)
with torch.inference_mode():
correctness()
benchmark.run(print_data=True, save_path="dump_benchmark_adaln_zero")
@a-r-r-o-w
Copy link
Author

A100:

image

H100:

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment