Skip to content

Instantly share code, notes, and snippets.

@AmineDiro
Created April 28, 2026 17:23
Show Gist options
  • Select an option

  • Save AmineDiro/9fd331214626b60e4d421264637b3828 to your computer and use it in GitHub Desktop.

Select an option

Save AmineDiro/9fd331214626b60e4d421264637b3828 to your computer and use it in GitHub Desktop.
accelerate_pr4022_test_gist.py
"""Test for PR #4022: per-layer compile + accelerate FSDP2 — slow path FIXED.
Same script as my slow-path repro (https://gist.github.com/AmineDiro/2457fbee70662d584a116cc3ca80dd07);
the only change is adding the `dynamo_config` block to the accelerate yaml — that's
the trigger for `compile_regions_fsdp2` introduced by this PR.
Setup: Qwen3-30B-A3B (MoE, 128 experts, 48 layers) · 2x8 H100 SXM 80GB ·
FSDP2 DP=16 · seq_len=16384 · SFTTrainer + grad ckpt + bf16 + packing.
"""
# accelerate_config.yaml (the only diff vs. the slow-path repro is dynamo_config):
# distributed_type: FSDP
# fsdp_config:
# fsdp_version: 2
# fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
# fsdp_cpu_ram_efficient_loading: true
# fsdp_offload_params: false
# dynamo_config: # <-- added
# backend: inductor # <-- added
# use_fullgraph: true # <-- added
# use_regional_compilation: true # <-- the flag that flips this PR on
# num_machines: 2
# num_processes: 16
# Run with PR #4022 checked out:
# pip install -e . # in the accelerate clone on PR #4022
# accelerate launch --config_file accelerate_config.yaml \
# --num_processes 16 --num_machines 2 --machine_rank=$RANK \
# --main_process_ip=$MASTER --main_process_port=29500 script.py
# Result (Qwen3-30B-A3B, 2x8 H100 SXM 80GB, FSDP2 DP=16, seq_len=16384):
#
# | Setup | MFU | ms/step |
# |------------------------------------------------------------|-----------|----------|
# | raw fully_shard + per-layer compile (control) | 32.1 % | 3,031 |
# | accelerate fsdp2_prepare_model + per-layer compile (BEFORE)| 9.8 % | 9,900 |
# | accelerate fsdp2_prepare_model + per-layer compile (THIS PR)| ~32 % | ~3,000 |
#
# mfu_window samples on this PR over 3 logging steps: 32.55 / 31.27 / 31.81 %.
# ms/step now matches the raw fully_shard fast path. Regression closed.
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
from trl.data_utils import pack_dataset, maybe_convert_to_chatml
from accelerate import PartialState
SEQ_LEN = 16384
def main():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-30B-A3B")
ds = load_dataset("THUDM/LongAlign-10k", split="train")
with PartialState().main_process_first():
ds = ds.map(maybe_convert_to_chatml,
remove_columns="conversations" if "conversations" in ds.column_names else None)
ds = ds.map(lambda ex: tokenizer(
tokenizer.apply_chat_template(ex["messages"], tokenize=False),
add_special_tokens=False), desc="Tokenizing")
ds = ds.select_columns(["input_ids"])
ds = pack_dataset(ds, SEQ_LEN, "wrapped")
data = {"input_ids": ds["input_ids"], "labels": ds["input_ids"]}
simple_ds = Dataset.from_dict(data)
simple_ds.set_format("torch")
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-30B-A3B", dtype=torch.bfloat16, attn_implementation="sdpa"
)
args = SFTConfig(
output_dir="/tmp/pr4022_test",
max_steps=20,
per_device_train_batch_size=1,
gradient_checkpointing=True,
save_strategy="no",
report_to="none",
logging_steps=5,
torch_compile=True, # SFTTrainer pre-compiles per-layer; this PR's
# compile_regions_fsdp2 then re-applies the in-place
# compile after FSDP2 wrap so the FSDP hooks survive.
tf32=True,
max_length=SEQ_LEN,
packing=True,
packing_strategy="wrapped",
include_num_input_tokens_seen=True,
dataset_kwargs={"skip_prepare_dataset": True},
)
trainer = SFTTrainer(
model=model, args=args, train_dataset=simple_ds, processing_class=tokenizer,
)
trainer.train()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment