Skip to content

Instantly share code, notes, and snippets.

@AmineDiro
Created April 28, 2026 08:17
Show Gist options
  • Select an option

  • Save AmineDiro/53d0772aa1f4812aede8a8444cef4ac0 to your computer and use it in GitHub Desktop.

Select an option

Save AmineDiro/53d0772aa1f4812aede8a8444cef4ac0 to your computer and use it in GitHub Desktop.
test ep+fsdp save + resume
"""EP+FSDP2 save/load correctness test.
Verifies that:
full4 losses [L0, L1, L2, L3] == save2 [L0, L1] ++ load2 [L2, L3]
Run as three separate srun/torchrun jobs sharing a checkpoint dir:
--phase full4 : train 4 steps from scratch
--phase save2 : train 2 steps, save state
--phase load2 : load state, train 2 more steps
"""
import argparse
import json
import os
from pathlib import Path
import torch
import torch.distributed as dist
from accelerate import Accelerator
from accelerate.utils import FullyShardedDataParallelPlugin
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from transformers.distributed import DistributedConfig
MODEL_NAME = "Qwen/Qwen3-30B-A3B"
SEED = 42
LR = 2e-5
SEQ_LEN = 1024
def _build_model_and_accelerator():
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
dtype=torch.bfloat16,
attn_implementation="sdpa",
distributed_config=DistributedConfig(enable_expert_parallel=True),
)
dist.barrier()
fsdp_plugin = FullyShardedDataParallelPlugin(
fsdp_version=2,
auto_wrap_policy="transformer_based_wrap",
cpu_ram_efficient_loading=False,
)
ep_module_names = list({n.rsplit(".", 1)[0] for n in model.ep_sharded_param_names})
fsdp_plugin.ignored_modules = [model.get_submodule(n) for n in ep_module_names]
acc = Accelerator(fsdp_plugin=fsdp_plugin)
optim = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=LR)
model, optim = acc.prepare(model, optim)
return model, optim, acc
def _fixed_batch(rank, device):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if rank == 0:
ds = load_dataset("THUDM/LongAlign-10k", split="train").select(range(4))
text = "\n".join(d.get("text", str(d)) for d in ds)
ids = tokenizer(text, return_tensors="pt", truncation=True, max_length=SEQ_LEN).input_ids
else:
ids = torch.zeros((1, SEQ_LEN), dtype=torch.long)
ids = ids.to(device)
dist.broadcast(ids, src=0)
return ids
def train_steps(model, optim, ids, labels, n_steps, start_step=0):
losses = []
model.train()
for step in range(start_step, start_step + n_steps):
optim.zero_grad()
out = model(input_ids=ids, labels=labels)
loss = out.loss
loss.backward()
optim.step()
loss_val = loss.detach().float().item()
losses.append(loss_val)
if dist.get_rank() == 0:
print(f" step {step}: loss={loss_val:.6f}", flush=True)
return losses
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--phase", required=True, choices=["full4", "save2", "load2"])
parser.add_argument("--ckpt-dir", required=True)
parser.add_argument("--losses-out", required=True)
args = parser.parse_args()
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
set_seed(SEED)
if rank == 0:
print(f"\n=== phase={args.phase} world_size={world_size} seed={SEED} ===\n", flush=True)
model, optim, acc = _build_model_and_accelerator()
device = acc.device
ids = _fixed_batch(rank, device)
labels = ids.clone()
if args.phase == "full4":
losses = train_steps(model, optim, ids, labels, n_steps=4, start_step=0)
elif args.phase == "save2":
losses = train_steps(model, optim, ids, labels, n_steps=2, start_step=0)
if rank == 0:
print(f" saving state to {args.ckpt_dir}", flush=True)
acc.save_state(args.ckpt_dir)
if rank == 0:
print(f" saved.", flush=True)
elif args.phase == "load2":
if rank == 0:
print(f" loading state from {args.ckpt_dir}", flush=True)
acc.load_state(args.ckpt_dir)
if rank == 0:
print(f" loaded.", flush=True)
losses = train_steps(model, optim, ids, labels, n_steps=2, start_step=2)
dist.barrier()
if rank == 0:
Path(args.losses_out).write_text(json.dumps({"phase": args.phase, "losses": losses}))
print(f" wrote {args.losses_out}: {losses}", flush=True)
dist.destroy_process_group()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment