Skip to content

Instantly share code, notes, and snippets.

@muellerzr
Last active May 28, 2025 20:06
Show Gist options
  • Save muellerzr/020fc2e2e39867279e863e2bec9123dd to your computer and use it in GitHub Desktop.
Save muellerzr/020fc2e2e39867279e863e2bec9123dd to your computer and use it in GitHub Desktop.
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_from_disk
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
import numpy as np
import logging
import time
from datetime import datetime
import wandb
import sys
import os
SYSTEM_PROMPT = """Return only the command to be executed as a raw string, no string delimiters wrapping it, no yapping, no markdown, no fenced code blocks, what you return will be passed to the terminal directly.
For example, if the user asks "undo last git commit", you return only "git reset --soft HEAD~1"
The shell is /bin/bash on linux
Use the shortest and most efficient direct command possible.
"""
# Set up logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Create console handler with formatting
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
# Create file handler with formatting
file_handler = logging.FileHandler(f'training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
# Add both handlers to logger
logger.addHandler(console_handler)
logger.addHandler(file_handler)
# Initialize wandb
wandb.init(
project="qwen3-distillation",
config={
"model_name": "/mnt/models/Qwen3-0.6B",
"max_length": 2048,
"batch_size": 4,
"grad_accum_steps": 4,
"learning_rate": 5e-5,
"num_epochs": 1
}
)
# Initialize predictions table
predictions_table = wandb.Table(columns=["step", "example", "prediction"])
model_name = "/mnt/models/Qwen3-0.6B"
logger.info(f"Loading model and tokenizer from {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = "<|endoftext|>"
tokenizer.eos_token = "<|im_end|>"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
class StudentWithDropout(torch.nn.Module):
def __init__(self, base_model, dropout_prob=0.1):
super().__init__()
self.base_model = base_model
self.dropout = torch.nn.Dropout(dropout_prob)
def forward(self, **kwargs):
outputs = self.base_model(**kwargs)
outputs.logits = self.dropout(outputs.logits)
return outputs
model = StudentWithDropout(model).to("cuda:0")
logger.info("Model and tokenizer loaded successfully")
dataset = load_from_disk("scored_data/combined_scored_data.hf")
logger.info(f"Loaded dataset with {len(dataset)} examples")
MAX_LENGTH = 2048
logger.info(f"Using max sequence length of {MAX_LENGTH}")
def tokenize(example):
# Define prompt structure matching Qwen3
system_prompt = SYSTEM_PROMPT.strip()
user_prompt = example['nl']
thinking = example["thinking"]
assistant_response = example['bash']
input_text = (
f"<|im_start|>system\n{system_prompt}\n<|im_end|>\n"
f"<|im_start|>user\n{user_prompt}\n<|im_end|>\n"
f"<|im_start|>assistant\n"
)
target_text = (
f"<think>\n{thinking}\n</think>\n"
f"{assistant_response}\n<|im_end|>"
)
input_enc = tokenizer(input_text, truncation=True, padding='max_length', max_length=MAX_LENGTH)
target_enc = tokenizer(target_text, truncation=True, padding='max_length', max_length=MAX_LENGTH)
labels = target_enc['input_ids']
labels = [l if l != tokenizer.pad_token_id else -100 for l in labels]
# Now postprocess to mask trailing '\n' tokens:
labels = torch.tensor(labels, dtype=torch.long)
# Create a clone to mask trailing newlines without modifying in-place
masked_labels = labels.clone()
# Mask trailing '\n' tokens
for i in reversed(range(len(masked_labels))):
token_id = masked_labels[i].item()
if token_id < 0:
continue # already masked
decoded_token = tokenizer.decode([token_id])
if decoded_token == '\n':
masked_labels[i] = -100
else:
break
labels = masked_labels
# logprob idxs and probs are top 5 from llama.cpp
ids = example['logprobs']['ids']
probs = example['logprobs']['probs']
if not isinstance(ids[0], list):
ids = [[i] for i in ids]
if not isinstance(probs[0], list):
probs = [[p] for p in probs]
seq_len = len(labels)
top_k = len(ids[0])
pad_count = max(0, seq_len - len(ids))
padded_ids = ids[:seq_len] + [[0] * top_k] * pad_count
padded_probs = probs[:seq_len] + [[0.0] * top_k] * pad_count
return {
'input_ids': torch.tensor(input_enc['input_ids'], dtype=torch.long),
'attention_mask': torch.tensor(input_enc['attention_mask'], dtype=torch.long),
'labels': torch.tensor(labels, dtype=torch.long),
'teacher_topk_ids': torch.tensor(padded_ids, dtype=torch.long),
'teacher_topk_probs': torch.tensor(padded_probs, dtype=torch.float32),
}
logger.info("Processing dataset...")
dataset = dataset.map(tokenize)
dataset = dataset.with_format("torch", columns=["input_ids", "attention_mask", "labels", "teacher_topk_ids", "teacher_topk_probs"])
logger.info("Dataset processing complete")
class DistillationLoss(torch.nn.Module):
def __init__(self):
super().__init__()
self.ce_loss_fn = CrossEntropyLoss(ignore_index=-100)
logger.info("Initialized DistillationLoss")
def forward(self, student_logits, labels, topk_ids, topk_probs, ce_weight=1.0, kl_weight=0.0, rlkd_weight=0.0):
student_logits = student_logits.float()
topk_probs = topk_probs.float()
B, S, V = student_logits.shape
ce_loss = self.ce_loss_fn(student_logits.view(-1, V), labels.view(-1))
kl_loss = 0.0
cosine_loss = 0.0
valid_count = 0
for b in range(B):
for t in range(S):
if labels[b, t] == -100:
continue
teacher_conf = topk_probs[b, t].max()
# Skip low-confidence steps
if teacher_conf < 0.1:
continue
valid_count += 1
# KL Loss
student_log_probs = F.log_softmax(student_logits[b, t], dim=-1)
teacher_dist = torch.zeros_like(student_log_probs)
teacher_dist[topk_ids[b, t]] = topk_probs[b, t]
kl_loss += F.kl_div(student_log_probs, teacher_dist, reduction="batchmean")
# Cosine similarity RLKD
student_probs = F.softmax(student_logits[b, t] / 1.5, dim=-1)
s_vals = student_probs[topk_ids[b, t]]
t_vals = topk_probs[b, t]
cos_sim = F.cosine_similarity(s_vals, t_vals, dim=0)
cosine_loss += 1 - cos_sim
if valid_count > 0:
kl_loss /= valid_count
cosine_loss /= valid_count
total_loss = ce_weight * ce_loss + kl_weight * kl_loss + rlkd_weight * cosine_loss
return total_loss, {"ce_loss": ce_loss.item(), "kl_loss": kl_loss.item(), "rlkd_loss": cosine_loss.item(), "total_loss": total_loss.item(), "rlkd_valid_count": valid_count}
loss_fn = DistillationLoss()
def get_weights(step, warmup_steps=935):
factor = min(step / warmup_steps, 1.0)
return 1.0, 0.2 * factor, 0.1 * factor
output_dir = "./qwen_distilled"
batch_size = 2
grad_accum_steps = 4
learning_rate = 5e-5
log_steps = 1
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
train_dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=data_collator
)
logger.info(f"Created dataloader with {len(train_dataloader)} batches")
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
logger.info("Initialized AdamW optimizer")
model.train()
step = 0
optimizer.zero_grad()
start_time = time.time()
logger.info("Starting training loop")
num_epochs = 3
max_steps = len(train_dataloader)
logger.info(f"""Training configuration:
Output directory: {output_dir}
Batch size: {batch_size}
Gradient accumulation steps: {grad_accum_steps}
Learning rate: {learning_rate}
Max steps: {max_steps}
Log steps: {log_steps}
""")
for epoch in range(num_epochs):
logger.info(f"Starting epoch {epoch+1}/{num_epochs}")
step = 0
for batch_idx, batch in enumerate(train_dataloader):
batch_start = time.time()
batch = {k: v.to("cuda:0") for k, v in batch.items()}
ce_w, kl_w, rlkd_w = get_weights(step)
with torch.autocast("cuda", torch.bfloat16):
outputs = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
loss, logs = loss_fn(outputs.logits, batch["labels"], batch["teacher_topk_ids"], batch["teacher_topk_probs"], ce_weight=ce_w, kl_weight=kl_w, rlkd_weight=rlkd_w)
loss = loss / grad_accum_steps
loss.backward()
if (batch_idx + 1) % grad_accum_steps == 0:
optimizer.step()
optimizer.zero_grad()
step += 1
batch_time = time.time() - batch_start
total_time = time.time() - start_time
wandb.log({
"total_loss": logs["total_loss"],
"ce_loss": logs["ce_loss"],
"kl_loss": logs["kl_loss"],
"rlkd_loss": logs["rlkd_loss"],
"batch_time": batch_time,
"total_time": total_time,
"step": step,
"epoch": epoch + 1,
"rlkd_valid_count": logs["rlkd_valid_count"]
})
if step % log_steps == 0:
logger.info(f"Epoch {epoch+1}/{num_epochs} - Step {step}/{max_steps} "
f"[{batch_time:.2f}s/batch, Total: {total_time:.2f}s] "
f"Loss: {logs['total_loss']:.4f} "
f"(CE: {logs['ce_loss']:.4f}, "
f"KL: {logs['kl_loss']:.4f}, "
f"RLKD: {logs['rlkd_loss']:.4f})")
checkpoint_dir = os.path.join(output_dir, f"checkpoint-epoch-{epoch+1}")
os.makedirs(checkpoint_dir, exist_ok=True)
model.base_model.save_pretrained(checkpoint_dir)
tokenizer.save_pretrained(checkpoint_dir)
logger.info(f"Saved checkpoint for epoch {epoch+1}")
logger.info(f"Training completed in {time.time() - start_time:.2f} seconds")
model.base_model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
logger.info("Saved final model")
logger.info("Training script completed")
wandb.finish()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment