Last active
April 24, 2026 08:34
-
-
Save AmineDiro/be1fd63af7f715c9431ee378d5d79dc6 to your computer and use it in GitHub Desktop.
FSDP2 Per-layer compile
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Per-layer compile + raw fully_shard = 32% MFU (fast path).""" | |
| # Run with: | |
| # torchrun --nproc_per_node=8 --nnodes=2 --node_rank=$RANK \ | |
| # --master_addr=$MASTER --master_port=29500 script.py | |
| # | |
| # Result: ~3,031 ms/step, 32.1% MFU | |
| import os, time, torch | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| import torch.distributed as dist | |
| from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from datasets import load_dataset | |
| from trl.data_utils import pack_dataset, maybe_convert_to_chatml | |
| from trl.trainer.utils import entropy_from_logits | |
| SEQ_LEN = 16384 | |
| def main(): | |
| dist.init_process_group("nccl") | |
| rank = dist.get_rank() | |
| world_size = dist.get_world_size() | |
| torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) | |
| device = torch.device(f"cuda:{int(os.environ.get('LOCAL_RANK', 0))}") | |
| # Identical data pipeline to the slow-path (SFTTrainer) script: | |
| # chat template -> tokenize -> pack_dataset(wrapped) to exactly SEQ_LEN. | |
| # We grab the first packed sample and reuse it across steps since the loop | |
| # is hand-rolled; shape and content distribution match the slow path. | |
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-30B-A3B") | |
| ds = load_dataset("THUDM/LongAlign-10k", split="train") | |
| ds = ds.map(maybe_convert_to_chatml, | |
| remove_columns="conversations" if "conversations" in ds.column_names else None) | |
| ds = ds.map(lambda ex: tokenizer( | |
| tokenizer.apply_chat_template(ex["messages"], tokenize=False), | |
| add_special_tokens=False), desc="Tokenizing") | |
| ds = ds.select_columns(["input_ids"]) | |
| ds = pack_dataset(ds, SEQ_LEN, "wrapped") | |
| input_ids = torch.tensor(ds[0]["input_ids"], dtype=torch.long, device=device).unsqueeze(0) | |
| labels = input_ids.clone() | |
| attn_mask = torch.ones_like(input_ids) | |
| # Load model | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "Qwen/Qwen3-30B-A3B", torch_dtype=torch.bfloat16, attn_implementation="sdpa" | |
| ) | |
| model.train() | |
| model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) | |
| # Per-layer compile (torchtitan approach) | |
| torch._dynamo.config.capture_scalar_outputs = True | |
| torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint = True | |
| model.config._experts_implementation = "grouped_mm" | |
| for layer in model.model.layers: | |
| layer.compile(backend="inductor", fullgraph=True) | |
| # Apply FSDP2 — raw fully_shard (THIS IS THE FAST PATH) | |
| mp = MixedPrecisionPolicy() | |
| for layer in model.model.layers: | |
| fully_shard(layer, mp_policy=mp) | |
| fully_shard(model, mp_policy=mp) | |
| # Optimizer + scheduler | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5) | |
| scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, total_iters=20) | |
| # Training loop with full SFT logic | |
| N_WARMUP, N_STEPS = 3, 17 | |
| for step in range(N_WARMUP + N_STEPS): | |
| model.train() | |
| outputs = model(input_ids=input_ids, labels=labels, attention_mask=attn_mask, use_cache=False) | |
| loss = outputs.loss | |
| # Entropy (same as SFT Trainer) | |
| with torch.no_grad(): | |
| entropy = entropy_from_logits(outputs.logits) | |
| entropy = (entropy * attn_mask).sum() / attn_mask.sum() | |
| dist.all_reduce(entropy.clone()) | |
| _ = entropy.item() | |
| # Accuracy (same as SFT Trainer) | |
| with torch.no_grad(): | |
| shift_logits = outputs.logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| correct = ((shift_logits.argmax(dim=-1) == shift_labels) & (shift_labels != -100)).sum() | |
| total = (shift_labels != -100).sum() | |
| dist.all_reduce(correct); dist.all_reduce(total) | |
| _ = correct.item(); _ = total.item() | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) | |
| optimizer.step() | |
| scheduler.step() | |
| model.zero_grad() | |
| if step == N_WARMUP: | |
| torch.cuda.synchronize(); dist.barrier() | |
| t0 = time.perf_counter() | |
| torch.cuda.synchronize() | |
| elapsed = time.perf_counter() - t0 | |
| tps = (SEQ_LEN * world_size * N_STEPS) / elapsed | |
| if rank == 0: | |
| print(f"TPS/GPU: {tps/world_size:.0f}, ms/step: {elapsed/N_STEPS*1000:.0f}") | |
| dist.destroy_process_group() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment