Skip to content

Instantly share code, notes, and snippets.

@AmineDiro
Last active April 24, 2026 08:34
Show Gist options
  • Select an option

  • Save AmineDiro/be1fd63af7f715c9431ee378d5d79dc6 to your computer and use it in GitHub Desktop.

Select an option

Save AmineDiro/be1fd63af7f715c9431ee378d5d79dc6 to your computer and use it in GitHub Desktop.
FSDP2 Per-layer compile
"""Per-layer compile + raw fully_shard = 32% MFU (fast path)."""
# Run with:
# torchrun --nproc_per_node=8 --nnodes=2 --node_rank=$RANK \
# --master_addr=$MASTER --master_port=29500 script.py
#
# Result: ~3,031 ms/step, 32.1% MFU
import os, time, torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import torch.distributed as dist
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl.data_utils import pack_dataset, maybe_convert_to_chatml
from trl.trainer.utils import entropy_from_logits
SEQ_LEN = 16384
def main():
dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0)))
device = torch.device(f"cuda:{int(os.environ.get('LOCAL_RANK', 0))}")
# Identical data pipeline to the slow-path (SFTTrainer) script:
# chat template -> tokenize -> pack_dataset(wrapped) to exactly SEQ_LEN.
# We grab the first packed sample and reuse it across steps since the loop
# is hand-rolled; shape and content distribution match the slow path.
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-30B-A3B")
ds = load_dataset("THUDM/LongAlign-10k", split="train")
ds = ds.map(maybe_convert_to_chatml,
remove_columns="conversations" if "conversations" in ds.column_names else None)
ds = ds.map(lambda ex: tokenizer(
tokenizer.apply_chat_template(ex["messages"], tokenize=False),
add_special_tokens=False), desc="Tokenizing")
ds = ds.select_columns(["input_ids"])
ds = pack_dataset(ds, SEQ_LEN, "wrapped")
input_ids = torch.tensor(ds[0]["input_ids"], dtype=torch.long, device=device).unsqueeze(0)
labels = input_ids.clone()
attn_mask = torch.ones_like(input_ids)
# Load model
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-30B-A3B", torch_dtype=torch.bfloat16, attn_implementation="sdpa"
)
model.train()
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
# Per-layer compile (torchtitan approach)
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint = True
model.config._experts_implementation = "grouped_mm"
for layer in model.model.layers:
layer.compile(backend="inductor", fullgraph=True)
# Apply FSDP2 — raw fully_shard (THIS IS THE FAST PATH)
mp = MixedPrecisionPolicy()
for layer in model.model.layers:
fully_shard(layer, mp_policy=mp)
fully_shard(model, mp_policy=mp)
# Optimizer + scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, total_iters=20)
# Training loop with full SFT logic
N_WARMUP, N_STEPS = 3, 17
for step in range(N_WARMUP + N_STEPS):
model.train()
outputs = model(input_ids=input_ids, labels=labels, attention_mask=attn_mask, use_cache=False)
loss = outputs.loss
# Entropy (same as SFT Trainer)
with torch.no_grad():
entropy = entropy_from_logits(outputs.logits)
entropy = (entropy * attn_mask).sum() / attn_mask.sum()
dist.all_reduce(entropy.clone())
_ = entropy.item()
# Accuracy (same as SFT Trainer)
with torch.no_grad():
shift_logits = outputs.logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
correct = ((shift_logits.argmax(dim=-1) == shift_labels) & (shift_labels != -100)).sum()
total = (shift_labels != -100).sum()
dist.all_reduce(correct); dist.all_reduce(total)
_ = correct.item(); _ = total.item()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
model.zero_grad()
if step == N_WARMUP:
torch.cuda.synchronize(); dist.barrier()
t0 = time.perf_counter()
torch.cuda.synchronize()
elapsed = time.perf_counter() - t0
tps = (SEQ_LEN * world_size * N_STEPS) / elapsed
if rank == 0:
print(f"TPS/GPU: {tps/world_size:.0f}, ms/step: {elapsed/N_STEPS*1000:.0f}")
dist.destroy_process_group()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment