Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save byronyi/e2bec7a60b608b8db15b4f711d7e8958 to your computer and use it in GitHub Desktop.
Save byronyi/e2bec7a60b608b8db15b4f711d7e8958 to your computer and use it in GitHub Desktop.
from typing import Tuple
import gc
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
import triton.testing
# --- Memory Tracking Utilities ---
def get_gpu_memory_usage():
"""Get current GPU memory usage in MB."""
if torch.cuda.is_available():
return torch.cuda.memory_allocated() / 1024 / 1024 # Convert to MB
return 0.0
def get_peak_gpu_memory_usage():
"""Get peak GPU memory usage in MB."""
if torch.cuda.is_available():
return torch.cuda.max_memory_allocated() / 1024 / 1024 # Convert to MB
return 0.0
def reset_gpu_memory_stats():
"""Reset GPU memory statistics."""
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
gc.collect()
def measure_memory_usage(fn, warmup=3, repetitions=10):
"""
Measure peak memory usage of a function.
Args:
fn: Function to measure
warmup: Number of warmup runs
repetitions: Number of measurement runs
Returns:
tuple: (average_peak_memory_mb, max_peak_memory_mb)
"""
peak_memories = []
for _ in range(warmup):
reset_gpu_memory_stats()
fn()
torch.cuda.synchronize()
for _ in range(repetitions):
reset_gpu_memory_stats()
fn()
torch.cuda.synchronize()
peak_memory = get_peak_gpu_memory_usage()
peak_memories.append(peak_memory)
avg_peak = sum(peak_memories) / len(peak_memories)
max_peak = max(peak_memories)
return avg_peak, max_peak
# --- Triton Multi-Head JVP Kernel ---
@triton.autotune(
configs=[
# Ultra-conservative configs for maximum compatibility
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16}, num_warps=2, num_stages=1),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16}, num_warps=2, num_stages=1),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 32}, num_warps=2, num_stages=1),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32}, num_warps=4, num_stages=1),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 16}, num_warps=4, num_stages=1),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 64}, num_warps=4, num_stages=1),
],
key=['B', 'H', 'L', 'D_head'],
)
@triton.jit
def _flash_attention_jvp_multihead_kernel(
# Input tensors
Q, K, V, T_Q, T_K, T_V,
# Output tensors
Y, T_Y,
# Tensor strides
stride_qb, stride_qh, stride_ql, stride_qd,
stride_kb, stride_kh, stride_kl, stride_kd,
stride_vb, stride_vh, stride_vl, stride_vd,
stride_tqb, stride_tqh, stride_tql, stride_tqd,
stride_tkb, stride_tkh, stride_tkl, stride_tkd,
stride_tvb, stride_tvh, stride_tvl, stride_tvd,
stride_yb, stride_yh, stride_yl, stride_yd,
stride_tyb, stride_tyh, stride_tyl, stride_tyd,
# Problem dimensions
B: tl.constexpr, H: tl.constexpr, L: tl.constexpr, D_head: tl.constexpr,
# Scale factor
scale: tl.constexpr,
# Block sizes
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
"""
Flash Attention JVP kernel following the reference implementation pattern.
Grid: (B*H, triton.cdiv(L, BLOCK_M))
"""
# Get program IDs
pid_bh = tl.program_id(0) # Combined batch and head index
pid_m = tl.program_id(1) # Query block index
# Decompose batch and head indices
pid_b = pid_bh // H
pid_h = pid_bh % H
# Compute offsets
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, D_head)
# Base pointers for this (batch, head)
q_base = Q + pid_b * stride_qb + pid_h * stride_qh
k_base = K + pid_b * stride_kb + pid_h * stride_kh
v_base = V + pid_b * stride_vb + pid_h * stride_vh
tq_base = T_Q + pid_b * stride_tqb + pid_h * stride_tqh
tk_base = T_K + pid_b * stride_tkb + pid_h * stride_tkh
tv_base = T_V + pid_b * stride_tvb + pid_h * stride_tvh
y_base = Y + pid_b * stride_yb + pid_h * stride_yh
ty_base = T_Y + pid_b * stride_tyb + pid_h * stride_tyh
# Load query block
q_ptrs = q_base + offs_m[:, None] * stride_ql + offs_d[None, :] * stride_qd
tq_ptrs = tq_base + offs_m[:, None] * stride_tql + offs_d[None, :] * stride_tqd
mask_m = offs_m < L
q = tl.load(q_ptrs, mask=mask_m[:, None], other=0.0)
tq = tl.load(tq_ptrs, mask=mask_m[:, None], other=0.0)
# Initialize accumulators following Flash Attention pattern
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, D_head], dtype=tl.float32)
g_acc = tl.zeros([BLOCK_M, D_head], dtype=tl.float32)
mu_i = tl.zeros([BLOCK_M], dtype=tl.float32)
p_tv_acc = tl.zeros([BLOCK_M, D_head], dtype=tl.float32)
# Scale factor for exp2 optimization (like reference)
qk_scale = scale * 1.44269504 # 1/log(2)
# Loop over key/value blocks
for start_n in range(0, L, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
offs_n_curr = start_n + offs_n
mask_n = offs_n_curr < L
# Load key and value blocks
k_ptrs = k_base + offs_n_curr[:, None] * stride_kl + offs_d[None, :] * stride_kd
v_ptrs = v_base + offs_n_curr[:, None] * stride_vl + offs_d[None, :] * stride_vd
tk_ptrs = tk_base + offs_n_curr[:, None] * stride_tkl + offs_d[None, :] * stride_tkd
tv_ptrs = tv_base + offs_n_curr[:, None] * stride_tvl + offs_d[None, :] * stride_tvd
k = tl.load(k_ptrs, mask=mask_n[:, None], other=0.0)
v = tl.load(v_ptrs, mask=mask_n[:, None], other=0.0)
tk = tl.load(tk_ptrs, mask=mask_n[:, None], other=0.0)
tv = tl.load(tv_ptrs, mask=mask_n[:, None], other=0.0)
# Compute attention scores
qk = tl.dot(q, tl.trans(k))
tqk = tl.dot(tq, tl.trans(k)) + tl.dot(q, tl.trans(tk))
# Mask invalid positions first
qk = tl.where(mask_n[None, :], qk, float('-inf'))
tqk = tl.where(mask_n[None, :], tqk, 0.0)
# Online softmax computation following Flash Attention
m_ij = tl.maximum(m_i, tl.max(qk * scale, 1))
qk = qk * qk_scale - m_ij[:, None] # Scale and subtract max
p = tl.math.exp2(qk) # Use exp2 like reference
# Correction factor
alpha = tl.math.exp2(m_i - m_ij)
l_ij = tl.sum(p, 1)
# Update normalization
l_i = l_i * alpha + l_ij
# Update output accumulator
acc = acc * alpha[:, None] + tl.dot(p, v)
# JVP accumulator: (p * tqk) @ v
p_tqk = p * (tqk * scale) # Apply scale to tangent scores
g_acc = g_acc * alpha[:, None] + tl.dot(p_tqk, v)
# Update mu: sum(p * tqk)
mu_ij = tl.sum(p_tqk, 1)
mu_i = mu_i * alpha + mu_ij
# Update p @ tv accumulator
p_tv_acc = p_tv_acc * alpha[:, None] + tl.dot(p, tv)
# Update max
m_i = m_ij
# Final computation - add log normalization and divide
m_i += tl.math.log2(l_i)
y_out = acc / l_i[:, None]
t_p_v = g_acc / l_i[:, None] - (mu_i / l_i)[:, None] * y_out
t_y_out = t_p_v + p_tv_acc / l_i[:, None]
# Store outputs
y_ptrs = y_base + offs_m[:, None] * stride_yl + offs_d[None, :] * stride_yd
ty_ptrs = ty_base + offs_m[:, None] * stride_tyl + offs_d[None, :] * stride_tyd
tl.store(y_ptrs, y_out, mask=mask_m[:, None])
tl.store(ty_ptrs, t_y_out, mask=mask_m[:, None])
def flash_attention_jvp_multihead_triton_kernel_wrapper(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
t_Q: torch.Tensor,
t_K: torch.Tensor,
t_V: torch.Tensor,
scale: float = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Python wrapper for the Multi-head Flash Attention JVP Triton kernel.
"""
device = Q.device
dtype = Q.dtype
B, H, L, D_head = Q.shape
# Check minimum dimension requirements for Triton
if D_head < 16:
raise ValueError(f"D_head must be >= 16 for efficient Triton kernel, got {D_head}")
if scale is None:
scale = 1.0 / (D_head ** 0.5)
# Ensure input shapes are correct
assert Q.shape == (B, H, L, D_head), f"Q shape mismatch: {Q.shape}"
assert K.shape == (B, H, L, D_head), f"K shape mismatch: {K.shape}"
assert V.shape == (B, H, L, D_head), f"V shape mismatch: {V.shape}"
assert t_Q.shape == (B, H, L, D_head), f"t_Q shape mismatch: {t_Q.shape}"
assert t_K.shape == (B, H, L, D_head), f"t_K shape mismatch: {t_K.shape}"
assert t_V.shape == (B, H, L, D_head), f"t_V shape mismatch: {t_V.shape}"
# Create output tensors
y = torch.zeros((B, H, L, D_head), dtype=dtype, device=device)
t_y = torch.zeros((B, H, L, D_head), dtype=dtype, device=device)
# Make tensors contiguous
Qc = Q.contiguous()
Kc = K.contiguous()
Vc = V.contiguous()
t_Qc = t_Q.contiguous()
t_Kc = t_K.contiguous()
t_Vc = t_V.contiguous()
# Compute strides
stride_qb, stride_qh, stride_ql, stride_qd = Qc.stride()
stride_kb, stride_kh, stride_kl, stride_kd = Kc.stride()
stride_vb, stride_vh, stride_vl, stride_vd = Vc.stride()
stride_tqb, stride_tqh, stride_tql, stride_tqd = t_Qc.stride()
stride_tkb, stride_tkh, stride_tkl, stride_tkd = t_Kc.stride()
stride_tvb, stride_tvh, stride_tvl, stride_tvd = t_Vc.stride()
stride_yb, stride_yh, stride_yl, stride_yd = y.stride()
stride_tyb, stride_tyh, stride_tyl, stride_tyd = t_y.stride()
# Use block-based grid like Flash Attention
# Choose BLOCK_M based on autotuning, but ensure we cover all queries
BLOCK_M = 64 # Will be determined by autotuning
grid = (B * H, triton.cdiv(L, BLOCK_M))
_flash_attention_jvp_multihead_kernel[grid](
Qc, Kc, Vc, t_Qc, t_Kc, t_Vc,
y, t_y,
stride_qb, stride_qh, stride_ql, stride_qd,
stride_kb, stride_kh, stride_kl, stride_kd,
stride_vb, stride_vh, stride_vl, stride_vd,
stride_tqb, stride_tqh, stride_tql, stride_tqd,
stride_tkb, stride_tkh, stride_tkl, stride_tkd,
stride_tvb, stride_tvh, stride_tvl, stride_tvd,
stride_yb, stride_yh, stride_yl, stride_yd,
stride_tyb, stride_tyh, stride_tyl, stride_tyd,
B, H, L, D_head,
scale,
)
return y, t_y
# --- Naive PyTorch Multi-Head JVP ---
def naive_multihead_attention_jvp(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
t_Q: torch.Tensor,
t_K: torch.Tensor,
t_V: torch.Tensor,
scale: float = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Naive implementation of multi-head attention JVP using torch.func.jvp.
"""
from torch.func import (
jvp, # Import locally to avoid issues if torch.func is not available
)
B, H, L, D_head = Q.shape
if scale is None:
scale = 1.0 / (D_head ** 0.5)
def multihead_attention_forward(Q_param, K_param, V_param):
scores = torch.matmul(Q_param, K_param.transpose(-2, -1)) * scale
p = F.softmax(scores, dim=-1)
y = torch.matmul(p, V_param)
return y
y, t_y = jvp(
multihead_attention_forward,
(Q, K, V),
(t_Q, t_K, t_V)
)
return y, t_y
# --- Test Function ---
def test_multihead_correctness():
"""
Tests the Triton multi-head JVP kernel against the naive PyTorch implementation.
"""
print("Running Multi-Head JVP Correctness Test...")
torch.manual_seed(42)
device = torch.device('cuda')
test_configs = [
{'B': 1, 'H': 1, 'L': 8, 'D_head': 16}, # Basic
{'B': 2, 'H': 2, 'L': 16, 'D_head': 32}, # Larger dimensions
{'B': 1, 'H': 4, 'L': 32, 'D_head': 32}, # More heads
{'B': 1, 'H': 1, 'L': 64, 'D_head': 64}, # Larger L, D_head
{'B': 2, 'H': 3, 'L': 128, 'D_head': 16}, # Mixed dimensions
]
for i, config in enumerate(test_configs):
B, H, L, D_head = config['B'], config['H'], config['L'], config['D_head']
print(f"\nTest Case {i+1}: B={B}, H={H}, L={L}, D_head={D_head}")
Q = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32)
K = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32)
V = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32)
t_Q = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) * 0.1 # Smaller tangents
t_K = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) * 0.1
t_V = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) * 0.1
scale = 1.0 / (D_head ** 0.5)
# Compute using naive PyTorch implementation
y_naive, t_y_naive = naive_multihead_attention_jvp(
Q, K, V, t_Q, t_K, t_V, scale
)
# Compute using Triton kernel implementation
try:
y_triton, t_y_triton = flash_attention_jvp_multihead_triton_kernel_wrapper(
Q, K, V, t_Q, t_K, t_V, scale
)
except Exception as e:
print(f" Triton kernel execution failed: {e}")
# If triton fails, we can't compare. Mark as failure for this config.
assert False, f"Triton kernel failed for config {config}"
continue
# Compare results
rtol, atol = 1e-2, 1e-3 # More realistic tolerances for Triton kernel differences
try:
torch.testing.assert_close(
y_triton, y_naive, rtol=rtol, atol=atol,
msg=f"Forward output mismatch for config {config}"
)
print(" Forward output: PASSED")
except AssertionError as e:
print(f" Forward output: FAILED\n{e}")
try:
torch.testing.assert_close(
t_y_triton, t_y_naive, rtol=rtol, atol=atol,
msg=f"JVP output mismatch for config {config}"
)
print(" JVP output: PASSED")
except AssertionError as e:
print(f" JVP output: FAILED\n{e}")
print("\nMulti-Head JVP Correctness Test Finished.")
# --- Memory Test Function ---
def test_memory_usage():
"""
Tests and compares memory usage between Triton and naive implementations.
"""
print("Running Memory Usage Comparison...")
torch.manual_seed(42)
device = torch.device('cuda')
test_configs = [
{'B': 1, 'H': 1, 'L': 128, 'D_head': 64},
{'B': 2, 'H': 2, 'L': 256, 'D_head': 32},
{'B': 1, 'H': 4, 'L': 512, 'D_head': 64},
]
for i, config in enumerate(test_configs):
B, H, L, D_head = config['B'], config['H'], config['L'], config['D_head']
print(f"\nMemory Test {i+1}: B={B}, H={H}, L={L}, D_head={D_head}")
Q = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32)
K = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32)
V = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32)
t_Q = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) * 0.1
t_K = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) * 0.1
t_V = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) * 0.1
scale = 1.0 / (D_head ** 0.5)
# Test Triton implementation memory usage
triton_fn = lambda: flash_attention_jvp_multihead_triton_kernel_wrapper(
Q, K, V, t_Q, t_K, t_V, scale
)
triton_avg_mem, triton_max_mem = measure_memory_usage(triton_fn)
# Test naive implementation memory usage
naive_fn = lambda: naive_multihead_attention_jvp(Q, K, V, t_Q, t_K, t_V, scale)
naive_avg_mem, naive_max_mem = measure_memory_usage(naive_fn)
print(f" Triton - Avg: {triton_avg_mem:.2f} MB, Peak: {triton_max_mem:.2f} MB")
print(f" Naive - Avg: {naive_avg_mem:.2f} MB, Peak: {naive_max_mem:.2f} MB")
memory_ratio = naive_avg_mem / triton_avg_mem if triton_avg_mem > 0 else float('inf')
print(f" Memory Efficiency: {memory_ratio:.2f}x ({'Better' if memory_ratio > 1 else 'Worse'})")
print("\nMemory Usage Comparison Finished.")
# --- Comprehensive Memory Analysis ---
def comprehensive_memory_analysis():
"""
Comprehensive memory analysis comparing Triton vs Naive implementations
across different problem sizes with detailed statistics.
"""
print("\n" + "="*60)
print("COMPREHENSIVE MEMORY ANALYSIS")
print("="*60)
torch.manual_seed(42)
device = torch.device('cuda')
# Extended test configurations for comprehensive analysis
test_configs = [
# Small problems
{'B': 1, 'H': 1, 'L': 64, 'D_head': 32, 'category': 'Small'},
{'B': 1, 'H': 2, 'L': 128, 'D_head': 64, 'category': 'Small'},
# Medium problems
{'B': 2, 'H': 4, 'L': 256, 'D_head': 64, 'category': 'Medium'},
{'B': 1, 'H': 8, 'L': 512, 'D_head': 32, 'category': 'Medium'},
# Large problems
{'B': 4, 'H': 4, 'L': 512, 'D_head': 64, 'category': 'Large'},
{'B': 2, 'H': 8, 'L': 1024, 'D_head': 32, 'category': 'Large'},
]
memory_results = []
for i, config in enumerate(test_configs):
B, H, L, D_head = config['B'], config['H'], config['L'], config['D_head']
category = config['category']
print(f"\n[{category.upper()}] Test {i+1}: B={B}, H={H}, L={L}, D_head={D_head}")
# Calculate theoretical memory requirements
input_size = B * H * L * D_head * 4 # 4 bytes per float32
total_input_size = input_size * 6 # Q, K, V, t_Q, t_K, t_V
output_size = input_size * 2 # Y, t_Y
print(f" Theoretical Input Memory: {total_input_size / (1024*1024):.2f} MB")
print(f" Theoretical Output Memory: {output_size / (1024*1024):.2f} MB")
Q = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32)
K = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32)
V = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32)
t_Q = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) * 0.1
t_K = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) * 0.1
t_V = torch.randn(B, H, L, D_head, device=device, dtype=torch.float32) * 0.1
scale = 1.0 / (D_head ** 0.5)
# Test Triton implementation
triton_fn = lambda: flash_attention_jvp_multihead_triton_kernel_wrapper(
Q, K, V, t_Q, t_K, t_V, scale
)
triton_avg_mem, triton_peak_mem = measure_memory_usage(triton_fn, repetitions=10)
# Test naive implementation
naive_fn = lambda: naive_multihead_attention_jvp(Q, K, V, t_Q, t_K, t_V, scale)
naive_avg_mem, naive_peak_mem = measure_memory_usage(naive_fn, repetitions=10)
# Calculate efficiency metrics
memory_ratio = naive_avg_mem / triton_avg_mem if triton_avg_mem > 0 else float('inf')
peak_ratio = naive_peak_mem / triton_peak_mem if triton_peak_mem > 0 else float('inf')
memory_savings = naive_avg_mem - triton_avg_mem
# Store results
result = {
'config': config,
'triton_avg': triton_avg_mem,
'triton_peak': triton_peak_mem,
'naive_avg': naive_avg_mem,
'naive_peak': naive_peak_mem,
'ratio': memory_ratio,
'peak_ratio': peak_ratio,
'savings': memory_savings
}
memory_results.append(result)
print(f" Triton - Avg: {triton_avg_mem:8.2f} MB, Peak: {triton_peak_mem:8.2f} MB")
print(f" Naive - Avg: {naive_avg_mem:8.2f} MB, Peak: {naive_peak_mem:8.2f} MB")
print(f" Efficiency Gain: {memory_ratio:.2f}x (avg), {peak_ratio:.2f}x (peak)")
print(f" Memory Saved: {memory_savings:.2f} MB")
# Summary statistics
print(f"\n" + "="*60)
print("MEMORY ANALYSIS SUMMARY")
print("="*60)
avg_ratios = [r['ratio'] for r in memory_results if r['ratio'] != float('inf')]
peak_ratios = [r['peak_ratio'] for r in memory_results if r['peak_ratio'] != float('inf')]
total_savings = sum(r['savings'] for r in memory_results)
print(f"Average Memory Efficiency: {sum(avg_ratios)/len(avg_ratios):.2f}x")
print(f"Peak Memory Efficiency: {sum(peak_ratios)/len(peak_ratios):.2f}x")
print(f"Total Memory Saved: {total_savings:.2f} MB")
print(f"Best Case Efficiency: {max(avg_ratios):.2f}x")
print(f"Worst Case Efficiency: {min(avg_ratios):.2f}x")
# Category breakdown
categories = {}
for result in memory_results:
cat = result['config']['category']
if cat not in categories:
categories[cat] = []
categories[cat].append(result['ratio'])
print(f"\nEfficiency by Problem Size:")
for cat, ratios in categories.items():
avg_ratio = sum(ratios) / len(ratios)
print(f" {cat}: {avg_ratio:.2f}x average efficiency")
print(f"\n" + "="*60)
# --- Benchmark Function and Configurations ---
# Benchmark configurations and function are now defined unconditionally
benchmark_configs = [
triton.testing.Benchmark(
x_names=['L'],
x_vals=[2**i for i in range(5, 15)], # Sequence lengths: 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384
line_arg='provider',
line_vals=['triton', 'naive'],
line_names=['Triton Kernel', 'Naive PyTorch'],
styles=[('blue', '-'), ('green', '--')],
ylabel='ms',
plot_name='multihead_jvp_L_scaling-B1-H2-D32',
args={'B': 1, 'H': 2, 'D_head': 32, 'dtype': torch.float32, 'device': 'cuda'}
),
triton.testing.Benchmark(
x_names=['D_head'],
x_vals=[16, 32, 64, 128, 256], # Head dimensions (all power-of-2)
line_arg='provider',
line_vals=['triton', 'naive'],
line_names=['Triton Kernel', 'Naive PyTorch'],
styles=[('blue', '-'), ('green', '--')],
ylabel='ms',
plot_name='multihead_jvp_D_head_scaling-B1-H2-L128',
args={'B': 1, 'H': 2, 'L': 1024, 'dtype': torch.float32, 'device': 'cuda'}
)
]
# Memory benchmark configurations
memory_benchmark_configs = [
triton.testing.Benchmark(
x_names=['L'],
x_vals=[2**i for i in range(5, 13)], # Sequence lengths: 32, 64, 128, 256, 512, 1024, 2048, 4096
line_arg='provider',
line_vals=['triton', 'naive'],
line_names=['Triton Kernel', 'Naive PyTorch'],
styles=[('blue', '-'), ('green', '--')],
ylabel='Memory Usage (MB)',
plot_name='multihead_jvp_memory_L_scaling-B1-H2-D32',
args={'B': 1, 'H': 2, 'D_head': 32, 'dtype': torch.float32, 'device': 'cuda', 'benchmark_type': 'memory'}
),
triton.testing.Benchmark(
x_names=['D_head'],
x_vals=[16, 32, 64, 128, 256], # Head dimensions
line_arg='provider',
line_vals=['triton', 'naive'],
line_names=['Triton Kernel', 'Naive PyTorch'],
styles=[('blue', '-'), ('green', '--')],
ylabel='Memory Usage (MB)',
plot_name='multihead_jvp_memory_D_head_scaling-B1-H2-L512',
args={'B': 1, 'H': 2, 'L': 512, 'dtype': torch.float32, 'device': 'cuda', 'benchmark_type': 'memory'}
),
triton.testing.Benchmark(
x_names=['B'],
x_vals=[1, 2, 4, 8, 16], # Batch sizes
line_arg='provider',
line_vals=['triton', 'naive'],
line_names=['Triton Kernel', 'Naive PyTorch'],
styles=[('blue', '-'), ('green', '--')],
ylabel='Memory Usage (MB)',
plot_name='multihead_jvp_memory_B_scaling-H2-L256-D32',
args={'H': 2, 'L': 256, 'D_head': 32, 'dtype': torch.float32, 'device': 'cuda', 'benchmark_type': 'memory'}
)
]
@triton.testing.perf_report(benchmark_configs)
def bench_multihead_jvp(B, H, L, D_head, provider, dtype, device, benchmark_type='time'):
torch.manual_seed(0) # Ensure consistent inputs for fair comparison
Q = torch.randn(B, H, L, D_head, device=device, dtype=dtype)
K = torch.randn(B, H, L, D_head, device=device, dtype=dtype)
V = torch.randn(B, H, L, D_head, device=device, dtype=dtype)
t_Q = torch.randn(B, H, L, D_head, device=device, dtype=dtype) * 0.1
t_K = torch.randn(B, H, L, D_head, device=device, dtype=dtype) * 0.1
t_V = torch.randn(B, H, L, D_head, device=device, dtype=dtype) * 0.1
scale = 1.0 / (D_head ** 0.5) if D_head > 0 else 1.0
if provider == 'triton':
fn = lambda: flash_attention_jvp_multihead_triton_kernel_wrapper(
Q, K, V, t_Q, t_K, t_V, scale
)
elif provider == 'naive':
fn = lambda: naive_multihead_attention_jvp(Q, K, V, t_Q, t_K, t_V, scale)
else:
raise ValueError(f"Unknown provider: {provider}")
if benchmark_type == 'memory':
# Measure memory usage
avg_memory, max_memory = measure_memory_usage(fn, warmup=3, repetitions=5)
return avg_memory
else:
# Measure time (default)
ms = triton.testing.do_bench(fn, warmup=10, rep=50)
return ms
# Memory-specific benchmark function
@triton.testing.perf_report(memory_benchmark_configs)
def bench_multihead_jvp_memory(B, H, L, D_head, provider, dtype, device, benchmark_type='memory'):
torch.manual_seed(0) # Ensure consistent inputs for fair comparison
Q = torch.randn(B, H, L, D_head, device=device, dtype=dtype)
K = torch.randn(B, H, L, D_head, device=device, dtype=dtype)
V = torch.randn(B, H, L, D_head, device=device, dtype=dtype)
t_Q = torch.randn(B, H, L, D_head, device=device, dtype=dtype) * 0.1
t_K = torch.randn(B, H, L, D_head, device=device, dtype=dtype) * 0.1
t_V = torch.randn(B, H, L, D_head, device=device, dtype=dtype) * 0.1
scale = 1.0 / (D_head ** 0.5) if D_head > 0 else 1.0
if provider == 'triton':
fn = lambda: flash_attention_jvp_multihead_triton_kernel_wrapper(
Q, K, V, t_Q, t_K, t_V, scale
)
elif provider == 'naive':
fn = lambda: naive_multihead_attention_jvp(Q, K, V, t_Q, t_K, t_V, scale)
else:
raise ValueError(f"Unknown provider: {provider}")
# Measure memory usage
avg_memory, max_memory = measure_memory_usage(fn, warmup=3, repetitions=5)
return avg_memory
if __name__ == "__main__":
test_multihead_correctness()
print("\nRunning Memory Usage Tests...")
test_memory_usage()
print("\nRunning Comprehensive Memory Analysis...")
comprehensive_memory_analysis()
# Benchmarks are now run unconditionally
print("\nRunning Multi-Head JVP Time Benchmarks...")
bench_multihead_jvp.run(print_data=True, save_path='.')
print("\nRunning Multi-Head JVP Memory Benchmarks...")
bench_multihead_jvp_memory.run(print_data=True, save_path='.')
print("\nAll benchmarks finished. Plots saved to current directory.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment