Skip to content

Instantly share code, notes, and snippets.

@AmineDiro
Last active April 24, 2026 08:35
Show Gist options
  • Select an option

  • Save AmineDiro/2457fbee70662d584a116cc3ca80dd07 to your computer and use it in GitHub Desktop.

Select an option

Save AmineDiro/2457fbee70662d584a116cc3ca80dd07 to your computer and use it in GitHub Desktop.
accelerate fsdp2
"""Per-layer compile + accelerate FSDP2 = 10% MFU (slow path)."""
# accelerate_config.yaml:
# distributed_type: FSDP
# fsdp_config:
# fsdp_version: 2
# fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
# fsdp_cpu_ram_efficient_loading: true
# fsdp_offload_params: false
# num_machines: 2
# num_processes: 16
#
# Run with:
# accelerate launch --config_file accelerate_config.yaml \
# --num_processes 16 --num_machines 2 --machine_rank=$RANK \
# --main_process_ip=$MASTER --main_process_port=29500 script.py
#
# Result: ~9,900 ms/step, 9.8% MFU
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
from trl.data_utils import pack_dataset, maybe_convert_to_chatml
from accelerate import PartialState
SEQ_LEN = 16384
def main():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-30B-A3B")
# Prepare packed data
ds = load_dataset("THUDM/LongAlign-10k", split="train")
with PartialState().main_process_first():
ds = ds.map(maybe_convert_to_chatml,
remove_columns="conversations" if "conversations" in ds.column_names else None)
ds = ds.map(lambda ex: tokenizer(
tokenizer.apply_chat_template(ex["messages"], tokenize=False),
add_special_tokens=False), desc="Tokenizing")
ds = ds.select_columns(["input_ids"])
ds = pack_dataset(ds, SEQ_LEN, "wrapped")
data = {"input_ids": ds["input_ids"], "labels": ds["input_ids"]}
simple_ds = Dataset.from_dict(data)
simple_ds.set_format("torch")
# Load model
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-30B-A3B", dtype=torch.bfloat16, attn_implementation="sdpa"
)
args = SFTConfig(
output_dir="/tmp/compile_test",
max_steps=20,
per_device_train_batch_size=1,
gradient_checkpointing=True,
save_strategy="no",
report_to="none",
logging_steps=5,
torch_compile=True, # triggers per-layer compile in SFTTrainer.__init__
tf32=True,
max_length=SEQ_LEN,
packing=True,
packing_strategy="wrapped",
include_num_input_tokens_seen=True,
dataset_kwargs={"skip_prepare_dataset": True},
)
# What SFTTrainer.__init__ runs when torch_compile=True (verbatim from
# trl/trainer/sft_trainer.py:995-1030), BEFORE calling super().__init__()
# (which is where accelerator.prepare() -> fsdp2_prepare_model() happens):
#
# # Per-layer torch.compile: compile each transformer block individually before FSDP wrapping.
# # This avoids graph breaks from FSDP hooks and transformers decorators that occur when compiling
# # the whole model after FSDP (the default HF Trainer behavior). See torchtitan's approach:
# # https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/compile.py
# if args.torch_compile and not isinstance(model, str):
# layers = getattr(getattr(model, "model", None), "layers", None)
# if layers is None:
# print(
# f"[SFTTrainer] Per-layer compile: model.model.layers not found. "
# f"model type={type(model).__name__}, has .model={hasattr(model, 'model')}. "
# f"Falling back to whole-model compile.",
# flush=True,
# )
# else:
# torch._dynamo.config.capture_scalar_outputs = True
# torch._dynamo.config.skip_fwd_side_effects_in_bwd_under_checkpoint = True
# # Use grouped_mm experts implementation for compile-friendly MoE dispatch.
# # The default eager implementation uses a Python loop with data-dependent shapes
# # (nonzero + dynamic indexing) that inductor handles poorly.
# if getattr(model.config, "_experts_implementation", None) is None:
# if getattr(model.config, "num_local_experts", 0) > 0 or getattr(
# model.config, "num_experts", 0
# ) > 0:
# model.config._experts_implementation = "grouped_mm"
# for layer in layers:
# layer.compile(
# backend=args.torch_compile_backend or "inductor",
# fullgraph=not args.use_liger_kernel,
# )
# print(
# f"[SFTTrainer] Compiled {len(layers)} transformer blocks with torch.compile "
# f"(per-layer, fullgraph=True)",
# flush=True,
# )
# # Disable whole-model compile so HF Trainer doesn't also compile after FSDP
# args.torch_compile = False
#
# So by the time accelerate wraps the model, every transformer block already has
# layer._compiled_call_impl set — identical to what the fast-path script does manually.
trainer = SFTTrainer(
model=model, args=args, train_dataset=simple_ds, processing_class=tokenizer,
)
trainer.train()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment