Created
March 23, 2026 21:56
-
-
Save AmineDiro/41ac0931616dffbe4c77939bc7f1e974 to your computer and use it in GitHub Desktop.
benchmark grpo liger
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Benchmark: _ChunkedLogProbFunction vs LigerFusedLinearGRPOLoss | |
| Compares the two chunked GRPO loss approaches without a model: | |
| - _ChunkedLogProbFunction: chunks along vocabulary (V) | |
| - LigerFusedLinearGRPOLoss: chunks along batch (B), fused fwd+bwd | |
| Preset model configs: | |
| --preset qwen3-4b → H=2560, V=151936 (Qwen3-4B-Thinking, 262k context) | |
| --preset qwen2.5-7b → H=3584, V=152064 | |
| Usage: | |
| python benchmark_chunked_loss.py --preset qwen3-4b | |
| python benchmark_chunked_loss.py --preset qwen3-4b --seq-lengths 4096 8192 16384 32768 | |
| python benchmark_chunked_loss.py --hidden-size 3584 --vocab-size 152064 --seq-lengths 2048 4096 | |
| """ | |
| import argparse | |
| import gc | |
| import json | |
| import statistics | |
| import torch | |
| from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss | |
| from trl.experimental.async_grpo.chunk_lm_head import ( | |
| _ChunkedLogProbFunction, | |
| _CompiledChunkedLogProbFunction, | |
| compiled_chunked_logprob, | |
| ) | |
| def generate_synthetic_inputs(B, S, H, V, device, dtype=torch.bfloat16): | |
| """Generate synthetic inputs for both benchmarks.""" | |
| torch.manual_seed(42) | |
| weight = torch.randn(V, H, device=device, dtype=dtype) | |
| # Flat inputs for _ChunkedLogProbFunction | |
| hidden_flat = torch.randn(B * S, H, device=device, dtype=dtype) | |
| targets_flat = torch.randint(0, V, (B * S,), device=device) | |
| # Half prompt (0), half completion (1) | |
| mask_single = torch.cat([torch.zeros(S // 2, device=device), torch.ones(S - S // 2, device=device)]) | |
| mask_flat = mask_single.repeat(B) | |
| advantages_scalar = torch.randn(B, device=device) | |
| # Expand advantages to per-token for the flat wrapper | |
| advantages_flat = advantages_scalar.unsqueeze(1).expand(B, S).reshape(B * S) | |
| # 3D inputs for LigerFusedLinearGRPOLoss | |
| hidden_3d = hidden_flat.detach().clone().view(B, S, H) | |
| selected_token_ids = targets_flat.detach().clone().view(B, S) | |
| attention_mask = mask_flat.detach().clone().view(B, S) | |
| old_per_token_logps = torch.randn(B, S, device=device, dtype=torch.float32) * 0.1 | |
| return { | |
| "weight": weight, | |
| "hidden_flat": hidden_flat, | |
| "targets_flat": targets_flat, | |
| "mask_flat": mask_flat, | |
| "advantages_flat": advantages_flat, | |
| "hidden_3d": hidden_3d, | |
| "selected_token_ids": selected_token_ids, | |
| "attention_mask": attention_mask, | |
| "advantages": advantages_scalar, | |
| "old_per_token_logps": old_per_token_logps, | |
| "B": B, | |
| "S": S, | |
| "H": H, | |
| "V": V, | |
| } | |
| def chunked_logprob_loss(hidden, weight, targets, mask_flat, advantages_flat, temperature, chunk_size): | |
| """Wrap _ChunkedLogProbFunction with a minimal GRPO-style scalar loss.""" | |
| logprobs, entropy = _ChunkedLogProbFunction.apply(hidden, weight, targets, temperature, chunk_size) | |
| # Minimal GRPO: -mean(advantage * logprob * mask) | |
| loss = -(advantages_flat * logprobs * mask_flat).sum() / mask_flat.sum().clamp(min=1.0) | |
| return loss | |
| def bench_chunked_logprob(inputs, vocab_chunk_size, temperature, n_warmup=3, n_iter=10): | |
| """Benchmark _ChunkedLogProbFunction.""" | |
| device = inputs["hidden_flat"].device | |
| hidden = inputs["hidden_flat"].detach().clone().requires_grad_(True) | |
| weight = inputs["weight"].detach().clone().requires_grad_(True) | |
| targets = inputs["targets_flat"] | |
| mask_flat = inputs["mask_flat"] | |
| advantages_flat = inputs["advantages_flat"] | |
| # Warmup | |
| for _ in range(n_warmup): | |
| loss = chunked_logprob_loss(hidden, weight, targets, mask_flat, advantages_flat, temperature, vocab_chunk_size) | |
| loss.backward() | |
| hidden.grad = None | |
| weight.grad = None | |
| torch.cuda.synchronize() | |
| # Timed runs | |
| times = [] | |
| peak_mems = [] | |
| for _ in range(n_iter): | |
| torch.cuda.reset_peak_memory_stats(device) | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| start.record() | |
| loss = chunked_logprob_loss(hidden, weight, targets, mask_flat, advantages_flat, temperature, vocab_chunk_size) | |
| loss.backward() | |
| end.record() | |
| torch.cuda.synchronize() | |
| times.append(start.elapsed_time(end)) | |
| peak_mems.append(torch.cuda.max_memory_allocated(device) / 1024**3) | |
| hidden.grad = None | |
| weight.grad = None | |
| return { | |
| "method": "ChunkedLogProb", | |
| "chunk_size": vocab_chunk_size, | |
| "median_time_ms": statistics.median(times), | |
| "peak_alloc_gb": max(peak_mems), | |
| "all_times": times, | |
| } | |
| def compiled_step_logprob_loss(hidden, weight, targets, mask_flat, advantages_flat, temperature, chunk_size): | |
| """Wrap _CompiledChunkedLogProbFunction with a minimal GRPO-style scalar loss.""" | |
| logprobs, entropy = _CompiledChunkedLogProbFunction.apply(hidden, weight, targets, temperature, chunk_size) | |
| loss = -(advantages_flat * logprobs * mask_flat).sum() / mask_flat.sum().clamp(min=1.0) | |
| return loss | |
| def compiled_chunked_logprob_loss(hidden, weight, targets, mask_flat, advantages_flat, temperature, chunk_size): | |
| """Wrap compiled_chunked_logprob with a minimal GRPO-style scalar loss.""" | |
| logprobs, entropy = compiled_chunked_logprob(hidden, weight, targets, temperature, chunk_size) | |
| loss = -(advantages_flat * logprobs * mask_flat).sum() / mask_flat.sum().clamp(min=1.0) | |
| return loss | |
| def bench_compiled_step_chunked(inputs, vocab_chunk_size, temperature, n_warmup=3, n_iter=10): | |
| """Benchmark _CompiledChunkedLogProbFunction (compiled per-step, fast compile).""" | |
| device = inputs["hidden_flat"].device | |
| hidden = inputs["hidden_flat"].detach().clone().requires_grad_(True) | |
| weight = inputs["weight"].detach().clone().requires_grad_(True) | |
| targets = inputs["targets_flat"] | |
| mask_flat = inputs["mask_flat"] | |
| advantages_flat = inputs["advantages_flat"] | |
| # Warmup | |
| for _ in range(n_warmup): | |
| loss = compiled_step_logprob_loss( | |
| hidden, weight, targets, mask_flat, advantages_flat, temperature, vocab_chunk_size | |
| ) | |
| loss.backward() | |
| hidden.grad = None | |
| weight.grad = None | |
| torch.cuda.synchronize() | |
| # Timed runs | |
| times = [] | |
| peak_mems = [] | |
| for _ in range(n_iter): | |
| torch.cuda.reset_peak_memory_stats(device) | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| start.record() | |
| loss = compiled_step_logprob_loss( | |
| hidden, weight, targets, mask_flat, advantages_flat, temperature, vocab_chunk_size | |
| ) | |
| loss.backward() | |
| end.record() | |
| torch.cuda.synchronize() | |
| times.append(start.elapsed_time(end)) | |
| peak_mems.append(torch.cuda.max_memory_allocated(device) / 1024**3) | |
| hidden.grad = None | |
| weight.grad = None | |
| return { | |
| "method": "CompiledChunkedLogProbFunction", | |
| "chunk_size": vocab_chunk_size, | |
| "median_time_ms": statistics.median(times), | |
| "peak_alloc_gb": max(peak_mems), | |
| "all_times": times, | |
| } | |
| def bench_compiled_chunked(inputs, vocab_chunk_size, temperature, n_warmup=3, n_iter=10): | |
| """Benchmark compiled_chunked_logprob (full compile, slow initial compile).""" | |
| device = inputs["hidden_flat"].device | |
| hidden = inputs["hidden_flat"].detach().clone().requires_grad_(True) | |
| weight = inputs["weight"].detach().clone().requires_grad_(True) | |
| targets = inputs["targets_flat"] | |
| mask_flat = inputs["mask_flat"] | |
| advantages_flat = inputs["advantages_flat"] | |
| # Warmup (critical for torch.compile) | |
| for _ in range(n_warmup): | |
| loss = compiled_chunked_logprob_loss( | |
| hidden, weight, targets, mask_flat, advantages_flat, temperature, vocab_chunk_size | |
| ) | |
| loss.backward() | |
| hidden.grad = None | |
| weight.grad = None | |
| torch.cuda.synchronize() | |
| # Timed runs | |
| times = [] | |
| peak_mems = [] | |
| for _ in range(n_iter): | |
| torch.cuda.reset_peak_memory_stats(device) | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| start.record() | |
| loss = compiled_chunked_logprob_loss( | |
| hidden, weight, targets, mask_flat, advantages_flat, temperature, vocab_chunk_size | |
| ) | |
| loss.backward() | |
| end.record() | |
| torch.cuda.synchronize() | |
| times.append(start.elapsed_time(end)) | |
| peak_mems.append(torch.cuda.max_memory_allocated(device) / 1024**3) | |
| hidden.grad = None | |
| weight.grad = None | |
| return { | |
| "method": "CompiledChunkedAOTFunction", | |
| "chunk_size": vocab_chunk_size, | |
| "median_time_ms": statistics.median(times), | |
| "peak_alloc_gb": max(peak_mems), | |
| "all_times": times, | |
| } | |
| def bench_liger_grpo(inputs, batch_chunk_size, temperature, n_warmup=3, n_iter=10): | |
| """Benchmark LigerFusedLinearGRPOLoss.""" | |
| device = inputs["hidden_3d"].device | |
| hidden_3d = inputs["hidden_3d"].detach().clone().requires_grad_(True) | |
| weight = inputs["weight"].detach().clone().requires_grad_(True) | |
| selected_token_ids = inputs["selected_token_ids"] | |
| attention_mask = inputs["attention_mask"] | |
| advantages = inputs["advantages"] | |
| old_per_token_logps = inputs["old_per_token_logps"] | |
| liger_loss = LigerFusedLinearGRPOLoss( | |
| beta=0.0, | |
| compiled=True, | |
| use_ref_model=False, | |
| chunk_size=batch_chunk_size, | |
| epsilon_low=0.2, | |
| epsilon_high=0.2, | |
| loss_type="grpo", | |
| temperature=temperature, | |
| ) | |
| # Warmup (important for torch.compile) | |
| for _ in range(n_warmup): | |
| loss, _metrics = liger_loss( | |
| _input=hidden_3d, | |
| lin_weight=weight, | |
| selected_token_ids=selected_token_ids, | |
| attention_mask=attention_mask, | |
| advantages=advantages, | |
| old_per_token_logps=old_per_token_logps, | |
| ) | |
| loss.backward() | |
| hidden_3d.grad = None | |
| weight.grad = None | |
| torch.cuda.synchronize() | |
| # Timed runs | |
| times = [] | |
| peak_mems = [] | |
| for _ in range(n_iter): | |
| torch.cuda.reset_peak_memory_stats(device) | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| start.record() | |
| loss, _metrics = liger_loss( | |
| _input=hidden_3d, | |
| lin_weight=weight, | |
| selected_token_ids=selected_token_ids, | |
| attention_mask=attention_mask, | |
| advantages=advantages, | |
| old_per_token_logps=old_per_token_logps, | |
| ) | |
| loss.backward() | |
| end.record() | |
| torch.cuda.synchronize() | |
| times.append(start.elapsed_time(end)) | |
| peak_mems.append(torch.cuda.max_memory_allocated(device) / 1024**3) | |
| hidden_3d.grad = None | |
| weight.grad = None | |
| return { | |
| "method": "LigerFusedLinearGRPOLoss", | |
| "chunk_size": batch_chunk_size, | |
| "median_time_ms": statistics.median(times), | |
| "peak_alloc_gb": max(peak_mems), | |
| "all_times": times, | |
| } | |
| def run_config(B, S, H, V, temperature, vocab_chunk_sizes, batch_chunk_size, n_warmup, n_iter): | |
| """Run all benchmarks for a single (B, S, H, V) configuration.""" | |
| device = "cuda" | |
| inputs = generate_synthetic_inputs(B, S, H, V, device) | |
| results = [] | |
| for vcs in vocab_chunk_sizes: | |
| print(f" ChunkedLogProb chunk_size={vcs} ...", end=" ", flush=True) | |
| r = bench_chunked_logprob(inputs, vcs, temperature, n_warmup, n_iter) | |
| print(f"done ({r['median_time_ms']:.1f} ms, {r['peak_alloc_gb']:.2f} GB)") | |
| results.append(r) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| for vcs in vocab_chunk_sizes: | |
| print(f" CompiledStep chunk_size={vcs} ...", end=" ", flush=True) | |
| r = bench_compiled_step_chunked(inputs, vcs, temperature, n_warmup, n_iter) | |
| print(f"done ({r['median_time_ms']:.1f} ms, {r['peak_alloc_gb']:.2f} GB)") | |
| results.append(r) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| # for vcs in vocab_chunk_sizes: | |
| # print(f" CompiledChunked chunk_size={vcs} ...", end=" ", flush=True) | |
| # r = bench_compiled_chunked(inputs, vcs, temperature, n_warmup, n_iter) | |
| # print(f"done ({r['median_time_ms']:.1f} ms, {r['peak_alloc_gb']:.2f} GB)") | |
| # results.append(r) | |
| # gc.collect() | |
| # torch.cuda.empty_cache() | |
| print(f" Liger chunk_size={batch_chunk_size} ...", end=" ", flush=True) | |
| r = bench_liger_grpo(inputs, batch_chunk_size, temperature, n_warmup, n_iter) | |
| print(f"done ({r['median_time_ms']:.1f} ms, {r['peak_alloc_gb']:.2f} GB)") | |
| results.append(r) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return results | |
| def print_table(all_results, configs): | |
| """Print comparison table.""" | |
| print(f"\n{'=' * 90}") | |
| print("BENCHMARK RESULTS") | |
| print(f"{'=' * 90}") | |
| for (B, S, H, V), results in zip(configs, all_results): | |
| print(f"\n--- B={B}, S={S}, H={H}, V={V} ---") | |
| print(f"{'Method':>18} | {'Chunk Size':>12} | {'Peak Alloc (GB)':>16} | {'Median Time (ms)':>18}") | |
| print(f"{'-' * 18}-+-{'-' * 12}-+-{'-' * 16}-+-{'-' * 18}") | |
| for r in results: | |
| print( | |
| f"{r['method']:>18} | {r['chunk_size']:>12} | {r['peak_alloc_gb']:>16.2f} | {r['median_time_ms']:>18.1f}" | |
| ) | |
| print(f"\n{'=' * 90}") | |
| MODEL_PRESETS = { | |
| "qwen3-4b": {"hidden_size": 2560, "vocab_size": 151936, "name": "Qwen3-4B-Thinking (262k ctx)"}, | |
| "qwen2.5-7b": {"hidden_size": 3584, "vocab_size": 152064, "name": "Qwen2.5-7B"}, | |
| "qwen2.5-3b": {"hidden_size": 2048, "vocab_size": 151936, "name": "Qwen2.5-3B"}, | |
| "llama3-8b": {"hidden_size": 4096, "vocab_size": 128256, "name": "Llama-3.1-8B (128k ctx)"}, | |
| } | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Benchmark _ChunkedLogProbFunction vs LigerFusedLinearGRPOLoss") | |
| parser.add_argument("--preset", type=str, choices=list(MODEL_PRESETS.keys()), help="Model preset for H and V") | |
| parser.add_argument("--batch-size", type=int, default=1) | |
| parser.add_argument("--seq-lengths", type=int, nargs="+", default=None) | |
| parser.add_argument("--hidden-size", type=int, default=None) | |
| parser.add_argument("--vocab-size", type=int, default=None) | |
| parser.add_argument("--temperature", type=float, default=1.0) | |
| parser.add_argument("--vocab-chunk-sizes", type=int, nargs="+", default=[1024, 4096, 8192]) | |
| parser.add_argument("--liger-chunk-size", type=int, default=1) | |
| parser.add_argument("--n-warmup", type=int, default=3) | |
| parser.add_argument("--n-iter", type=int, default=10) | |
| parser.add_argument( | |
| "--sweep-vocabs", | |
| action="store_true", | |
| help="Also sweep across different vocab sizes (32k, 64k, 128k, 152k)", | |
| ) | |
| parser.add_argument("--output-json", type=str, default=None, help="Save results to JSON file") | |
| args = parser.parse_args() | |
| # Apply preset | |
| if args.preset: | |
| preset = MODEL_PRESETS[args.preset] | |
| H = args.hidden_size or preset["hidden_size"] | |
| V = args.vocab_size or preset["vocab_size"] | |
| print(f"Using preset: {args.preset} ({preset['name']})") | |
| else: | |
| H = args.hidden_size or 3584 | |
| V = args.vocab_size or 152064 | |
| B = args.batch_size | |
| seq_lengths = args.seq_lengths or [1024, 2048, 4096, 8192, 16384, 32768] | |
| print(f"Device: {torch.cuda.get_device_name()}") | |
| print(f"Config: B={B}, H={H}, V={V}, temperature={args.temperature}") | |
| print(f"Sequence lengths: {seq_lengths}") | |
| print(f"Vocab chunk sizes: {args.vocab_chunk_sizes}") | |
| print(f"Liger batch chunk size: {args.liger_chunk_size}") | |
| print(f"Warmup: {args.n_warmup}, Iterations: {args.n_iter}") | |
| # Build list of (H, V) configs to sweep | |
| hv_configs = [(H, V)] | |
| if args.sweep_vocabs: | |
| extra_vocabs = [32000, 65536, 128256] | |
| for ev in extra_vocabs: | |
| if ev != V: | |
| hv_configs.append((H, ev)) | |
| print(f"Sweeping vocab sizes: {[v for _, v in hv_configs]}") | |
| all_results = [] | |
| configs = [] | |
| for cur_H, cur_V in hv_configs: | |
| for S in seq_lengths: | |
| # Check if this config fits in GPU memory (rough estimate: weight + hidden + overhead) | |
| weight_gb = cur_V * cur_H * 2 / 1024**3 # bf16 | |
| hidden_gb = B * S * cur_H * 2 / 1024**3 | |
| est_min_gb = weight_gb + hidden_gb * 3 # weight + hidden + grads | |
| gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
| if est_min_gb > gpu_mem_gb * 0.9: | |
| print(f"\nSkipping S={S}, H={cur_H}, V={cur_V} (est. {est_min_gb:.1f} GB > {gpu_mem_gb:.0f} GB GPU)") | |
| continue | |
| print(f"\n{'=' * 60}") | |
| print(f"S={S} (B={B}, H={cur_H}, V={cur_V})") | |
| print(f"{'=' * 60}") | |
| try: | |
| results = run_config( | |
| B, | |
| S, | |
| cur_H, | |
| cur_V, | |
| args.temperature, | |
| args.vocab_chunk_sizes, | |
| args.liger_chunk_size, | |
| args.n_warmup, | |
| args.n_iter, | |
| ) | |
| all_results.append(results) | |
| configs.append((B, S, cur_H, cur_V)) | |
| except torch.cuda.OutOfMemoryError as e: | |
| print(f" OOM: {e}") | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| print_table(all_results, configs) | |
| # Save to JSON | |
| if args.output_json: | |
| json_data = { | |
| "device": torch.cuda.get_device_name(), | |
| "preset": args.preset, | |
| "batch_size": B, | |
| "temperature": args.temperature, | |
| "entries": [], | |
| } | |
| for (cfg_B, cfg_S, cfg_H, cfg_V), results in zip(configs, all_results): | |
| for r in results: | |
| json_data["entries"].append( | |
| { | |
| "B": cfg_B, | |
| "S": cfg_S, | |
| "H": cfg_H, | |
| "V": cfg_V, | |
| "method": r["method"], | |
| "chunk_size": r["chunk_size"], | |
| "peak_alloc_gb": r["peak_alloc_gb"], | |
| "median_time_ms": r["median_time_ms"], | |
| "all_times": r["all_times"], | |
| } | |
| ) | |
| with open(args.output_json, "w") as f: | |
| json.dump(json_data, f, indent=2) | |
| print(f"\nResults saved to {args.output_json}") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment