Last active
April 24, 2026 08:35
-
-
Save AmineDiro/2457fbee70662d584a116cc3ca80dd07 to your computer and use it in GitHub Desktop.
accelerate fsdp2
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Per-layer compile + accelerate FSDP2 = 10% MFU (slow path).""" | |
| # accelerate_config.yaml: | |
| # distributed_type: FSDP | |
| # fsdp_config: | |
| # fsdp_version: 2 | |
| # fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP | |
| # fsdp_cpu_ram_efficient_loading: true | |
| # fsdp_offload_params: false | |
| # num_machines: 2 | |
| # num_processes: 16 | |
| # | |
| # Run with: | |
| # accelerate launch --config_file accelerate_config.yaml \ | |
| # --num_processes 16 --num_machines 2 --machine_rank=$RANK \ | |
| # --main_process_ip=$MASTER --main_process_port=29500 script.py | |
| # | |
| # Result: ~9,900 ms/step, 9.8% MFU | |
| import torch | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| from datasets import Dataset, load_dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from trl import SFTTrainer, SFTConfig | |
| from trl.data_utils import pack_dataset, maybe_convert_to_chatml | |
| from accelerate import PartialState | |
| SEQ_LEN = 16384 | |
| def main(): | |
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-30B-A3B") | |
| # Prepare packed data | |
| ds = load_dataset("THUDM/LongAlign-10k", split="train") | |
| with PartialState().main_process_first(): | |
| ds = ds.map(maybe_convert_to_chatml, | |
| remove_columns="conversations" if "conversations" in ds.column_names else None) | |
| ds = ds.map(lambda ex: tokenizer( | |
| tokenizer.apply_chat_template(ex["messages"], tokenize=False), | |
| add_special_tokens=False), desc="Tokenizing") | |
| ds = ds.select_columns(["input_ids"]) | |
| ds = pack_dataset(ds, SEQ_LEN, "wrapped") | |
| data = {"input_ids": ds["input_ids"], "labels": ds["input_ids"]} | |
| simple_ds = Dataset.from_dict(data) | |
| simple_ds.set_format("torch") | |
| # Load model | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "Qwen/Qwen3-30B-A3B", dtype=torch.bfloat16, attn_implementation="sdpa" | |
| ) | |
| args = SFTConfig( | |
| output_dir="/tmp/compile_test", | |
| max_steps=20, | |
| per_device_train_batch_size=1, | |
| gradient_checkpointing=True, | |
| save_strategy="no", | |
| report_to="none", | |
| logging_steps=5, | |
| torch_compile=True, # triggers per-layer compile in SFTTrainer.__init__ | |
| tf32=True, | |
| max_length=SEQ_LEN, | |
| packing=True, | |
| packing_strategy="wrapped", | |
| include_num_input_tokens_seen=True, | |
| dataset_kwargs={"skip_prepare_dataset": True}, | |
| ) | |
| # What SFTTrainer.__init__ runs when torch_compile=True (verbatim from | |
| # trl/trainer/sft_trainer.py:995-1030), BEFORE calling super().__init__() | |
| # (which is where accelerator.prepare() -> fsdp2_prepare_model() happens): | |
| # | |
| # # Per-layer torch.compile: compile each transformer block individually before FSDP wrapping. | |
| # # This avoids graph breaks from FSDP hooks and transformers decorators that occur when compiling | |
| # # the whole model after FSDP (the default HF Trainer behavior). See torchtitan's approach: | |
| # # https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/compile.py | |
| # if args.torch_compile and not isinstance(model, str): | |
| # layers = getattr(getattr(model, "model", None), "layers", None) | |
| # if layers is None: | |
| # print( | |
| # f"[SFTTrainer] Per-layer compile: model.model.layers not found. " | |
| # f"model type={type(model).__name__}, has .model={hasattr(model, 'model')}. " | |
| # f"Falling back to whole-model compile.", | |
| # flush=True, | |
| # ) | |
| # else: | |
| # torch._dynamo.config.capture_scalar_outputs = True | |
| # torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint = True | |
| # # Use grouped_mm experts implementation for compile-friendly MoE dispatch. | |
| # # The default eager implementation uses a Python loop with data-dependent shapes | |
| # # (nonzero + dynamic indexing) that inductor handles poorly. | |
| # if getattr(model.config, "_experts_implementation", None) is None: | |
| # if getattr(model.config, "num_local_experts", 0) > 0 or getattr( | |
| # model.config, "num_experts", 0 | |
| # ) > 0: | |
| # model.config._experts_implementation = "grouped_mm" | |
| # for layer in layers: | |
| # layer.compile( | |
| # backend=args.torch_compile_backend or "inductor", | |
| # fullgraph=not args.use_liger_kernel, | |
| # ) | |
| # print( | |
| # f"[SFTTrainer] Compiled {len(layers)} transformer blocks with torch.compile " | |
| # f"(per-layer, fullgraph=True)", | |
| # flush=True, | |
| # ) | |
| # # Disable whole-model compile so HF Trainer doesn't also compile after FSDP | |
| # args.torch_compile = False | |
| # | |
| # So by the time accelerate wraps the model, every transformer block already has | |
| # layer._compiled_call_impl set — identical to what the fast-path script does manually. | |
| trainer = SFTTrainer( | |
| model=model, args=args, train_dataset=simple_ds, processing_class=tokenizer, | |
| ) | |
| trainer.train() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment