Skip to content

Instantly share code, notes, and snippets.

@celsowm
Last active April 2, 2026 14:58
Show Gist options
  • Select an option

  • Save celsowm/ac0ae9c0681bf6f145e400f459b4329b to your computer and use it in GitHub Desktop.

Select an option

Save celsowm/ac0ae9c0681bf6f145e400f459b4329b to your computer and use it in GitHub Desktop.
train_qwen35_fullft_v2.py
import os
import math
from typing import Any, Dict, List
# =============================================================================
# ENVIRONMENT AND GPU SETTINGS
# =============================================================================
os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "7")
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from trl import SFTConfig, SFTTrainer
# =============================================================================
# GENERAL HYPERPARAMETERS AND SETTINGS
# =============================================================================
MODEL_NAME = "Qwen/Qwen3.5-4B"
DATASET_NAME = "celsowm/legal_br_sft"
OUTPUT_DIR = "/home/fontesc/qwen35-4b-legal-br-fullft"
SEED = 42
TEST_SIZE = 0.05
MAX_LENGTH = 2048
STRICT_ASSISTANT_MASK = True
NUM_TRAIN_EPOCHS = 1
LEARNING_RATE = 5e-6
WEIGHT_DECAY = 0.1
WARMUP_RATIO = 0.03
PER_DEVICE_TRAIN_BATCH_SIZE = 4
PER_DEVICE_EVAL_BATCH_SIZE = 4
GRADIENT_ACCUMULATION_STEPS = 4
LOGGING_STEPS = 10
EVAL_STEPS = 100
SAVE_STEPS = 100
SAVE_TOTAL_LIMIT = 2
OPTIMIZER = "adamw_torch_fused"
# =============================================================================
# HELPER FUNCTIONS
# =============================================================================
def flatten_content(content: Any) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts: List[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
if item.get("type") == "text":
parts.append(item.get("text", ""))
elif "text" in item:
parts.append(item["text"])
elif "content" in item:
parts.append(item["content"])
return "".join(parts)
return str(content)
def normalize_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, str]]:
return [
{
"role": str(m["role"]),
"content": flatten_content(m.get("content", "")),
}
for m in messages
]
def to_messages_only(example: Dict[str, Any]) -> Dict[str, Any]:
return {"messages": normalize_messages(example["messages"])}
def fits_context(example: Dict[str, Any], tokenizer: AutoTokenizer, max_length: int) -> bool:
input_ids = tokenizer.apply_chat_template(
example["messages"],
tokenize=True,
add_generation_prompt=False,
)
return len(input_ids) <= max_length
def count_parameters(model: torch.nn.Module):
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
return total, trainable
def chat_template_supports_assistant_mask(tokenizer: AutoTokenizer) -> bool:
probe_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Say hello."},
{"role": "assistant", "content": "Hello!"},
]
try:
rendered = tokenizer.apply_chat_template(
probe_messages,
tokenize=True,
add_generation_prompt=False,
return_dict=True,
return_assistant_tokens_mask=True,
)
except TypeError:
return False
except Exception:
return False
return isinstance(rendered, dict) and "assistant_masks" in rendered
# =============================================================================
# MAIN FUNCTION
# =============================================================================
def main():
os.makedirs(OUTPUT_DIR, exist_ok=True)
set_seed(SEED)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
print("=" * 80)
print("CUDA_VISIBLE_DEVICES =", os.environ.get("CUDA_VISIBLE_DEVICES"))
print("torch :", torch.__version__)
print("cuda :", torch.cuda.is_available())
print("gpus :", torch.cuda.device_count())
if torch.cuda.is_available():
print("gpu :", torch.cuda.get_device_name(0))
print("bf16 :", torch.cuda.is_bf16_supported())
print("=" * 80)
use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
dtype = torch.bfloat16 if use_bf16 else torch.float16
# ── TOKENIZER ─────────────────────────────────────────────────────────────
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
use_fast=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
assistant_mask_supported = chat_template_supports_assistant_mask(tokenizer)
print("assistant mask support:", assistant_mask_supported)
if STRICT_ASSISTANT_MASK and not assistant_mask_supported:
raise RuntimeError(
"This tokenizer/chat template does not expose assistant_masks. "
"Either update the checkpoint template to support assistant-only loss "
"or set STRICT_ASSISTANT_MASK = False."
)
# ── MODEL ─────────────────────────────────────────────────────────────────
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
torch_dtype=dtype,
low_cpu_mem_usage=True,
)
for p in model.parameters():
p.requires_grad = True
model.config.use_cache = False
total, trainable = count_parameters(model)
print(f"Total params : {total:,}")
print(f"Trainable params : {trainable:,}")
# ── DATASET ───────────────────────────────────────────────────────────────
ds = load_dataset(DATASET_NAME, split="train")
print(f"\nRaw dataset : {len(ds):,} examples")
ds = ds.map(
to_messages_only,
remove_columns=ds.column_names,
desc="Normalizing messages",
num_proc=4,
)
print("Filtering examples that exceed MAX_LENGTH...")
before = len(ds)
ds = ds.filter(
lambda ex: fits_context(ex, tokenizer, MAX_LENGTH),
desc="Filtering by length",
num_proc=4,
)
after = len(ds)
print(f" Removed : {before - after:,} ({(before - after) / before:.1%})")
print(f" Remaining : {after:,}")
split = ds.train_test_split(test_size=TEST_SIZE, seed=SEED)
train_ds = split["train"]
eval_ds = split["test"]
print("\n── Split ───────────────────────────────────────────")
print(f" Train : {len(train_ds):>7,} ({len(train_ds) / after:.1%})")
print(f" Eval : {len(eval_ds):>7,} ({len(eval_ds) / after:.1%})")
print("────────────────────────────────────────────────────\n")
print("Sample (train_ds[0]):")
print(train_ds[0])
# ── SFT CONFIG ───────────────────────────────────────────────────────────
args = SFTConfig(
output_dir=OUTPUT_DIR,
do_train=True,
do_eval=True,
num_train_epochs=NUM_TRAIN_EPOCHS,
learning_rate=LEARNING_RATE,
weight_decay=WEIGHT_DECAY,
warmup_ratio=WARMUP_RATIO,
lr_scheduler_type="cosine",
per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
logging_steps=LOGGING_STEPS,
eval_strategy="steps",
eval_steps=EVAL_STEPS,
save_strategy="steps",
save_steps=SAVE_STEPS,
save_total_limit=SAVE_TOTAL_LIMIT,
bf16=use_bf16,
fp16=not use_bf16,
tf32=True,
gradient_checkpointing=True,
packing=False,
max_length=MAX_LENGTH,
optim=OPTIMIZER,
report_to="none",
seed=SEED,
dataloader_num_workers=0,
remove_unused_columns=False,
assistant_only_loss=assistant_mask_supported,
eos_token=tokenizer.eos_token,
pad_token=tokenizer.pad_token,
)
# ── TRAINER ───────────────────────────────────────────────────────────────
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=train_ds,
eval_dataset=eval_ds,
processing_class=tokenizer,
)
# ── TRAINING ──────────────────────────────────────────────────────────────
try:
trainer.train()
except KeyboardInterrupt:
print("\nInterrupted — saving checkpoint...")
trainer.save_model(OUTPUT_DIR + "/interrupt")
tokenizer.save_pretrained(OUTPUT_DIR + "/interrupt")
raise
# ── FINAL EVALUATION ──────────────────────────────────────────────────────
metrics = trainer.evaluate()
if "eval_loss" in metrics:
metrics["perplexity"] = math.exp(metrics["eval_loss"])
print(metrics)
# ── FINAL SAVE ────────────────────────────────────────────────────────────
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print("\n✅ FINISHED")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment