Created
April 28, 2026 17:23
-
-
Save AmineDiro/9fd331214626b60e4d421264637b3828 to your computer and use it in GitHub Desktop.
accelerate_pr4022_test_gist.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Test for PR #4022: per-layer compile + accelerate FSDP2 — slow path FIXED. | |
| Same script as my slow-path repro (https://gist.github.com/AmineDiro/2457fbee70662d584a116cc3ca80dd07); | |
| the only change is adding the `dynamo_config` block to the accelerate yaml — that's | |
| the trigger for `compile_regions_fsdp2` introduced by this PR. | |
| Setup: Qwen3-30B-A3B (MoE, 128 experts, 48 layers) · 2x8 H100 SXM 80GB · | |
| FSDP2 DP=16 · seq_len=16384 · SFTTrainer + grad ckpt + bf16 + packing. | |
| """ | |
| # accelerate_config.yaml (the only diff vs. the slow-path repro is dynamo_config): | |
| # distributed_type: FSDP | |
| # fsdp_config: | |
| # fsdp_version: 2 | |
| # fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP | |
| # fsdp_cpu_ram_efficient_loading: true | |
| # fsdp_offload_params: false | |
| # dynamo_config: # <-- added | |
| # backend: inductor # <-- added | |
| # use_fullgraph: true # <-- added | |
| # use_regional_compilation: true # <-- the flag that flips this PR on | |
| # num_machines: 2 | |
| # num_processes: 16 | |
| # Run with PR #4022 checked out: | |
| # pip install -e . # in the accelerate clone on PR #4022 | |
| # accelerate launch --config_file accelerate_config.yaml \ | |
| # --num_processes 16 --num_machines 2 --machine_rank=$RANK \ | |
| # --main_process_ip=$MASTER --main_process_port=29500 script.py | |
| # Result (Qwen3-30B-A3B, 2x8 H100 SXM 80GB, FSDP2 DP=16, seq_len=16384): | |
| # | |
| # | Setup | MFU | ms/step | | |
| # |------------------------------------------------------------|-----------|----------| | |
| # | raw fully_shard + per-layer compile (control) | 32.1 % | 3,031 | | |
| # | accelerate fsdp2_prepare_model + per-layer compile (BEFORE)| 9.8 % | 9,900 | | |
| # | accelerate fsdp2_prepare_model + per-layer compile (THIS PR)| ~32 % | ~3,000 | | |
| # | |
| # mfu_window samples on this PR over 3 logging steps: 32.55 / 31.27 / 31.81 %. | |
| # ms/step now matches the raw fully_shard fast path. Regression closed. | |
| import torch | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| from datasets import Dataset, load_dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from trl import SFTTrainer, SFTConfig | |
| from trl.data_utils import pack_dataset, maybe_convert_to_chatml | |
| from accelerate import PartialState | |
| SEQ_LEN = 16384 | |
| def main(): | |
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-30B-A3B") | |
| ds = load_dataset("THUDM/LongAlign-10k", split="train") | |
| with PartialState().main_process_first(): | |
| ds = ds.map(maybe_convert_to_chatml, | |
| remove_columns="conversations" if "conversations" in ds.column_names else None) | |
| ds = ds.map(lambda ex: tokenizer( | |
| tokenizer.apply_chat_template(ex["messages"], tokenize=False), | |
| add_special_tokens=False), desc="Tokenizing") | |
| ds = ds.select_columns(["input_ids"]) | |
| ds = pack_dataset(ds, SEQ_LEN, "wrapped") | |
| data = {"input_ids": ds["input_ids"], "labels": ds["input_ids"]} | |
| simple_ds = Dataset.from_dict(data) | |
| simple_ds.set_format("torch") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "Qwen/Qwen3-30B-A3B", dtype=torch.bfloat16, attn_implementation="sdpa" | |
| ) | |
| args = SFTConfig( | |
| output_dir="/tmp/pr4022_test", | |
| max_steps=20, | |
| per_device_train_batch_size=1, | |
| gradient_checkpointing=True, | |
| save_strategy="no", | |
| report_to="none", | |
| logging_steps=5, | |
| torch_compile=True, # SFTTrainer pre-compiles per-layer; this PR's | |
| # compile_regions_fsdp2 then re-applies the in-place | |
| # compile after FSDP2 wrap so the FSDP hooks survive. | |
| tf32=True, | |
| max_length=SEQ_LEN, | |
| packing=True, | |
| packing_strategy="wrapped", | |
| include_num_input_tokens_seen=True, | |
| dataset_kwargs={"skip_prepare_dataset": True}, | |
| ) | |
| trainer = SFTTrainer( | |
| model=model, args=args, train_dataset=simple_ds, processing_class=tokenizer, | |
| ) | |
| trainer.train() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment