Last active
May 28, 2025 20:06
-
-
Save muellerzr/020fc2e2e39867279e863e2bec9123dd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.data import DataLoader | |
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling | |
from datasets import load_from_disk | |
import torch.nn.functional as F | |
from torch.nn import CrossEntropyLoss | |
import numpy as np | |
import logging | |
import time | |
from datetime import datetime | |
import wandb | |
import sys | |
import os | |
SYSTEM_PROMPT = """Return only the command to be executed as a raw string, no string delimiters wrapping it, no yapping, no markdown, no fenced code blocks, what you return will be passed to the terminal directly. | |
For example, if the user asks "undo last git commit", you return only "git reset --soft HEAD~1" | |
The shell is /bin/bash on linux | |
Use the shortest and most efficient direct command possible. | |
""" | |
# Set up logging | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
# Create console handler with formatting | |
console_handler = logging.StreamHandler(sys.stdout) | |
console_handler.setLevel(logging.INFO) | |
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') | |
console_handler.setFormatter(formatter) | |
# Create file handler with formatting | |
file_handler = logging.FileHandler(f'training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log') | |
file_handler.setLevel(logging.INFO) | |
file_handler.setFormatter(formatter) | |
# Add both handlers to logger | |
logger.addHandler(console_handler) | |
logger.addHandler(file_handler) | |
# Initialize wandb | |
wandb.init( | |
project="qwen3-distillation", | |
config={ | |
"model_name": "/mnt/models/Qwen3-0.6B", | |
"max_length": 2048, | |
"batch_size": 4, | |
"grad_accum_steps": 4, | |
"learning_rate": 5e-5, | |
"num_epochs": 1 | |
} | |
) | |
# Initialize predictions table | |
predictions_table = wandb.Table(columns=["step", "example", "prediction"]) | |
model_name = "/mnt/models/Qwen3-0.6B" | |
logger.info(f"Loading model and tokenizer from {model_name}") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer.pad_token = "<|endoftext|>" | |
tokenizer.eos_token = "<|im_end|>" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
attn_implementation="flash_attention_2", | |
) | |
class StudentWithDropout(torch.nn.Module): | |
def __init__(self, base_model, dropout_prob=0.1): | |
super().__init__() | |
self.base_model = base_model | |
self.dropout = torch.nn.Dropout(dropout_prob) | |
def forward(self, **kwargs): | |
outputs = self.base_model(**kwargs) | |
outputs.logits = self.dropout(outputs.logits) | |
return outputs | |
model = StudentWithDropout(model).to("cuda:0") | |
logger.info("Model and tokenizer loaded successfully") | |
dataset = load_from_disk("scored_data/combined_scored_data.hf") | |
logger.info(f"Loaded dataset with {len(dataset)} examples") | |
MAX_LENGTH = 2048 | |
logger.info(f"Using max sequence length of {MAX_LENGTH}") | |
def tokenize(example): | |
# Define prompt structure matching Qwen3 | |
system_prompt = SYSTEM_PROMPT.strip() | |
user_prompt = example['nl'] | |
thinking = example["thinking"] | |
assistant_response = example['bash'] | |
input_text = ( | |
f"<|im_start|>system\n{system_prompt}\n<|im_end|>\n" | |
f"<|im_start|>user\n{user_prompt}\n<|im_end|>\n" | |
f"<|im_start|>assistant\n" | |
) | |
target_text = ( | |
f"<think>\n{thinking}\n</think>\n" | |
f"{assistant_response}\n<|im_end|>" | |
) | |
input_enc = tokenizer(input_text, truncation=True, padding='max_length', max_length=MAX_LENGTH) | |
target_enc = tokenizer(target_text, truncation=True, padding='max_length', max_length=MAX_LENGTH) | |
labels = target_enc['input_ids'] | |
labels = [l if l != tokenizer.pad_token_id else -100 for l in labels] | |
# Now postprocess to mask trailing '\n' tokens: | |
labels = torch.tensor(labels, dtype=torch.long) | |
# Create a clone to mask trailing newlines without modifying in-place | |
masked_labels = labels.clone() | |
# Mask trailing '\n' tokens | |
for i in reversed(range(len(masked_labels))): | |
token_id = masked_labels[i].item() | |
if token_id < 0: | |
continue # already masked | |
decoded_token = tokenizer.decode([token_id]) | |
if decoded_token == '\n': | |
masked_labels[i] = -100 | |
else: | |
break | |
labels = masked_labels | |
# logprob idxs and probs are top 5 from llama.cpp | |
ids = example['logprobs']['ids'] | |
probs = example['logprobs']['probs'] | |
if not isinstance(ids[0], list): | |
ids = [[i] for i in ids] | |
if not isinstance(probs[0], list): | |
probs = [[p] for p in probs] | |
seq_len = len(labels) | |
top_k = len(ids[0]) | |
pad_count = max(0, seq_len - len(ids)) | |
padded_ids = ids[:seq_len] + [[0] * top_k] * pad_count | |
padded_probs = probs[:seq_len] + [[0.0] * top_k] * pad_count | |
return { | |
'input_ids': torch.tensor(input_enc['input_ids'], dtype=torch.long), | |
'attention_mask': torch.tensor(input_enc['attention_mask'], dtype=torch.long), | |
'labels': torch.tensor(labels, dtype=torch.long), | |
'teacher_topk_ids': torch.tensor(padded_ids, dtype=torch.long), | |
'teacher_topk_probs': torch.tensor(padded_probs, dtype=torch.float32), | |
} | |
logger.info("Processing dataset...") | |
dataset = dataset.map(tokenize) | |
dataset = dataset.with_format("torch", columns=["input_ids", "attention_mask", "labels", "teacher_topk_ids", "teacher_topk_probs"]) | |
logger.info("Dataset processing complete") | |
class DistillationLoss(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.ce_loss_fn = CrossEntropyLoss(ignore_index=-100) | |
logger.info("Initialized DistillationLoss") | |
def forward(self, student_logits, labels, topk_ids, topk_probs, ce_weight=1.0, kl_weight=0.0, rlkd_weight=0.0): | |
student_logits = student_logits.float() | |
topk_probs = topk_probs.float() | |
B, S, V = student_logits.shape | |
ce_loss = self.ce_loss_fn(student_logits.view(-1, V), labels.view(-1)) | |
kl_loss = 0.0 | |
cosine_loss = 0.0 | |
valid_count = 0 | |
for b in range(B): | |
for t in range(S): | |
if labels[b, t] == -100: | |
continue | |
teacher_conf = topk_probs[b, t].max() | |
# Skip low-confidence steps | |
if teacher_conf < 0.1: | |
continue | |
valid_count += 1 | |
# KL Loss | |
student_log_probs = F.log_softmax(student_logits[b, t], dim=-1) | |
teacher_dist = torch.zeros_like(student_log_probs) | |
teacher_dist[topk_ids[b, t]] = topk_probs[b, t] | |
kl_loss += F.kl_div(student_log_probs, teacher_dist, reduction="batchmean") | |
# Cosine similarity RLKD | |
student_probs = F.softmax(student_logits[b, t] / 1.5, dim=-1) | |
s_vals = student_probs[topk_ids[b, t]] | |
t_vals = topk_probs[b, t] | |
cos_sim = F.cosine_similarity(s_vals, t_vals, dim=0) | |
cosine_loss += 1 - cos_sim | |
if valid_count > 0: | |
kl_loss /= valid_count | |
cosine_loss /= valid_count | |
total_loss = ce_weight * ce_loss + kl_weight * kl_loss + rlkd_weight * cosine_loss | |
return total_loss, {"ce_loss": ce_loss.item(), "kl_loss": kl_loss.item(), "rlkd_loss": cosine_loss.item(), "total_loss": total_loss.item(), "rlkd_valid_count": valid_count} | |
loss_fn = DistillationLoss() | |
def get_weights(step, warmup_steps=935): | |
factor = min(step / warmup_steps, 1.0) | |
return 1.0, 0.2 * factor, 0.1 * factor | |
output_dir = "./qwen_distilled" | |
batch_size = 2 | |
grad_accum_steps = 4 | |
learning_rate = 5e-5 | |
log_steps = 1 | |
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) | |
train_dataloader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
collate_fn=data_collator | |
) | |
logger.info(f"Created dataloader with {len(train_dataloader)} batches") | |
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) | |
logger.info("Initialized AdamW optimizer") | |
model.train() | |
step = 0 | |
optimizer.zero_grad() | |
start_time = time.time() | |
logger.info("Starting training loop") | |
num_epochs = 3 | |
max_steps = len(train_dataloader) | |
logger.info(f"""Training configuration: | |
Output directory: {output_dir} | |
Batch size: {batch_size} | |
Gradient accumulation steps: {grad_accum_steps} | |
Learning rate: {learning_rate} | |
Max steps: {max_steps} | |
Log steps: {log_steps} | |
""") | |
for epoch in range(num_epochs): | |
logger.info(f"Starting epoch {epoch+1}/{num_epochs}") | |
step = 0 | |
for batch_idx, batch in enumerate(train_dataloader): | |
batch_start = time.time() | |
batch = {k: v.to("cuda:0") for k, v in batch.items()} | |
ce_w, kl_w, rlkd_w = get_weights(step) | |
with torch.autocast("cuda", torch.bfloat16): | |
outputs = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]) | |
loss, logs = loss_fn(outputs.logits, batch["labels"], batch["teacher_topk_ids"], batch["teacher_topk_probs"], ce_weight=ce_w, kl_weight=kl_w, rlkd_weight=rlkd_w) | |
loss = loss / grad_accum_steps | |
loss.backward() | |
if (batch_idx + 1) % grad_accum_steps == 0: | |
optimizer.step() | |
optimizer.zero_grad() | |
step += 1 | |
batch_time = time.time() - batch_start | |
total_time = time.time() - start_time | |
wandb.log({ | |
"total_loss": logs["total_loss"], | |
"ce_loss": logs["ce_loss"], | |
"kl_loss": logs["kl_loss"], | |
"rlkd_loss": logs["rlkd_loss"], | |
"batch_time": batch_time, | |
"total_time": total_time, | |
"step": step, | |
"epoch": epoch + 1, | |
"rlkd_valid_count": logs["rlkd_valid_count"] | |
}) | |
if step % log_steps == 0: | |
logger.info(f"Epoch {epoch+1}/{num_epochs} - Step {step}/{max_steps} " | |
f"[{batch_time:.2f}s/batch, Total: {total_time:.2f}s] " | |
f"Loss: {logs['total_loss']:.4f} " | |
f"(CE: {logs['ce_loss']:.4f}, " | |
f"KL: {logs['kl_loss']:.4f}, " | |
f"RLKD: {logs['rlkd_loss']:.4f})") | |
checkpoint_dir = os.path.join(output_dir, f"checkpoint-epoch-{epoch+1}") | |
os.makedirs(checkpoint_dir, exist_ok=True) | |
model.base_model.save_pretrained(checkpoint_dir) | |
tokenizer.save_pretrained(checkpoint_dir) | |
logger.info(f"Saved checkpoint for epoch {epoch+1}") | |
logger.info(f"Training completed in {time.time() - start_time:.2f} seconds") | |
model.base_model.save_pretrained(output_dir) | |
tokenizer.save_pretrained(output_dir) | |
logger.info("Saved final model") | |
logger.info("Training script completed") | |
wandb.finish() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment