Created
April 28, 2026 08:17
-
-
Save AmineDiro/53d0772aa1f4812aede8a8444cef4ac0 to your computer and use it in GitHub Desktop.
test ep+fsdp save + resume
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """EP+FSDP2 save/load correctness test. | |
| Verifies that: | |
| full4 losses [L0, L1, L2, L3] == save2 [L0, L1] ++ load2 [L2, L3] | |
| Run as three separate srun/torchrun jobs sharing a checkpoint dir: | |
| --phase full4 : train 4 steps from scratch | |
| --phase save2 : train 2 steps, save state | |
| --phase load2 : load state, train 2 more steps | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| from pathlib import Path | |
| import torch | |
| import torch.distributed as dist | |
| from accelerate import Accelerator | |
| from accelerate.utils import FullyShardedDataParallelPlugin | |
| from datasets import load_dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed | |
| from transformers.distributed import DistributedConfig | |
| MODEL_NAME = "Qwen/Qwen3-30B-A3B" | |
| SEED = 42 | |
| LR = 2e-5 | |
| SEQ_LEN = 1024 | |
| def _build_model_and_accelerator(): | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| dtype=torch.bfloat16, | |
| attn_implementation="sdpa", | |
| distributed_config=DistributedConfig(enable_expert_parallel=True), | |
| ) | |
| dist.barrier() | |
| fsdp_plugin = FullyShardedDataParallelPlugin( | |
| fsdp_version=2, | |
| auto_wrap_policy="transformer_based_wrap", | |
| cpu_ram_efficient_loading=False, | |
| ) | |
| ep_module_names = list({n.rsplit(".", 1)[0] for n in model.ep_sharded_param_names}) | |
| fsdp_plugin.ignored_modules = [model.get_submodule(n) for n in ep_module_names] | |
| acc = Accelerator(fsdp_plugin=fsdp_plugin) | |
| optim = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=LR) | |
| model, optim = acc.prepare(model, optim) | |
| return model, optim, acc | |
| def _fixed_batch(rank, device): | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| if rank == 0: | |
| ds = load_dataset("THUDM/LongAlign-10k", split="train").select(range(4)) | |
| text = "\n".join(d.get("text", str(d)) for d in ds) | |
| ids = tokenizer(text, return_tensors="pt", truncation=True, max_length=SEQ_LEN).input_ids | |
| else: | |
| ids = torch.zeros((1, SEQ_LEN), dtype=torch.long) | |
| ids = ids.to(device) | |
| dist.broadcast(ids, src=0) | |
| return ids | |
| def train_steps(model, optim, ids, labels, n_steps, start_step=0): | |
| losses = [] | |
| model.train() | |
| for step in range(start_step, start_step + n_steps): | |
| optim.zero_grad() | |
| out = model(input_ids=ids, labels=labels) | |
| loss = out.loss | |
| loss.backward() | |
| optim.step() | |
| loss_val = loss.detach().float().item() | |
| losses.append(loss_val) | |
| if dist.get_rank() == 0: | |
| print(f" step {step}: loss={loss_val:.6f}", flush=True) | |
| return losses | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--phase", required=True, choices=["full4", "save2", "load2"]) | |
| parser.add_argument("--ckpt-dir", required=True) | |
| parser.add_argument("--losses-out", required=True) | |
| args = parser.parse_args() | |
| rank = int(os.environ["RANK"]) | |
| world_size = int(os.environ["WORLD_SIZE"]) | |
| set_seed(SEED) | |
| if rank == 0: | |
| print(f"\n=== phase={args.phase} world_size={world_size} seed={SEED} ===\n", flush=True) | |
| model, optim, acc = _build_model_and_accelerator() | |
| device = acc.device | |
| ids = _fixed_batch(rank, device) | |
| labels = ids.clone() | |
| if args.phase == "full4": | |
| losses = train_steps(model, optim, ids, labels, n_steps=4, start_step=0) | |
| elif args.phase == "save2": | |
| losses = train_steps(model, optim, ids, labels, n_steps=2, start_step=0) | |
| if rank == 0: | |
| print(f" saving state to {args.ckpt_dir}", flush=True) | |
| acc.save_state(args.ckpt_dir) | |
| if rank == 0: | |
| print(f" saved.", flush=True) | |
| elif args.phase == "load2": | |
| if rank == 0: | |
| print(f" loading state from {args.ckpt_dir}", flush=True) | |
| acc.load_state(args.ckpt_dir) | |
| if rank == 0: | |
| print(f" loaded.", flush=True) | |
| losses = train_steps(model, optim, ids, labels, n_steps=2, start_step=2) | |
| dist.barrier() | |
| if rank == 0: | |
| Path(args.losses_out).write_text(json.dumps({"phase": args.phase, "losses": losses})) | |
| print(f" wrote {args.losses_out}: {losses}", flush=True) | |
| dist.destroy_process_group() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment