Skip to content

Instantly share code, notes, and snippets.

@AmineDiro
Created March 23, 2026 21:56
Show Gist options
  • Select an option

  • Save AmineDiro/41ac0931616dffbe4c77939bc7f1e974 to your computer and use it in GitHub Desktop.

Select an option

Save AmineDiro/41ac0931616dffbe4c77939bc7f1e974 to your computer and use it in GitHub Desktop.
benchmark grpo liger
"""
Benchmark: _ChunkedLogProbFunction vs LigerFusedLinearGRPOLoss
Compares the two chunked GRPO loss approaches without a model:
- _ChunkedLogProbFunction: chunks along vocabulary (V)
- LigerFusedLinearGRPOLoss: chunks along batch (B), fused fwd+bwd
Preset model configs:
--preset qwen3-4b → H=2560, V=151936 (Qwen3-4B-Thinking, 262k context)
--preset qwen2.5-7b → H=3584, V=152064
Usage:
python benchmark_chunked_loss.py --preset qwen3-4b
python benchmark_chunked_loss.py --preset qwen3-4b --seq-lengths 4096 8192 16384 32768
python benchmark_chunked_loss.py --hidden-size 3584 --vocab-size 152064 --seq-lengths 2048 4096
"""
import argparse
import gc
import json
import statistics
import torch
from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss
from trl.experimental.async_grpo.chunk_lm_head import (
_ChunkedLogProbFunction,
_CompiledChunkedLogProbFunction,
compiled_chunked_logprob,
)
def generate_synthetic_inputs(B, S, H, V, device, dtype=torch.bfloat16):
"""Generate synthetic inputs for both benchmarks."""
torch.manual_seed(42)
weight = torch.randn(V, H, device=device, dtype=dtype)
# Flat inputs for _ChunkedLogProbFunction
hidden_flat = torch.randn(B * S, H, device=device, dtype=dtype)
targets_flat = torch.randint(0, V, (B * S,), device=device)
# Half prompt (0), half completion (1)
mask_single = torch.cat([torch.zeros(S // 2, device=device), torch.ones(S - S // 2, device=device)])
mask_flat = mask_single.repeat(B)
advantages_scalar = torch.randn(B, device=device)
# Expand advantages to per-token for the flat wrapper
advantages_flat = advantages_scalar.unsqueeze(1).expand(B, S).reshape(B * S)
# 3D inputs for LigerFusedLinearGRPOLoss
hidden_3d = hidden_flat.detach().clone().view(B, S, H)
selected_token_ids = targets_flat.detach().clone().view(B, S)
attention_mask = mask_flat.detach().clone().view(B, S)
old_per_token_logps = torch.randn(B, S, device=device, dtype=torch.float32) * 0.1
return {
"weight": weight,
"hidden_flat": hidden_flat,
"targets_flat": targets_flat,
"mask_flat": mask_flat,
"advantages_flat": advantages_flat,
"hidden_3d": hidden_3d,
"selected_token_ids": selected_token_ids,
"attention_mask": attention_mask,
"advantages": advantages_scalar,
"old_per_token_logps": old_per_token_logps,
"B": B,
"S": S,
"H": H,
"V": V,
}
def chunked_logprob_loss(hidden, weight, targets, mask_flat, advantages_flat, temperature, chunk_size):
"""Wrap _ChunkedLogProbFunction with a minimal GRPO-style scalar loss."""
logprobs, entropy = _ChunkedLogProbFunction.apply(hidden, weight, targets, temperature, chunk_size)
# Minimal GRPO: -mean(advantage * logprob * mask)
loss = -(advantages_flat * logprobs * mask_flat).sum() / mask_flat.sum().clamp(min=1.0)
return loss
def bench_chunked_logprob(inputs, vocab_chunk_size, temperature, n_warmup=3, n_iter=10):
"""Benchmark _ChunkedLogProbFunction."""
device = inputs["hidden_flat"].device
hidden = inputs["hidden_flat"].detach().clone().requires_grad_(True)
weight = inputs["weight"].detach().clone().requires_grad_(True)
targets = inputs["targets_flat"]
mask_flat = inputs["mask_flat"]
advantages_flat = inputs["advantages_flat"]
# Warmup
for _ in range(n_warmup):
loss = chunked_logprob_loss(hidden, weight, targets, mask_flat, advantages_flat, temperature, vocab_chunk_size)
loss.backward()
hidden.grad = None
weight.grad = None
torch.cuda.synchronize()
# Timed runs
times = []
peak_mems = []
for _ in range(n_iter):
torch.cuda.reset_peak_memory_stats(device)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
loss = chunked_logprob_loss(hidden, weight, targets, mask_flat, advantages_flat, temperature, vocab_chunk_size)
loss.backward()
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))
peak_mems.append(torch.cuda.max_memory_allocated(device) / 1024**3)
hidden.grad = None
weight.grad = None
return {
"method": "ChunkedLogProb",
"chunk_size": vocab_chunk_size,
"median_time_ms": statistics.median(times),
"peak_alloc_gb": max(peak_mems),
"all_times": times,
}
def compiled_step_logprob_loss(hidden, weight, targets, mask_flat, advantages_flat, temperature, chunk_size):
"""Wrap _CompiledChunkedLogProbFunction with a minimal GRPO-style scalar loss."""
logprobs, entropy = _CompiledChunkedLogProbFunction.apply(hidden, weight, targets, temperature, chunk_size)
loss = -(advantages_flat * logprobs * mask_flat).sum() / mask_flat.sum().clamp(min=1.0)
return loss
def compiled_chunked_logprob_loss(hidden, weight, targets, mask_flat, advantages_flat, temperature, chunk_size):
"""Wrap compiled_chunked_logprob with a minimal GRPO-style scalar loss."""
logprobs, entropy = compiled_chunked_logprob(hidden, weight, targets, temperature, chunk_size)
loss = -(advantages_flat * logprobs * mask_flat).sum() / mask_flat.sum().clamp(min=1.0)
return loss
def bench_compiled_step_chunked(inputs, vocab_chunk_size, temperature, n_warmup=3, n_iter=10):
"""Benchmark _CompiledChunkedLogProbFunction (compiled per-step, fast compile)."""
device = inputs["hidden_flat"].device
hidden = inputs["hidden_flat"].detach().clone().requires_grad_(True)
weight = inputs["weight"].detach().clone().requires_grad_(True)
targets = inputs["targets_flat"]
mask_flat = inputs["mask_flat"]
advantages_flat = inputs["advantages_flat"]
# Warmup
for _ in range(n_warmup):
loss = compiled_step_logprob_loss(
hidden, weight, targets, mask_flat, advantages_flat, temperature, vocab_chunk_size
)
loss.backward()
hidden.grad = None
weight.grad = None
torch.cuda.synchronize()
# Timed runs
times = []
peak_mems = []
for _ in range(n_iter):
torch.cuda.reset_peak_memory_stats(device)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
loss = compiled_step_logprob_loss(
hidden, weight, targets, mask_flat, advantages_flat, temperature, vocab_chunk_size
)
loss.backward()
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))
peak_mems.append(torch.cuda.max_memory_allocated(device) / 1024**3)
hidden.grad = None
weight.grad = None
return {
"method": "CompiledChunkedLogProbFunction",
"chunk_size": vocab_chunk_size,
"median_time_ms": statistics.median(times),
"peak_alloc_gb": max(peak_mems),
"all_times": times,
}
def bench_compiled_chunked(inputs, vocab_chunk_size, temperature, n_warmup=3, n_iter=10):
"""Benchmark compiled_chunked_logprob (full compile, slow initial compile)."""
device = inputs["hidden_flat"].device
hidden = inputs["hidden_flat"].detach().clone().requires_grad_(True)
weight = inputs["weight"].detach().clone().requires_grad_(True)
targets = inputs["targets_flat"]
mask_flat = inputs["mask_flat"]
advantages_flat = inputs["advantages_flat"]
# Warmup (critical for torch.compile)
for _ in range(n_warmup):
loss = compiled_chunked_logprob_loss(
hidden, weight, targets, mask_flat, advantages_flat, temperature, vocab_chunk_size
)
loss.backward()
hidden.grad = None
weight.grad = None
torch.cuda.synchronize()
# Timed runs
times = []
peak_mems = []
for _ in range(n_iter):
torch.cuda.reset_peak_memory_stats(device)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
loss = compiled_chunked_logprob_loss(
hidden, weight, targets, mask_flat, advantages_flat, temperature, vocab_chunk_size
)
loss.backward()
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))
peak_mems.append(torch.cuda.max_memory_allocated(device) / 1024**3)
hidden.grad = None
weight.grad = None
return {
"method": "CompiledChunkedAOTFunction",
"chunk_size": vocab_chunk_size,
"median_time_ms": statistics.median(times),
"peak_alloc_gb": max(peak_mems),
"all_times": times,
}
def bench_liger_grpo(inputs, batch_chunk_size, temperature, n_warmup=3, n_iter=10):
"""Benchmark LigerFusedLinearGRPOLoss."""
device = inputs["hidden_3d"].device
hidden_3d = inputs["hidden_3d"].detach().clone().requires_grad_(True)
weight = inputs["weight"].detach().clone().requires_grad_(True)
selected_token_ids = inputs["selected_token_ids"]
attention_mask = inputs["attention_mask"]
advantages = inputs["advantages"]
old_per_token_logps = inputs["old_per_token_logps"]
liger_loss = LigerFusedLinearGRPOLoss(
beta=0.0,
compiled=True,
use_ref_model=False,
chunk_size=batch_chunk_size,
epsilon_low=0.2,
epsilon_high=0.2,
loss_type="grpo",
temperature=temperature,
)
# Warmup (important for torch.compile)
for _ in range(n_warmup):
loss, _metrics = liger_loss(
_input=hidden_3d,
lin_weight=weight,
selected_token_ids=selected_token_ids,
attention_mask=attention_mask,
advantages=advantages,
old_per_token_logps=old_per_token_logps,
)
loss.backward()
hidden_3d.grad = None
weight.grad = None
torch.cuda.synchronize()
# Timed runs
times = []
peak_mems = []
for _ in range(n_iter):
torch.cuda.reset_peak_memory_stats(device)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
loss, _metrics = liger_loss(
_input=hidden_3d,
lin_weight=weight,
selected_token_ids=selected_token_ids,
attention_mask=attention_mask,
advantages=advantages,
old_per_token_logps=old_per_token_logps,
)
loss.backward()
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))
peak_mems.append(torch.cuda.max_memory_allocated(device) / 1024**3)
hidden_3d.grad = None
weight.grad = None
return {
"method": "LigerFusedLinearGRPOLoss",
"chunk_size": batch_chunk_size,
"median_time_ms": statistics.median(times),
"peak_alloc_gb": max(peak_mems),
"all_times": times,
}
def run_config(B, S, H, V, temperature, vocab_chunk_sizes, batch_chunk_size, n_warmup, n_iter):
"""Run all benchmarks for a single (B, S, H, V) configuration."""
device = "cuda"
inputs = generate_synthetic_inputs(B, S, H, V, device)
results = []
for vcs in vocab_chunk_sizes:
print(f" ChunkedLogProb chunk_size={vcs} ...", end=" ", flush=True)
r = bench_chunked_logprob(inputs, vcs, temperature, n_warmup, n_iter)
print(f"done ({r['median_time_ms']:.1f} ms, {r['peak_alloc_gb']:.2f} GB)")
results.append(r)
gc.collect()
torch.cuda.empty_cache()
for vcs in vocab_chunk_sizes:
print(f" CompiledStep chunk_size={vcs} ...", end=" ", flush=True)
r = bench_compiled_step_chunked(inputs, vcs, temperature, n_warmup, n_iter)
print(f"done ({r['median_time_ms']:.1f} ms, {r['peak_alloc_gb']:.2f} GB)")
results.append(r)
gc.collect()
torch.cuda.empty_cache()
# for vcs in vocab_chunk_sizes:
# print(f" CompiledChunked chunk_size={vcs} ...", end=" ", flush=True)
# r = bench_compiled_chunked(inputs, vcs, temperature, n_warmup, n_iter)
# print(f"done ({r['median_time_ms']:.1f} ms, {r['peak_alloc_gb']:.2f} GB)")
# results.append(r)
# gc.collect()
# torch.cuda.empty_cache()
print(f" Liger chunk_size={batch_chunk_size} ...", end=" ", flush=True)
r = bench_liger_grpo(inputs, batch_chunk_size, temperature, n_warmup, n_iter)
print(f"done ({r['median_time_ms']:.1f} ms, {r['peak_alloc_gb']:.2f} GB)")
results.append(r)
gc.collect()
torch.cuda.empty_cache()
return results
def print_table(all_results, configs):
"""Print comparison table."""
print(f"\n{'=' * 90}")
print("BENCHMARK RESULTS")
print(f"{'=' * 90}")
for (B, S, H, V), results in zip(configs, all_results):
print(f"\n--- B={B}, S={S}, H={H}, V={V} ---")
print(f"{'Method':>18} | {'Chunk Size':>12} | {'Peak Alloc (GB)':>16} | {'Median Time (ms)':>18}")
print(f"{'-' * 18}-+-{'-' * 12}-+-{'-' * 16}-+-{'-' * 18}")
for r in results:
print(
f"{r['method']:>18} | {r['chunk_size']:>12} | {r['peak_alloc_gb']:>16.2f} | {r['median_time_ms']:>18.1f}"
)
print(f"\n{'=' * 90}")
MODEL_PRESETS = {
"qwen3-4b": {"hidden_size": 2560, "vocab_size": 151936, "name": "Qwen3-4B-Thinking (262k ctx)"},
"qwen2.5-7b": {"hidden_size": 3584, "vocab_size": 152064, "name": "Qwen2.5-7B"},
"qwen2.5-3b": {"hidden_size": 2048, "vocab_size": 151936, "name": "Qwen2.5-3B"},
"llama3-8b": {"hidden_size": 4096, "vocab_size": 128256, "name": "Llama-3.1-8B (128k ctx)"},
}
def main():
parser = argparse.ArgumentParser(description="Benchmark _ChunkedLogProbFunction vs LigerFusedLinearGRPOLoss")
parser.add_argument("--preset", type=str, choices=list(MODEL_PRESETS.keys()), help="Model preset for H and V")
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--seq-lengths", type=int, nargs="+", default=None)
parser.add_argument("--hidden-size", type=int, default=None)
parser.add_argument("--vocab-size", type=int, default=None)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--vocab-chunk-sizes", type=int, nargs="+", default=[1024, 4096, 8192])
parser.add_argument("--liger-chunk-size", type=int, default=1)
parser.add_argument("--n-warmup", type=int, default=3)
parser.add_argument("--n-iter", type=int, default=10)
parser.add_argument(
"--sweep-vocabs",
action="store_true",
help="Also sweep across different vocab sizes (32k, 64k, 128k, 152k)",
)
parser.add_argument("--output-json", type=str, default=None, help="Save results to JSON file")
args = parser.parse_args()
# Apply preset
if args.preset:
preset = MODEL_PRESETS[args.preset]
H = args.hidden_size or preset["hidden_size"]
V = args.vocab_size or preset["vocab_size"]
print(f"Using preset: {args.preset} ({preset['name']})")
else:
H = args.hidden_size or 3584
V = args.vocab_size or 152064
B = args.batch_size
seq_lengths = args.seq_lengths or [1024, 2048, 4096, 8192, 16384, 32768]
print(f"Device: {torch.cuda.get_device_name()}")
print(f"Config: B={B}, H={H}, V={V}, temperature={args.temperature}")
print(f"Sequence lengths: {seq_lengths}")
print(f"Vocab chunk sizes: {args.vocab_chunk_sizes}")
print(f"Liger batch chunk size: {args.liger_chunk_size}")
print(f"Warmup: {args.n_warmup}, Iterations: {args.n_iter}")
# Build list of (H, V) configs to sweep
hv_configs = [(H, V)]
if args.sweep_vocabs:
extra_vocabs = [32000, 65536, 128256]
for ev in extra_vocabs:
if ev != V:
hv_configs.append((H, ev))
print(f"Sweeping vocab sizes: {[v for _, v in hv_configs]}")
all_results = []
configs = []
for cur_H, cur_V in hv_configs:
for S in seq_lengths:
# Check if this config fits in GPU memory (rough estimate: weight + hidden + overhead)
weight_gb = cur_V * cur_H * 2 / 1024**3 # bf16
hidden_gb = B * S * cur_H * 2 / 1024**3
est_min_gb = weight_gb + hidden_gb * 3 # weight + hidden + grads
gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3
if est_min_gb > gpu_mem_gb * 0.9:
print(f"\nSkipping S={S}, H={cur_H}, V={cur_V} (est. {est_min_gb:.1f} GB > {gpu_mem_gb:.0f} GB GPU)")
continue
print(f"\n{'=' * 60}")
print(f"S={S} (B={B}, H={cur_H}, V={cur_V})")
print(f"{'=' * 60}")
try:
results = run_config(
B,
S,
cur_H,
cur_V,
args.temperature,
args.vocab_chunk_sizes,
args.liger_chunk_size,
args.n_warmup,
args.n_iter,
)
all_results.append(results)
configs.append((B, S, cur_H, cur_V))
except torch.cuda.OutOfMemoryError as e:
print(f" OOM: {e}")
gc.collect()
torch.cuda.empty_cache()
print_table(all_results, configs)
# Save to JSON
if args.output_json:
json_data = {
"device": torch.cuda.get_device_name(),
"preset": args.preset,
"batch_size": B,
"temperature": args.temperature,
"entries": [],
}
for (cfg_B, cfg_S, cfg_H, cfg_V), results in zip(configs, all_results):
for r in results:
json_data["entries"].append(
{
"B": cfg_B,
"S": cfg_S,
"H": cfg_H,
"V": cfg_V,
"method": r["method"],
"chunk_size": r["chunk_size"],
"peak_alloc_gb": r["peak_alloc_gb"],
"median_time_ms": r["median_time_ms"],
"all_times": r["all_times"],
}
)
with open(args.output_json, "w") as f:
json.dump(json_data, f, indent=2)
print(f"\nResults saved to {args.output_json}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment