Last active
April 2, 2026 14:58
-
-
Save celsowm/ac0ae9c0681bf6f145e400f459b4329b to your computer and use it in GitHub Desktop.
train_qwen35_fullft_v2.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import math | |
| from typing import Any, Dict, List | |
| # ============================================================================= | |
| # ENVIRONMENT AND GPU SETTINGS | |
| # ============================================================================= | |
| os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID") | |
| os.environ.setdefault("CUDA_VISIBLE_DEVICES", "7") | |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| import torch | |
| from datasets import load_dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed | |
| from trl import SFTConfig, SFTTrainer | |
| # ============================================================================= | |
| # GENERAL HYPERPARAMETERS AND SETTINGS | |
| # ============================================================================= | |
| MODEL_NAME = "Qwen/Qwen3.5-4B" | |
| DATASET_NAME = "celsowm/legal_br_sft" | |
| OUTPUT_DIR = "/home/fontesc/qwen35-4b-legal-br-fullft" | |
| SEED = 42 | |
| TEST_SIZE = 0.05 | |
| MAX_LENGTH = 2048 | |
| STRICT_ASSISTANT_MASK = True | |
| NUM_TRAIN_EPOCHS = 1 | |
| LEARNING_RATE = 5e-6 | |
| WEIGHT_DECAY = 0.1 | |
| WARMUP_RATIO = 0.03 | |
| PER_DEVICE_TRAIN_BATCH_SIZE = 4 | |
| PER_DEVICE_EVAL_BATCH_SIZE = 4 | |
| GRADIENT_ACCUMULATION_STEPS = 4 | |
| LOGGING_STEPS = 10 | |
| EVAL_STEPS = 100 | |
| SAVE_STEPS = 100 | |
| SAVE_TOTAL_LIMIT = 2 | |
| OPTIMIZER = "adamw_torch_fused" | |
| # ============================================================================= | |
| # HELPER FUNCTIONS | |
| # ============================================================================= | |
| def flatten_content(content: Any) -> str: | |
| if isinstance(content, str): | |
| return content | |
| if isinstance(content, list): | |
| parts: List[str] = [] | |
| for item in content: | |
| if isinstance(item, str): | |
| parts.append(item) | |
| elif isinstance(item, dict): | |
| if item.get("type") == "text": | |
| parts.append(item.get("text", "")) | |
| elif "text" in item: | |
| parts.append(item["text"]) | |
| elif "content" in item: | |
| parts.append(item["content"]) | |
| return "".join(parts) | |
| return str(content) | |
| def normalize_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, str]]: | |
| return [ | |
| { | |
| "role": str(m["role"]), | |
| "content": flatten_content(m.get("content", "")), | |
| } | |
| for m in messages | |
| ] | |
| def to_messages_only(example: Dict[str, Any]) -> Dict[str, Any]: | |
| return {"messages": normalize_messages(example["messages"])} | |
| def fits_context(example: Dict[str, Any], tokenizer: AutoTokenizer, max_length: int) -> bool: | |
| input_ids = tokenizer.apply_chat_template( | |
| example["messages"], | |
| tokenize=True, | |
| add_generation_prompt=False, | |
| ) | |
| return len(input_ids) <= max_length | |
| def count_parameters(model: torch.nn.Module): | |
| total = sum(p.numel() for p in model.parameters()) | |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| return total, trainable | |
| def chat_template_supports_assistant_mask(tokenizer: AutoTokenizer) -> bool: | |
| probe_messages = [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": "Say hello."}, | |
| {"role": "assistant", "content": "Hello!"}, | |
| ] | |
| try: | |
| rendered = tokenizer.apply_chat_template( | |
| probe_messages, | |
| tokenize=True, | |
| add_generation_prompt=False, | |
| return_dict=True, | |
| return_assistant_tokens_mask=True, | |
| ) | |
| except TypeError: | |
| return False | |
| except Exception: | |
| return False | |
| return isinstance(rendered, dict) and "assistant_masks" in rendered | |
| # ============================================================================= | |
| # MAIN FUNCTION | |
| # ============================================================================= | |
| def main(): | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| set_seed(SEED) | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| print("=" * 80) | |
| print("CUDA_VISIBLE_DEVICES =", os.environ.get("CUDA_VISIBLE_DEVICES")) | |
| print("torch :", torch.__version__) | |
| print("cuda :", torch.cuda.is_available()) | |
| print("gpus :", torch.cuda.device_count()) | |
| if torch.cuda.is_available(): | |
| print("gpu :", torch.cuda.get_device_name(0)) | |
| print("bf16 :", torch.cuda.is_bf16_supported()) | |
| print("=" * 80) | |
| use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported() | |
| dtype = torch.bfloat16 if use_bf16 else torch.float16 | |
| # ── TOKENIZER ───────────────────────────────────────────────────────────── | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_NAME, | |
| trust_remote_code=True, | |
| use_fast=True, | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "right" | |
| assistant_mask_supported = chat_template_supports_assistant_mask(tokenizer) | |
| print("assistant mask support:", assistant_mask_supported) | |
| if STRICT_ASSISTANT_MASK and not assistant_mask_supported: | |
| raise RuntimeError( | |
| "This tokenizer/chat template does not expose assistant_masks. " | |
| "Either update the checkpoint template to support assistant-only loss " | |
| "or set STRICT_ASSISTANT_MASK = False." | |
| ) | |
| # ── MODEL ───────────────────────────────────────────────────────────────── | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| trust_remote_code=True, | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True, | |
| ) | |
| for p in model.parameters(): | |
| p.requires_grad = True | |
| model.config.use_cache = False | |
| total, trainable = count_parameters(model) | |
| print(f"Total params : {total:,}") | |
| print(f"Trainable params : {trainable:,}") | |
| # ── DATASET ─────────────────────────────────────────────────────────────── | |
| ds = load_dataset(DATASET_NAME, split="train") | |
| print(f"\nRaw dataset : {len(ds):,} examples") | |
| ds = ds.map( | |
| to_messages_only, | |
| remove_columns=ds.column_names, | |
| desc="Normalizing messages", | |
| num_proc=4, | |
| ) | |
| print("Filtering examples that exceed MAX_LENGTH...") | |
| before = len(ds) | |
| ds = ds.filter( | |
| lambda ex: fits_context(ex, tokenizer, MAX_LENGTH), | |
| desc="Filtering by length", | |
| num_proc=4, | |
| ) | |
| after = len(ds) | |
| print(f" Removed : {before - after:,} ({(before - after) / before:.1%})") | |
| print(f" Remaining : {after:,}") | |
| split = ds.train_test_split(test_size=TEST_SIZE, seed=SEED) | |
| train_ds = split["train"] | |
| eval_ds = split["test"] | |
| print("\n── Split ───────────────────────────────────────────") | |
| print(f" Train : {len(train_ds):>7,} ({len(train_ds) / after:.1%})") | |
| print(f" Eval : {len(eval_ds):>7,} ({len(eval_ds) / after:.1%})") | |
| print("────────────────────────────────────────────────────\n") | |
| print("Sample (train_ds[0]):") | |
| print(train_ds[0]) | |
| # ── SFT CONFIG ─────────────────────────────────────────────────────────── | |
| args = SFTConfig( | |
| output_dir=OUTPUT_DIR, | |
| do_train=True, | |
| do_eval=True, | |
| num_train_epochs=NUM_TRAIN_EPOCHS, | |
| learning_rate=LEARNING_RATE, | |
| weight_decay=WEIGHT_DECAY, | |
| warmup_ratio=WARMUP_RATIO, | |
| lr_scheduler_type="cosine", | |
| per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE, | |
| per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE, | |
| gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, | |
| logging_steps=LOGGING_STEPS, | |
| eval_strategy="steps", | |
| eval_steps=EVAL_STEPS, | |
| save_strategy="steps", | |
| save_steps=SAVE_STEPS, | |
| save_total_limit=SAVE_TOTAL_LIMIT, | |
| bf16=use_bf16, | |
| fp16=not use_bf16, | |
| tf32=True, | |
| gradient_checkpointing=True, | |
| packing=False, | |
| max_length=MAX_LENGTH, | |
| optim=OPTIMIZER, | |
| report_to="none", | |
| seed=SEED, | |
| dataloader_num_workers=0, | |
| remove_unused_columns=False, | |
| assistant_only_loss=assistant_mask_supported, | |
| eos_token=tokenizer.eos_token, | |
| pad_token=tokenizer.pad_token, | |
| ) | |
| # ── TRAINER ─────────────────────────────────────────────────────────────── | |
| trainer = SFTTrainer( | |
| model=model, | |
| args=args, | |
| train_dataset=train_ds, | |
| eval_dataset=eval_ds, | |
| processing_class=tokenizer, | |
| ) | |
| # ── TRAINING ────────────────────────────────────────────────────────────── | |
| try: | |
| trainer.train() | |
| except KeyboardInterrupt: | |
| print("\nInterrupted — saving checkpoint...") | |
| trainer.save_model(OUTPUT_DIR + "/interrupt") | |
| tokenizer.save_pretrained(OUTPUT_DIR + "/interrupt") | |
| raise | |
| # ── FINAL EVALUATION ────────────────────────────────────────────────────── | |
| metrics = trainer.evaluate() | |
| if "eval_loss" in metrics: | |
| metrics["perplexity"] = math.exp(metrics["eval_loss"]) | |
| print(metrics) | |
| # ── FINAL SAVE ──────────────────────────────────────────────────────────── | |
| trainer.save_model(OUTPUT_DIR) | |
| tokenizer.save_pretrained(OUTPUT_DIR) | |
| print("\n✅ FINISHED") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment