Created
July 8, 2025 08:16
-
-
Save a-r-r-o-w/32ba9df444b60f23786ae9177b05451f to your computer and use it in GitHub Desktop.
Can be faster than torch.compile if you don't use masks! Almost always possible in common transformer scenarios with aligned block sizes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import triton | |
import triton.language as tl | |
torch._dynamo.config.cache_size_limit = 10000 | |
ENABLE_TRITON = True | |
ENABLE_DEEP_AUTOTUNE = True | |
def get_autotune_configs(): | |
configs = [] | |
if ENABLE_DEEP_AUTOTUNE: | |
for BLOCK_M in [1, 2, 4, 8]: | |
for BLOCK_N in [512, 1024]: | |
for num_warps in [4, 8]: | |
for num_stages in [1]: | |
configs.append( | |
triton.Config( | |
{"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N}, | |
num_warps=num_warps, | |
num_stages=num_stages, | |
) | |
) | |
else: | |
configs.append(triton.Config({"BLOCK_M": 1, "BLOCK_N": 1024}, num_warps=4, num_stages=1)) | |
configs.append(triton.Config({"BLOCK_M": 1, "BLOCK_N": 1024}, num_warps=8, num_stages=1)) | |
configs.append(triton.Config({"BLOCK_M": 2, "BLOCK_N": 1024}, num_warps=4, num_stages=1)) | |
configs.append(triton.Config({"BLOCK_M": 2, "BLOCK_N": 1024}, num_warps=8, num_stages=1)) | |
return configs | |
@triton.autotune(configs=get_autotune_configs(), key=["BATCH_SIZE", "EMBEDDING_DIM"]) | |
@triton.jit | |
def adaptive_layernorm_zero_kernel( | |
ptr_x, | |
ptr_shift, | |
ptr_scale, | |
ptr_out, | |
EPS: tl.constexpr, | |
BATCH_SIZE, | |
EMBEDDING_DIM: tl.constexpr, | |
BLOCK_M: tl.constexpr, | |
BLOCK_N: tl.constexpr, | |
UPCAST: tl.constexpr = True, | |
): | |
pid_m = tl.program_id(0) | |
m_offsets = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M | |
n_offsets = tl.arange(0, BLOCK_N) | |
x_ptrs = ptr_x + m_offsets[:, None] * EMBEDDING_DIM + n_offsets[None, :] | |
out_ptrs = ptr_out + m_offsets[:, None] * EMBEDDING_DIM + n_offsets[None, :] | |
shift_ptrs = ptr_shift + n_offsets | |
scale_ptrs = ptr_scale + n_offsets | |
mean = tl.zeros((BLOCK_M,), dtype=tl.float32) | |
var = tl.zeros((BLOCK_M,), dtype=tl.float32) | |
for n_start in range(0, EMBEDDING_DIM, BLOCK_N): | |
x = tl.load(x_ptrs + n_start, eviction_policy="evict_last") | |
if UPCAST: | |
x = x.to(tl.float32) | |
mean += tl.sum(x, axis=1) | |
var += tl.sum(x * x, axis=1) | |
mean = mean / EMBEDDING_DIM | |
var = var / EMBEDDING_DIM - mean * mean | |
rstd = tl.rsqrt(var + EPS) | |
for n_start in range(0, EMBEDDING_DIM, BLOCK_N): | |
x = tl.load(x_ptrs + n_start, eviction_policy="evict_last") | |
shift = tl.load(shift_ptrs + n_start, eviction_policy="evict_last") | |
scale = tl.load(scale_ptrs + n_start, eviction_policy="evict_last") | |
if UPCAST: | |
x = x.to(tl.float32) | |
shift = shift.to(tl.float32) | |
scale = scale.to(tl.float32) | |
x = (x - mean[:, None]) * rstd[:, None] | |
x = x * (1 + scale) + shift | |
tl.store(out_ptrs + n_start, x) | |
@triton.jit | |
def welford_reduce(value, mean, m2, weight, first_iteration): | |
if first_iteration: | |
new_weight = tl.full(weight.shape, 1, weight.dtype) | |
new_mean = value | |
new_m2 = tl.zeros_like(m2) | |
else: | |
delta = value - mean | |
new_weight = weight + 1 | |
new_mean = mean + delta / new_weight | |
new_m2 = m2 + delta * (value - new_mean) | |
return new_mean, new_m2, new_weight | |
@triton.jit | |
def welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): | |
delta = mean_2 - mean_1 | |
new_weight = weight_1 + weight_2 | |
w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight) | |
return ( | |
mean_1 + delta * w2_over_w, | |
m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w, | |
new_weight, | |
) | |
@triton.jit | |
def welford(mean, m2, weight, dim): | |
return tl.reduce((mean, m2, weight), dim, welford_combine) | |
@triton.autotune(configs=get_autotune_configs(), key=["BATCH_SIZE", "EMBEDDING_DIM"]) | |
@triton.jit | |
def adaptive_layernorm_zero_welford_kernel( | |
ptr_x, | |
ptr_shift, | |
ptr_scale, | |
ptr_out, | |
EPS: tl.constexpr, | |
BATCH_SIZE, | |
EMBEDDING_DIM: tl.constexpr, | |
BLOCK_M: tl.constexpr, | |
BLOCK_N: tl.constexpr, | |
): | |
pid_m = tl.program_id(0) | |
m_offsets = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M | |
n_offsets = tl.arange(0, BLOCK_N) | |
x_ptrs = ptr_x + m_offsets[:, None] * EMBEDDING_DIM + n_offsets[None, :] | |
out_ptrs = ptr_out + m_offsets[:, None] * EMBEDDING_DIM + n_offsets[None, :] | |
shift_ptrs = ptr_shift + n_offsets | |
scale_ptrs = ptr_scale + n_offsets | |
mean = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) | |
m2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) | |
weight = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) | |
for n_start in range(0, EMBEDDING_DIM, BLOCK_N): | |
x = tl.load(x_ptrs + n_start, eviction_policy="evict_last").to(tl.float32) | |
mean, m2, weight = welford_reduce(x, mean, m2, weight, n_start == 0) | |
mean, m2, weight = welford(mean, m2, weight, 1) | |
m2 = m2 / EMBEDDING_DIM | |
rstd = tl.rsqrt(m2 + EPS) | |
for n_start in range(0, EMBEDDING_DIM, BLOCK_N): | |
x = tl.load(x_ptrs + n_start, eviction_policy="evict_last").to(tl.float32) | |
shift = tl.load(shift_ptrs + n_start, eviction_policy="evict_last").to(tl.float32) | |
scale = tl.load(scale_ptrs + n_start, eviction_policy="evict_last").to(tl.float32) | |
x = (x - mean[:, None]) * rstd[:, None] | |
x = x * (1 + scale) + shift | |
tl.store(out_ptrs + n_start, x) | |
def adaptive_layernorm_zero_triton(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor, *, upcast: bool) -> torch.Tensor: | |
batch_size, seq_len, embedding_dim = x.shape | |
# We know this to be true considering the context of what we're optimizing | |
assert embedding_dim == 3072, "Embedding dimension must be 3072" | |
assert batch_size == 1, "Batch size must be 1" | |
assert shift.shape == scale.shape == (1, 1, embedding_dim), "Shift and scale must have shape (1, 1, embedding_dim)" | |
effective_batch_size = batch_size * seq_len | |
x = x.view(-1, embedding_dim) | |
shift = shift.view(embedding_dim) | |
scale = scale.view(embedding_dim) | |
out = torch.empty_like(x) | |
grid = lambda META: (triton.cdiv(effective_batch_size, META["BLOCK_M"]),) | |
adaptive_layernorm_zero_kernel[grid]( | |
x, | |
shift, | |
scale, | |
out, | |
EPS=1e-6, | |
BATCH_SIZE=effective_batch_size, | |
EMBEDDING_DIM=3072, | |
UPCAST=upcast, | |
) | |
return out.view(batch_size, seq_len, embedding_dim) | |
def adaptive_layernorm_zero_welford_triton(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: | |
batch_size, seq_len, embedding_dim = x.shape | |
# We know this to be true considering the context of what we're optimizing | |
assert embedding_dim == 3072, "Embedding dimension must be 3072" | |
assert batch_size == 1, "Batch size must be 1" | |
assert shift.shape == scale.shape == (1, 1, embedding_dim), "Shift and scale must have shape (1, 1, embedding_dim)" | |
effective_batch_size = batch_size * seq_len | |
x = x.view(-1, embedding_dim) | |
shift = shift.view(embedding_dim) | |
scale = scale.view(embedding_dim) | |
out = torch.empty_like(x) | |
grid = lambda META: (triton.cdiv(effective_batch_size, META["BLOCK_M"]),) | |
adaptive_layernorm_zero_welford_kernel[grid]( | |
x, | |
shift, | |
scale, | |
out, | |
EPS=1e-6, | |
BATCH_SIZE=effective_batch_size, | |
EMBEDDING_DIM=3072, | |
) | |
return out.view(batch_size, seq_len, embedding_dim) | |
class AdaLayerNormZeroEager(torch.nn.Module): | |
def __init__(self, embedding_dim: int, bias=True): | |
super().__init__() | |
self.linear = torch.nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) | |
self.norm = torch.nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) | |
def forward(self, x: torch.Tensor, emb: torch.Tensor): | |
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1) | |
x = self.norm(x) | |
x = torch.addcmul(shift_msa, x, 1 + scale_msa) | |
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp | |
class AdaLayerNormZeroTriton(torch.nn.Module): | |
def __init__(self, embedding_dim: int, bias=True, kernel_type: str = "welford"): | |
super().__init__() | |
self.linear = torch.nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) | |
self.norm = torch.nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) | |
self.kernel_type = kernel_type | |
assert kernel_type in ["nocast", "upcast", "welford"] | |
def forward(self, x: torch.Tensor, emb: torch.Tensor): | |
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1) | |
if ENABLE_TRITON: | |
if self.kernel_type == "upcast": | |
x = adaptive_layernorm_zero_triton(x, shift_msa, scale_msa, upcast=True) | |
elif self.kernel_type == "nocast": | |
x = adaptive_layernorm_zero_triton(x, shift_msa, scale_msa, upcast=False) | |
elif self.kernel_type == "welford": | |
x = adaptive_layernorm_zero_welford_triton(x, shift_msa, scale_msa) | |
else: | |
x = self.norm(x) | |
x = torch.addcmul(shift_msa, x, 1 + scale_msa) | |
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp | |
device = "cuda" | |
dtype = torch.bfloat16 | |
batch_size = 1 | |
seq_lens = [4096 + 128 * i for i in range(10)] | |
embedding_dim = 3072 | |
model_eager = AdaLayerNormZeroEager(embedding_dim).to(device, dtype) | |
model_triton_upcast = AdaLayerNormZeroTriton(embedding_dim, kernel_type="upcast").to(device, dtype) | |
model_triton_nocast = AdaLayerNormZeroTriton(embedding_dim, kernel_type="nocast").to(device, dtype) | |
model_triton_welford = AdaLayerNormZeroTriton(embedding_dim, kernel_type="welford").to(device, dtype) | |
models = {} | |
models["eager"] = model_eager | |
models["triton_upcast"] = model_triton_upcast | |
models["triton_nocast"] = model_triton_nocast | |
models["triton_welford"] = model_triton_welford | |
models["eager_compiled_d"] = torch.compile(model_eager, mode="default", fullgraph=True) | |
models["eager_compiled_max"] = torch.compile(model_eager, mode="max-autotune", fullgraph=True) | |
def get_color_and_linestyle(n: int) -> tuple[str, str]: | |
color_names = ["red", "blue", "green", "orange", "purple", "brown", "pink", "gray", "olive", "cyan"] | |
line_styles = ["-", "--", "-.", ":"] | |
if n > len(color_names) * len(line_styles): | |
raise ValueError(f"Required {n=} styles but maximum is {len(color_names) * len(line_styles)}") | |
styles = [] | |
for i in range(n): | |
color = color_names[i % len(color_names)] | |
linestyle = line_styles[i // len(color_names)] | |
styles.append((color, linestyle)) | |
return styles | |
def get_inputs(seq_len: int): | |
x = torch.randn(batch_size, seq_len, embedding_dim, device=device, dtype=dtype) | |
emb = torch.randn(batch_size, 1, embedding_dim, device=device, dtype=dtype) | |
emb = model_eager.linear(emb) | |
return x, emb | |
def correctness(): | |
x, emb = get_inputs(4608) | |
results = {} | |
for name, model in models.items(): | |
results[name] = model(x, emb) | |
reference = results["eager"] | |
for name, result in results.items(): | |
for i, (res, ref) in enumerate(zip(result, reference)): | |
diff = res - ref | |
absdiff = torch.abs(diff) | |
absmax = torch.max(absdiff) | |
mae = torch.mean(absdiff) | |
mse = torch.mean(diff * diff) | |
print(f"{name} {i}: absmax={absmax:.4}, mae={mae:.4}, mse={mse:.4}") | |
@triton.testing.perf_report( | |
triton.testing.Benchmark( | |
x_names=["seq_len"], | |
x_vals=seq_lens, | |
x_log=False, | |
line_arg="provider", | |
line_vals=list(models.keys()), | |
line_names=[x.removeprefix("solution_") for x in models.keys()], | |
styles=get_color_and_linestyle(len(models)), | |
ylabel="time (ms)", | |
plot_name="AdaLN Zero", | |
args={}, | |
) | |
) | |
def benchmark(seq_len: int, provider: str): | |
x, emb = get_inputs(seq_len) | |
fn = models[provider] | |
ms, min_ms, max_ms = triton.testing.do_bench( | |
lambda: fn(x, emb), | |
warmup=8, | |
rep=32, | |
quantiles=[0.5, 0.2, 0.8], | |
) | |
return ms, max_ms, min_ms | |
if __name__ == "__main__": | |
torch.manual_seed(0) | |
with torch.inference_mode(): | |
correctness() | |
benchmark.run(print_data=True, save_path="dump_benchmark_adaln_zero") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
A100:
H100: