Created
June 1, 2025 21:59
-
-
Save Ryu1845/904a9b2c1fd58ca0b5d14119a5452eea to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Tuple | |
import gc | |
import torch | |
import torch.nn.functional as F | |
import triton | |
import triton.language as tl | |
import triton.testing | |
# --- Memory Tracking Utilities --- | |
def get_gpu_memory_usage(): | |
"""Get current GPU memory usage in MB.""" | |
if torch.cuda.is_available(): | |
return torch.cuda.memory_allocated() / 1024 / 1024 # Convert to MB | |
return 0.0 | |
def get_peak_gpu_memory_usage(): | |
"""Get peak GPU memory usage in MB.""" | |
if torch.cuda.is_available(): | |
return torch.cuda.max_memory_allocated() / 1024 / 1024 # Convert to MB | |
return 0.0 | |
def reset_gpu_memory_stats(): | |
"""Reset GPU memory statistics.""" | |
if torch.cuda.is_available(): | |
torch.cuda.reset_peak_memory_stats() | |
torch.cuda.empty_cache() | |
gc.collect() | |
def measure_memory_usage(fn, warmup=3, repetitions=10): | |
""" | |
Measure peak memory usage of a function. | |
Args: | |
fn: Function to measure | |
warmup: Number of warmup runs | |
repetitions: Number of measurement runs | |
Returns: | |
tuple: (average_peak_memory_mb, max_peak_memory_mb) | |
""" | |
peak_memories = [] | |
for _ in range(warmup): | |
reset_gpu_memory_stats() | |
fn() | |
torch.cuda.synchronize() | |
for _ in range(repetitions): | |
reset_gpu_memory_stats() | |
fn() | |
torch.cuda.synchronize() | |
peak_memory = get_peak_gpu_memory_usage() | |
peak_memories.append(peak_memory) | |
avg_peak = sum(peak_memories) / len(peak_memories) | |
max_peak = max(peak_memories) | |
return avg_peak, max_peak | |
# --- Triton Multi-Head JVP Kernel --- | |
@triton.autotune( | |
configs=[ | |
# Ultra-conservative configs for maximum compatibility | |
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16}, num_warps=2, num_stages=1), | |
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16}, num_warps=2, num_stages=1), | |
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 32}, num_warps=2, num_stages=1), | |
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32}, num_warps=4, num_stages=1), | |
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 16}, num_warps=4, num_stages=1), | |
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 64}, num_warps=4, num_stages=1), | |
], | |
key=['B', 'H', 'L', 'D_head'], | |
) | |
@triton.jit | |
def _flash_attention_jvp_multihead_kernel( | |
# Input tensors | |
Q, K, V, T_Q, T_K, T_V, | |
# Output tensors | |
Y, T_Y, | |
# Tensor strides | |
stride_qb, stride_qh, stride_ql, stride_qd, | |
stride_kb, stride_kh, stride_kl, stride_kd, | |
stride_vb, stride_vh, stride_vl, stride_vd, | |
stride_tqb, stride_tqh, stride_tql, stride_tqd, | |
stride_tkb, stride_tkh, stride_tkl, stride_tkd, | |
stride_tvb, stride_tvh, stride_tvl, stride_tvd, | |
stride_yb, stride_yh, stride_yl, stride_yd, | |
stride_tyb, stride_tyh, stride_tyl, stride_tyd, | |
# Problem dimensions | |
B: tl.constexpr, H: tl.constexpr, L: tl.constexpr, D_head: tl.constexpr, | |
# Scale factor | |
scale: tl.constexpr, | |
# Block sizes | |
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, | |
): | |
""" | |
Flash Attention JVP kernel following the reference implementation pattern. | |
Grid: (B*H, triton.cdiv(L, BLOCK_M)) | |
""" | |
# Get program IDs | |
pid_bh = tl.program_id(0) # Combined batch and head index | |
pid_m = tl.program_id(1) # Query block index | |
# Decompose batch and head indices | |
pid_b = pid_bh // H | |
pid_h = pid_bh % H | |
# Compute offsets | |
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
offs_n = tl.arange(0, BLOCK_N) | |
offs_d = tl.arange(0, D_head) | |
# Base pointers for this (batch, head) | |
q_base = Q + pid_b * stride_qb + pid_h * stride_qh | |
k_base = K + pid_b * stride_kb + pid_h * stride_kh | |
v_base = V + pid_b * stride_vb + pid_h * stride_vh | |
tq_base = T_Q + pid_b * stride_tqb + pid_h * stride_tqh | |
tk_base = T_K + pid_b * stride_tkb + pid_h * stride_tkh | |
tv_base = T_V + pid_b * stride_tvb + pid_h * stride_tvh | |
y_base = Y + pid_b * stride_yb + pid_h * stride_yh | |
ty_base = T_Y + pid_b * stride_tyb + pid_h * stride_tyh | |
# Load query block | |
q_ptrs = q_base + offs_m[:, None] * stride_ql + offs_d[None, :] * stride_qd | |
tq_ptrs = tq_base + offs_m[:, None] * stride_tql + offs_d[None, :] * stride_tqd | |
mask_m = offs_m < L | |
q = tl.load(q_ptrs, mask=mask_m[:, None], other=0.0) | |
tq = tl.load(tq_ptrs, mask=mask_m[:, None], other=0.0) | |
# Initialize accumulators following Flash Attention pattern | |
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") | |
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) | |
acc = tl.zeros([BLOCK_M, D_head], dtype=tl.float32) | |
g_acc = tl.zeros([BLOCK_M, D_head], dtype=tl.float32) | |
mu_i = tl.zeros([BLOCK_M], dtype=tl.float32) | |
p_tv_acc = tl.zeros([BLOCK_M, D_head], dtype=tl.float32) | |
# Scale factor for exp2 optimization (like reference) | |
qk_scale = scale * 1.44269504 # 1/log(2) | |
# Loop over key/value blocks | |
for start_n in range(0, L, BLOCK_N): | |
start_n = tl.multiple_of(start_n, BLOCK_N) | |
offs_n_curr = start_n + offs_n | |
mask_n = offs_n_curr < L | |
# Load key and value blocks | |
k_ptrs = k_base + offs_n_curr[:, None] * stride_kl + offs_d[None, :] * stride_kd | |
v_ptrs = v_base + offs_n_curr[:, None] * stride_vl + offs_d[None, :] * stride_vd | |
tk_ptrs = tk_base + offs_n_curr[:, None] * stride_tkl + offs_d[None, :] * stride_tkd | |
tv_ptrs = tv_base + offs_n_curr[:, None] * stride_tvl + offs_d[None, :] * stride_tvd | |
k = tl.load(k_ptrs, mask=mask_n[:, None], other=0.0) | |
v = tl.load(v_ptrs, mask=mask_n[:, None], other=0.0) | |
tk = tl.load(tk_ptrs, mask=mask_n[:, None], other=0.0) | |
tv = tl.load(tv_ptrs, mask=mask_n[:, None], other=0.0) | |
# Compute attention scores | |
qk = tl.dot(q, tl.trans(k)) | |
tqk = tl.dot(tq, tl.trans(k)) + tl.dot(q, tl.trans(tk)) | |
# Mask invalid positions first | |
qk = tl.where(mask_n[None, :], qk, float('-inf')) | |
tqk = tl.where(mask_n[None, :], tqk, 0.0) | |
# Online softmax computation following Flash Attention | |
m_ij = tl.maximum(m_i, tl.max(qk * scale, 1)) | |
qk = qk * qk_scale - m_ij[:, None] # Scale and subtract max | |
p = tl.math.exp2(qk) # Use exp2 like reference | |
# Correction factor | |
alpha = tl.math.exp2(m_i - m_ij) | |
l_ij = tl.sum(p, 1) | |
# Update normalization | |
l_i = l_i * alpha + l_ij | |
# Update output accumulator | |
acc = acc * alpha[:, None] + tl.dot(p, v) | |
# JVP accumulator: (p * tqk) @ v | |
p_tqk = p * (tqk * scale) # Apply scale to tangent scores | |
g_acc = g_acc * alpha[:, None] + tl.dot(p_tqk, v) | |
# Update mu: sum(p * tqk) | |
mu_ij = tl.sum(p_tqk, 1) | |
mu_i = mu_i * alpha + mu_ij | |
# Update p @ tv accumulator | |
p_tv_acc = p_tv_acc * alpha[:, None] + tl.dot(p, tv) | |
# Update max | |
m_i = m_ij | |
# Final computation - add log normalization and divide | |
m_i += tl.math.log2(l_i) | |
y_out = acc / l_i[:, None] | |
t_p_v = g_acc / l_i[:, None] - (mu_i / l_i)[:, None] * y_out | |
t_y_out = t_p_v + p_tv_acc / l_i[:, None] | |
# Store outputs | |
y_ptrs = y_base + offs_m[:, None] * stride_yl + offs_d[None, :] * stride_yd | |
ty_ptrs = ty_base + offs_m[:, None] * stride_tyl + offs_d[None, :] * stride_tyd | |
tl.store(y_ptrs, y_out, mask=mask_m[:, None]) | |
tl.store(ty_ptrs, t_y_out, mask=mask_m[:, None]) | |
def flash_attention_jvp_multihead_triton_kernel_wrapper( | |
Q: torch.Tensor, | |
K: torch.Tensor, | |
V: torch.Tensor, | |
t_Q: torch.Tensor, | |
t_K: torch.Tensor, | |
t_V: torch.Tensor, | |
scale: float = None | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Python wrapper for the Multi-head Flash Attention JVP Triton kernel. | |
""" | |
device = Q.device | |
dtype = Q.dtype | |
B, H, L, D_head = Q.shape | |
# Check minimum dimension requirements for Triton | |
if D_head < 16: | |
raise ValueError(f"D_head must be >= 16 for efficient Triton kernel, got {D_head}") | |
if scale is None: | |
scale = 1.0 / (D_head ** 0.5) | |
# Ensure input shapes are correct | |
assert Q.shape == (B, H, L, D_head), f"Q shape mismatch: {Q.shape}" | |
assert K.shape == (B, H, L, D_head), f"K shape mismatch: {K.shape}" | |
assert V.shape == (B, H, L, D_head), f"V shape mismatch: {V.shape}" | |
assert t_Q.shape == (B, H, L, D_head), f"t_Q shape mismatch: {t_Q.shape}" | |
assert t_K.shape == (B, H, L, D_head), f"t_K shape mismatch: {t_K.shape}" | |
assert t_V.shape == (B, H, L, D_head), f"t_V shape mismatch: {t_V.shape}" | |
# Create output tensors | |
y = torch.zeros((B, H, L, D_head), dtype=dtype, device=device) | |
t_y = torch.zeros((B, H, L, D_head), dtype=dtype, device=device) | |
# Make tensors contiguous | |
Qc = Q.contiguous() | |
Kc = K.contiguous() | |
Vc = V.contiguous() | |
t_Qc = t_Q.contiguous() | |
t_Kc = t_K.contiguous() | |
t_Vc = t_V.contiguous() | |
# Compute strides | |
stride_qb, stride_qh, stride_ql, stride_qd = Qc.stride() | |
stride_kb, stride_kh, stride_kl, stride_kd = Kc.stride() | |
stride_vb, stride_vh, stride_vl, stride_vd = Vc.stride() | |
stride_tqb, stride_tqh, stride_tql, stride_tqd = t_Qc.stride() | |
stride_tkb, stride_tkh, stride_tkl, stride_tkd = t_Kc.stride() | |
stride_tvb, stride_tvh, stride_tvl, stride_tvd = t_Vc.stride() | |
stride_yb, stride_yh, stride_yl, stride_yd = y.stride() | |
stride_tyb, stride_tyh, stride_tyl, stride_tyd = t_y.stride() | |
# Use block-based grid like Flash Attention | |
# Choose BLOCK_M based on autotuning, but ensure we cover all queries | |
BLOCK_M = 64 # Will be determined by autotuning | |
grid = (B * H, triton.cdiv(L, BLOCK_M)) | |
_flash_attention_jvp_multihead_kernel[grid]( | |
Qc, Kc, Vc, t_Qc, t_Kc, t_Vc, | |
y, t_y, | |
stride_qb, stride_qh, stride_ql, stride_qd, | |
stride_kb, stride_kh, stride_kl, stride_kd, | |
stride_vb, stride_vh, stride_vl, stride_vd, | |
stride_tqb, stride_tqh, stride_tql, stride_tqd, | |
stride_tkb, stride_tkh, stride_tkl, stride_tkd, | |
stride_tvb, stride_tvh, stride_tvl, stride_tvd, | |
stride_yb, stride_yh, stride_yl, stride_yd, | |
stride_tyb, stride_tyh, stride_tyl, stride_tyd, | |
B, H, L, D_head, | |
scale, | |
) | |
return y, t_y | |
# --- Naive PyTorch Multi-Head JVP --- | |
def naive_multihead_attention_jvp( | |
Q: torch.Tensor, | |
K: torch.Tensor, | |
V: torch.Tensor, | |
t_Q: torch.Tensor, | |
t_K: torch.Tensor, | |
t_V: torch.Tensor, | |
scale: float = None | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Naive implementation of multi-head attention JVP using torch.func.jvp. | |
""" | |
from torch.func import ( | |
jvp, # Import locally to avoid issues if torch.func is not available | |
) | |
B, H, L, D_head = Q.shape | |
if scale is None: | |
scale = 1.0 / (D_head ** 0.5) | |
def multihead_attention_forward(Q_param, K_param, V_param): | |
scores = torch.matmul(Q_param, K_param.transpose(-2, -1)) * scale | |
p = F.softmax(scores, dim=-1) | |
y = torch.matmul(p, V_param) | |
return y | |
y, t_y = jvp( | |
multihead_attention_forward, | |
(Q, K, V), | |
(t_Q, t_K, t_V) | |
) | |
return y, t_y | |
# --- Test Function --- | |
def test_multihead_correctness(): | |
""" | |
Tests the Triton multi-head JVP kernel against the naive PyTorch implementation. | |
""" | |
print("Running Multi-Head JVP Correctness Test...") | |
torch.manual_seed(42) | |
device = torch.device('cuda') | |
test_configs = [ | |
{'B': 1, 'H': 1, 'L': 8, 'D_head': 16}, # Basic | |
{'B': 2, 'H': 2, 'L': 16, 'D_head': 32}, # Larger dimensions | |
{'B': 1, 'H': 4, 'L': 32, 'D_head': 32}, # More heads | |
{'B': 1, 'H': 1, 'L': 64, 'D_head': 64}, # Larger L, D_head | |
{'B': 2, 'H': 3, 'L': 128, 'D_head': 16}, # Mixed dimensions | |
] | |
for i, config in enumerate(test_configs): | |
B, H, L, D_head = config['B'], config['H'], config['L'], config['D_head'] | |
print(f"\nTest Case {i+1}: B={B}, H={H}, L={L}, D_head={D_head}") | |
Q = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) | |
K = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) | |
V = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) | |
t_Q = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) * 0.1 # Smaller tangents | |
t_K = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) * 0.1 | |
t_V = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) * 0.1 | |
scale = 1.0 / (D_head ** 0.5) | |
# Compute using naive PyTorch implementation | |
y_naive, t_y_naive = naive_multihead_attention_jvp( | |
Q, K, V, t_Q, t_K, t_V, scale | |
) | |
# Compute using Triton kernel implementation | |
try: | |
y_triton, t_y_triton = flash_attention_jvp_multihead_triton_kernel_wrapper( | |
Q, K, V, t_Q, t_K, t_V, scale | |
) | |
except Exception as e: | |
print(f" Triton kernel execution failed: {e}") | |
# If triton fails, we can't compare. Mark as failure for this config. | |
assert False, f"Triton kernel failed for config {config}" | |
continue | |
# Compare results | |
rtol, atol = 1e-2, 1e-3 # More realistic tolerances for Triton kernel differences | |
try: | |
torch.testing.assert_close( | |
y_triton, y_naive, rtol=rtol, atol=atol, | |
msg=f"Forward output mismatch for config {config}" | |
) | |
print(" Forward output: PASSED") | |
except AssertionError as e: | |
print(f" Forward output: FAILED\n{e}") | |
try: | |
torch.testing.assert_close( | |
t_y_triton, t_y_naive, rtol=rtol, atol=atol, | |
msg=f"JVP output mismatch for config {config}" | |
) | |
print(" JVP output: PASSED") | |
except AssertionError as e: | |
print(f" JVP output: FAILED\n{e}") | |
print("\nMulti-Head JVP Correctness Test Finished.") | |
# --- Memory Test Function --- | |
def test_memory_usage(): | |
""" | |
Tests and compares memory usage between Triton and naive implementations. | |
""" | |
print("Running Memory Usage Comparison...") | |
torch.manual_seed(42) | |
device = torch.device('cuda') | |
test_configs = [ | |
{'B': 1, 'H': 1, 'L': 128, 'D_head': 64}, | |
{'B': 2, 'H': 2, 'L': 256, 'D_head': 32}, | |
{'B': 1, 'H': 4, 'L': 512, 'D_head': 64}, | |
] | |
for i, config in enumerate(test_configs): | |
B, H, L, D_head = config['B'], config['H'], config['L'], config['D_head'] | |
print(f"\nMemory Test {i+1}: B={B}, H={H}, L={L}, D_head={D_head}") | |
Q = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) | |
K = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) | |
V = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) | |
t_Q = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) * 0.1 | |
t_K = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) * 0.1 | |
t_V = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) * 0.1 | |
scale = 1.0 / (D_head ** 0.5) | |
# Test Triton implementation memory usage | |
triton_fn = lambda: flash_attention_jvp_multihead_triton_kernel_wrapper( | |
Q, K, V, t_Q, t_K, t_V, scale | |
) | |
triton_avg_mem, triton_max_mem = measure_memory_usage(triton_fn) | |
# Test naive implementation memory usage | |
naive_fn = lambda: naive_multihead_attention_jvp(Q, K, V, t_Q, t_K, t_V, scale) | |
naive_avg_mem, naive_max_mem = measure_memory_usage(naive_fn) | |
print(f" Triton - Avg: {triton_avg_mem:.2f} MB, Peak: {triton_max_mem:.2f} MB") | |
print(f" Naive - Avg: {naive_avg_mem:.2f} MB, Peak: {naive_max_mem:.2f} MB") | |
memory_ratio = naive_avg_mem / triton_avg_mem if triton_avg_mem > 0 else float('inf') | |
print(f" Memory Efficiency: {memory_ratio:.2f}x ({'Better' if memory_ratio > 1 else 'Worse'})") | |
print("\nMemory Usage Comparison Finished.") | |
# --- Comprehensive Memory Analysis --- | |
def comprehensive_memory_analysis(): | |
""" | |
Comprehensive memory analysis comparing Triton vs Naive implementations | |
across different problem sizes with detailed statistics. | |
""" | |
print("\n" + "="*60) | |
print("COMPREHENSIVE MEMORY ANALYSIS") | |
print("="*60) | |
torch.manual_seed(42) | |
device = torch.device('cuda') | |
# Extended test configurations for comprehensive analysis | |
test_configs = [ | |
# Small problems | |
{'B': 1, 'H': 1, 'L': 64, 'D_head': 32, 'category': 'Small'}, | |
{'B': 1, 'H': 2, 'L': 128, 'D_head': 64, 'category': 'Small'}, | |
# Medium problems | |
{'B': 2, 'H': 4, 'L': 256, 'D_head': 64, 'category': 'Medium'}, | |
{'B': 1, 'H': 8, 'L': 512, 'D_head': 32, 'category': 'Medium'}, | |
# Large problems | |
{'B': 4, 'H': 4, 'L': 512, 'D_head': 64, 'category': 'Large'}, | |
{'B': 2, 'H': 8, 'L': 1024, 'D_head': 32, 'category': 'Large'}, | |
] | |
memory_results = [] | |
for i, config in enumerate(test_configs): | |
B, H, L, D_head = config['B'], config['H'], config['L'], config['D_head'] | |
category = config['category'] | |
print(f"\n[{category.upper()}] Test {i+1}: B={B}, H={H}, L={L}, D_head={D_head}") | |
# Calculate theoretical memory requirements | |
input_size = B * H * L * D_head * 4 # 4 bytes per float32 | |
total_input_size = input_size * 6 # Q, K, V, t_Q, t_K, t_V | |
output_size = input_size * 2 # Y, t_Y | |
print(f" Theoretical Input Memory: {total_input_size / (1024*1024):.2f} MB") | |
print(f" Theoretical Output Memory: {output_size / (1024*1024):.2f} MB") | |
Q = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) | |
K = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) | |
V = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) | |
t_Q = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) * 0.1 | |
t_K = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) * 0.1 | |
t_V = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) * 0.1 | |
scale = 1.0 / (D_head ** 0.5) | |
# Test Triton implementation | |
triton_fn = lambda: flash_attention_jvp_multihead_triton_kernel_wrapper( | |
Q, K, V, t_Q, t_K, t_V, scale | |
) | |
triton_avg_mem, triton_peak_mem = measure_memory_usage(triton_fn, repetitions=10) | |
# Test naive implementation | |
naive_fn = lambda: naive_multihead_attention_jvp(Q, K, V, t_Q, t_K, t_V, scale) | |
naive_avg_mem, naive_peak_mem = measure_memory_usage(naive_fn, repetitions=10) | |
# Calculate efficiency metrics | |
memory_ratio = naive_avg_mem / triton_avg_mem if triton_avg_mem > 0 else float('inf') | |
peak_ratio = naive_peak_mem / triton_peak_mem if triton_peak_mem > 0 else float('inf') | |
memory_savings = naive_avg_mem - triton_avg_mem | |
# Store results | |
result = { | |
'config': config, | |
'triton_avg': triton_avg_mem, | |
'triton_peak': triton_peak_mem, | |
'naive_avg': naive_avg_mem, | |
'naive_peak': naive_peak_mem, | |
'ratio': memory_ratio, | |
'peak_ratio': peak_ratio, | |
'savings': memory_savings | |
} | |
memory_results.append(result) | |
print(f" Triton - Avg: {triton_avg_mem:8.2f} MB, Peak: {triton_peak_mem:8.2f} MB") | |
print(f" Naive - Avg: {naive_avg_mem:8.2f} MB, Peak: {naive_peak_mem:8.2f} MB") | |
print(f" Efficiency Gain: {memory_ratio:.2f}x (avg), {peak_ratio:.2f}x (peak)") | |
print(f" Memory Saved: {memory_savings:.2f} MB") | |
# Summary statistics | |
print(f"\n" + "="*60) | |
print("MEMORY ANALYSIS SUMMARY") | |
print("="*60) | |
avg_ratios = [r['ratio'] for r in memory_results if r['ratio'] != float('inf')] | |
peak_ratios = [r['peak_ratio'] for r in memory_results if r['peak_ratio'] != float('inf')] | |
total_savings = sum(r['savings'] for r in memory_results) | |
print(f"Average Memory Efficiency: {sum(avg_ratios)/len(avg_ratios):.2f}x") | |
print(f"Peak Memory Efficiency: {sum(peak_ratios)/len(peak_ratios):.2f}x") | |
print(f"Total Memory Saved: {total_savings:.2f} MB") | |
print(f"Best Case Efficiency: {max(avg_ratios):.2f}x") | |
print(f"Worst Case Efficiency: {min(avg_ratios):.2f}x") | |
# Category breakdown | |
categories = {} | |
for result in memory_results: | |
cat = result['config']['category'] | |
if cat not in categories: | |
categories[cat] = [] | |
categories[cat].append(result['ratio']) | |
print(f"\nEfficiency by Problem Size:") | |
for cat, ratios in categories.items(): | |
avg_ratio = sum(ratios) / len(ratios) | |
print(f" {cat}: {avg_ratio:.2f}x average efficiency") | |
print(f"\n" + "="*60) | |
# --- Benchmark Function and Configurations --- | |
# Benchmark configurations and function are now defined unconditionally | |
benchmark_configs = [ | |
triton.testing.Benchmark( | |
x_names=['L'], | |
x_vals=[2**i for i in range(5, 15)], # Sequence lengths: 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384 | |
line_arg='provider', | |
line_vals=['triton', 'naive'], | |
line_names=['Triton Kernel', 'Naive PyTorch'], | |
styles=[('blue', '-'), ('green', '--')], | |
ylabel='ms', | |
plot_name='multihead_jvp_L_scaling-B1-H2-D32', | |
args={'B': 1, 'H': 2, 'D_head': 32, 'dtype': torch.float32, 'device': 'cuda'} | |
), | |
triton.testing.Benchmark( | |
x_names=['D_head'], | |
x_vals=[16, 32, 64, 128, 256], # Head dimensions (all power-of-2) | |
line_arg='provider', | |
line_vals=['triton', 'naive'], | |
line_names=['Triton Kernel', 'Naive PyTorch'], | |
styles=[('blue', '-'), ('green', '--')], | |
ylabel='ms', | |
plot_name='multihead_jvp_D_head_scaling-B1-H2-L128', | |
args={'B': 1, 'H': 2, 'L': 1024, 'dtype': torch.float32, 'device': 'cuda'} | |
) | |
] | |
# Memory benchmark configurations | |
memory_benchmark_configs = [ | |
triton.testing.Benchmark( | |
x_names=['L'], | |
x_vals=[2**i for i in range(5, 13)], # Sequence lengths: 32, 64, 128, 256, 512, 1024, 2048, 4096 | |
line_arg='provider', | |
line_vals=['triton', 'naive'], | |
line_names=['Triton Kernel', 'Naive PyTorch'], | |
styles=[('blue', '-'), ('green', '--')], | |
ylabel='Memory Usage (MB)', | |
plot_name='multihead_jvp_memory_L_scaling-B1-H2-D32', | |
args={'B': 1, 'H': 2, 'D_head': 32, 'dtype': torch.float32, 'device': 'cuda', 'benchmark_type': 'memory'} | |
), | |
triton.testing.Benchmark( | |
x_names=['D_head'], | |
x_vals=[16, 32, 64, 128, 256], # Head dimensions | |
line_arg='provider', | |
line_vals=['triton', 'naive'], | |
line_names=['Triton Kernel', 'Naive PyTorch'], | |
styles=[('blue', '-'), ('green', '--')], | |
ylabel='Memory Usage (MB)', | |
plot_name='multihead_jvp_memory_D_head_scaling-B1-H2-L512', | |
args={'B': 1, 'H': 2, 'L': 512, 'dtype': torch.float32, 'device': 'cuda', 'benchmark_type': 'memory'} | |
), | |
triton.testing.Benchmark( | |
x_names=['B'], | |
x_vals=[1, 2, 4, 8, 16], # Batch sizes | |
line_arg='provider', | |
line_vals=['triton', 'naive'], | |
line_names=['Triton Kernel', 'Naive PyTorch'], | |
styles=[('blue', '-'), ('green', '--')], | |
ylabel='Memory Usage (MB)', | |
plot_name='multihead_jvp_memory_B_scaling-H2-L256-D32', | |
args={'H': 2, 'L': 256, 'D_head': 32, 'dtype': torch.float32, 'device': 'cuda', 'benchmark_type': 'memory'} | |
) | |
] | |
@triton.testing.perf_report(benchmark_configs) | |
def bench_multihead_jvp(B, H, L, D_head, provider, dtype, device, benchmark_type='time'): | |
torch.manual_seed(0) # Ensure consistent inputs for fair comparison | |
Q = torch.randn(B, H, L, D_head, device=device, dtype=dtype) | |
K = torch.randn(B, H, L, D_head, device=device, dtype=dtype) | |
V = torch.randn(B, H, L, D_head, device=device, dtype=dtype) | |
t_Q = torch.randn(B, H, L, D_head, device=device, dtype=dtype) * 0.1 | |
t_K = torch.randn(B, H, L, D_head, device=device, dtype=dtype) * 0.1 | |
t_V = torch.randn(B, H, L, D_head, device=device, dtype=dtype) * 0.1 | |
scale = 1.0 / (D_head ** 0.5) if D_head > 0 else 1.0 | |
if provider == 'triton': | |
fn = lambda: flash_attention_jvp_multihead_triton_kernel_wrapper( | |
Q, K, V, t_Q, t_K, t_V, scale | |
) | |
elif provider == 'naive': | |
fn = lambda: naive_multihead_attention_jvp(Q, K, V, t_Q, t_K, t_V, scale) | |
else: | |
raise ValueError(f"Unknown provider: {provider}") | |
if benchmark_type == 'memory': | |
# Measure memory usage | |
avg_memory, max_memory = measure_memory_usage(fn, warmup=3, repetitions=5) | |
return avg_memory | |
else: | |
# Measure time (default) | |
ms = triton.testing.do_bench(fn, warmup=10, rep=50) | |
return ms | |
# Memory-specific benchmark function | |
@triton.testing.perf_report(memory_benchmark_configs) | |
def bench_multihead_jvp_memory(B, H, L, D_head, provider, dtype, device, benchmark_type='memory'): | |
torch.manual_seed(0) # Ensure consistent inputs for fair comparison | |
Q = torch.randn(B, H, L, D_head, device=device, dtype=dtype) | |
K = torch.randn(B, H, L, D_head, device=device, dtype=dtype) | |
V = torch.randn(B, H, L, D_head, device=device, dtype=dtype) | |
t_Q = torch.randn(B, H, L, D_head, device=device, dtype=dtype) * 0.1 | |
t_K = torch.randn(B, H, L, D_head, device=device, dtype=dtype) * 0.1 | |
t_V = torch.randn(B, H, L, D_head, device=device, dtype=dtype) * 0.1 | |
scale = 1.0 / (D_head ** 0.5) if D_head > 0 else 1.0 | |
if provider == 'triton': | |
fn = lambda: flash_attention_jvp_multihead_triton_kernel_wrapper( | |
Q, K, V, t_Q, t_K, t_V, scale | |
) | |
elif provider == 'naive': | |
fn = lambda: naive_multihead_attention_jvp(Q, K, V, t_Q, t_K, t_V, scale) | |
else: | |
raise ValueError(f"Unknown provider: {provider}") | |
# Measure memory usage | |
avg_memory, max_memory = measure_memory_usage(fn, warmup=3, repetitions=5) | |
return avg_memory | |
if __name__ == "__main__": | |
test_multihead_correctness() | |
print("\nRunning Memory Usage Tests...") | |
test_memory_usage() | |
print("\nRunning Comprehensive Memory Analysis...") | |
comprehensive_memory_analysis() | |
# Benchmarks are now run unconditionally | |
print("\nRunning Multi-Head JVP Time Benchmarks...") | |
bench_multihead_jvp.run(print_data=True, save_path='.') | |
print("\nRunning Multi-Head JVP Memory Benchmarks...") | |
bench_multihead_jvp_memory.run(print_data=True, save_path='.') | |
print("\nAll benchmarks finished. Plots saved to current directory.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment