Last active
April 29, 2025 11:34
-
-
Save KohakuBlueleaf/85dd382427d6c848f541ec4e99cc5c4a to your computer and use it in GitHub Desktop.
10% gsm-8k acc gain within 15min
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## Note | |
## if use vllm in same gpu, remember to set a low gpu_memory_utilization to avoid OOM | |
## For larger model please consider to use multi-GPU or CPU offloading | |
## AnySchedule: https://github.com/KohakuBlueleaf/AnySchedule | |
## LyCORIS: https://github.com/KohakuBlueleaf/LyCORIS | |
## Following code can perform reasonable training on Llama-3.2-1B-Instruct model with GSM8K dataset | |
## With noticable improvement on each reward function | |
from itertools import chain | |
import re | |
import wandb | |
import torch | |
import torch.utils.data as data | |
from tqdm import tqdm | |
from anyschedule import AnySchedule | |
from datasets import load_dataset, Dataset | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
) | |
from vllm import LLM, SamplingParams | |
from lycoris import create_lycoris, LycorisNetwork | |
from utils import patch_lm_head_for_logp, unpatch_lm_head_for_logp | |
model_name = "Qwen/Qwen2.5-0.5B-Instruct" | |
## Constants | |
GEN_CONFIG = SamplingParams(max_tokens=1024, temperature=0.75, min_p=0.1) | |
DTYPE = torch.bfloat16 | |
DEVICE = "cuda" | |
EPOCHS = 1 | |
MAX_LEN = 1024 | |
MAX_STEPS = 128 | |
GRAD_CKPT = True | |
LR = 5e-6 | |
# GRPO Hyperparameters | |
answer_r = 2.5 | |
format_r = 1.0 | |
length_penalty = 0.001 | |
exceed_penalty = 0.005 | |
BATCH_SIZE = 8 | |
MINI_BATCH_SIZE = 8 | |
GROUP_SIZE = 16 | |
GRPO_ITER = 4 | |
ETA = 0.5 | |
BETA = 0.25 | |
wandb.init( | |
project="HakuGRPO", | |
name="GRPO-QWen-2.5-0.5B-vllm", | |
# mode="offline", | |
) | |
wandb.config.update( | |
{ | |
"learning_rate": LR, | |
"batch_size": BATCH_SIZE, | |
"mini_batch_size": MINI_BATCH_SIZE, | |
"group_size": GROUP_SIZE, | |
"grpo_iter": GRPO_ITER, | |
"eta": ETA, | |
"beta": BETA, | |
} | |
) | |
wandb.run.log_code(".") | |
## Define the GSM8K dataset and reward functions | |
## Modified from: | |
## https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb | |
SYSTEM_PROMPT = """Following this Format, \ | |
Don't forget <response> and don't put anything outside the format: | |
--- | |
Your Step by Step thinking | |
<response> | |
Your Final Answer or Response | |
</response> | |
""" | |
# FAKE_SYSTEM_PROMPT = """ | |
# You are a human, you can think and answer in your own way. | |
# """.strip() | |
FAKE_SYSTEM_PROMPT = SYSTEM_PROMPT | |
XML_COT_FORMAT = """{thinking} | |
<response> | |
{answer} | |
</response> | |
""" | |
def extract_xml_answer(text: str) -> str: | |
answer = text.split("<response>")[-1] | |
# answer = answer.split("</response>")[0] | |
if answer.count("</response>") >= 1: | |
answer = answer.split("</response>")[0] | |
else: | |
return "" | |
# take last number in answer region as answer | |
# Allow more verbose answer, like $300 or "Alex need to buy 100 apples" (for example) | |
number_answer = re.findall(r"\d+\.\d+", answer) | |
if number_answer: | |
answer = number_answer[-1] | |
int_answer = re.findall(r"\d+", answer) | |
if int_answer: | |
answer = int_answer[-1] | |
return answer.strip() | |
def extract_hash_answer(text: str) -> str | None: | |
if "####" not in text: | |
return None | |
return text.split("####")[1].strip().replace(",", "") | |
def get_gsm8k_questions(split="train") -> Dataset: | |
data = load_dataset("openai/gsm8k", "main")[split] | |
data = data.map( | |
lambda x: { | |
"question": x["question"], | |
"model_input": tokenizer.apply_chat_template( | |
[ | |
{"role": "system", "content": SYSTEM_PROMPT}, | |
# In-Context Learning prompt, somehow necessary for smaller model | |
{"role": "user", "content": "Which is bigger, 9.11 or 9.8?"}, | |
{ | |
"role": "assistant", | |
"content": XML_COT_FORMAT.format( | |
thinking=( | |
"User asking me to compare 9.11 and 9.8, to compare two numbers, " | |
"we need to look at the integer part first, " | |
"9 is the same in both numbers, " | |
"then we look at the decimal part, " | |
"0.11 is smaller than 0.8, so 9.11 is smaller than 9.8." | |
), | |
answer="9.8", | |
), | |
}, | |
{"role": "user", "content": x["question"]}, | |
], | |
tokenize=False, | |
add_generation_prompt=True, | |
), | |
"answer": extract_hash_answer(x["answer"]), | |
} | |
) | |
return data | |
## reward functions | |
def correctness_reward_func(completions, answers, **kwargs) -> list[float]: | |
extracted_responses = [extract_xml_answer(r) for r in completions] | |
return [answer_r if r == a else 0.0 for r, a in zip(extracted_responses, answers)] | |
def format_reward(text) -> float: | |
count = 0.0 | |
if text.count("<response>") == 1: | |
if text.split("<response>")[-1].count("</response>") == 1: | |
count += format_r | |
thinking_length = len(text.split("<response>")[0]) | |
if thinking_length < 500: | |
count += thinking_length * length_penalty | |
else: | |
count -= (thinking_length - 500) * length_penalty | |
count -= len(text.split("</response>")[1]) * exceed_penalty | |
return count | |
def format_reward_func(completions, **kwargs) -> list[float]: | |
return [format_reward(c) for c in completions] | |
def total_reward(completions, **kwargs) -> list[float]: | |
rewards = [ | |
correctness_reward_func(completions, **kwargs), | |
format_reward_func(completions, **kwargs), | |
] | |
return [sum(r) for r in zip(*rewards)], rewards | |
## GRPO Utilities | |
def get_advantages(completions, group_size=GROUP_SIZE, **kwargs) -> list[float]: | |
""" | |
completions: list of group completions | |
""" | |
rewards, rewards_list = total_reward(completions, **kwargs) | |
rewards = torch.tensor(rewards, dtype=DTYPE, device=DEVICE) | |
rewards = rewards.view(-1, group_size) | |
mean_group = rewards.mean(dim=1) | |
std_group = rewards.std(dim=1) | |
mean_group = torch.repeat_interleave(mean_group, group_size, dim=0) | |
std_group = torch.repeat_interleave(std_group, group_size, dim=0) | |
advantages = (rewards.flatten() - mean_group) / (std_group + 1e-4) | |
return advantages, rewards.flatten(), rewards_list | |
def logp_per_token(model, inputs, chunked=False): | |
input_ids = inputs["input_ids"][:, 1:] | |
if chunked: | |
patch_lm_head_for_logp(model.lm_head, input_ids, chunk_size=1024) | |
with torch.autocast("cuda", DTYPE): | |
logp_per_token = model(**inputs, return_dict=False)[0] | |
unpatch_lm_head_for_logp(model.lm_head) | |
return logp_per_token | |
else: | |
with torch.autocast("cuda", DTYPE): | |
output = model(**inputs, return_dict=False)[0] | |
return _logp_per_token(output[:, :-1, :], input_ids) | |
@torch.compile | |
def _logp_per_token(logits, input_ids): | |
""" | |
logits: [B, L, V] | |
input_ids: [B, L] | |
L is dynamic | |
""" | |
choosed_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) | |
logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits]) | |
logp = choosed_logits - logsumexp_values | |
return logp | |
## Training Setup | |
inf_model = LLM( | |
model_name, | |
gpu_memory_utilization=0.15, | |
max_seq_len_to_capture=1024, | |
max_model_len=1024, | |
max_num_seqs=128, | |
) | |
inf_model.llm_engine.model_executor.driver_worker.model_runner.model.requires_grad_( | |
False | |
) | |
model = ( | |
AutoModelForCausalLM.from_pretrained( | |
model_name, | |
attn_implementation="flash_attention_2", | |
) | |
.requires_grad_(False) | |
.eval() | |
.to(DTYPE) | |
.to(DEVICE) | |
) | |
def vllm_load_state_dict(llm: LLM, state_dict, skip_dicts): | |
llm.llm_engine.model_executor.driver_worker.model_runner.model.load_weights( | |
tuple((k, v) for k, v in state_dict.items() if k not in skip_dicts) | |
) | |
def vllm_model_mapping(inf_model, hf_model): | |
## Specifically for qwen, may not work for other arch( need custom impl) | |
vllm_model = inf_model.llm_engine.model_executor.driver_worker.model_runner.model | |
vllm_state_dict = vllm_model.state_dict() | |
config = vllm_model.config | |
stacked_params_mapping = [ | |
# (param_name, shard_name, shard_id) | |
("qkv_proj", "q_proj", "q"), | |
("qkv_proj", "k_proj", "k"), | |
("qkv_proj", "v_proj", "v"), | |
("gate_up_proj", "gate_proj", 0), | |
("gate_up_proj", "up_proj", 1), | |
] | |
for key, param in hf_model.state_dict().items(): | |
if key in vllm_state_dict: | |
param.data = vllm_state_dict[key].data | |
else: | |
for param_name, shard_name, shard_id in stacked_params_mapping: | |
try_key = key.replace(shard_name, param_name) | |
if try_key in vllm_state_dict: | |
current_param = vllm_state_dict[try_key] | |
out_ch = current_param.shape[0] | |
if isinstance(shard_id, int): | |
target_ch = out_ch // 2 | |
param.data = current_param.data[ | |
target_ch * shard_id : target_ch * (shard_id + 1) | |
] | |
else: | |
q_heads = config.num_attention_heads | |
kv_heads = config.num_key_value_heads | |
total = q_heads + kv_heads * 2 | |
dims = { | |
"q": (0, out_ch * q_heads // total), | |
"k": ( | |
out_ch * q_heads // total, | |
out_ch * kv_heads // total, | |
), | |
"v": ( | |
out_ch * q_heads // total + out_ch * kv_heads // total, | |
out_ch * kv_heads // total, | |
), | |
}[shard_id] | |
param.data = current_param.data[dims[0] : dims[0] + dims[1]] | |
break | |
else: | |
pass | |
skip_dicts = vllm_model_mapping(inf_model, model) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
if getattr(tokenizer, "pad_token", None) is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
GEN_CONFIG.pad_token_id = tokenizer.pad_token_id | |
LycorisNetwork.apply_preset( | |
{ | |
"target_module": [], | |
"target_name": [ | |
".*proj.*", | |
".*norm.*", | |
".*lm_head" | |
], | |
} | |
) | |
lycoris_net = create_lycoris( | |
model, | |
algo="lokr", | |
full_matrix=True, | |
factor=8, | |
bypass_mode=False, | |
train_norm=True, | |
) | |
lycoris_net.apply_to() | |
lycoris_net.to(DEVICE) | |
print( | |
sum(p.numel() for p in lycoris_net.parameters() if p.requires_grad) / 1e6, | |
"M trainable parameters", | |
) | |
optimizer = torch.optim.AdamW(lycoris_net.parameters(), lr=LR, betas=(0.8, 0.98)) | |
scheduler = AnySchedule( | |
optimizer, | |
config={ | |
# Value is factor for corresponding hyperparam | |
"lr": { | |
"mode": "cosine", | |
"end": MAX_STEPS, | |
"value": 1.0, | |
"min_value": 0.05, | |
# "warmup": int(MAX_STEPS * 0.15), | |
} | |
}, | |
) | |
dataset = get_gsm8k_questions() | |
dataloader = data.DataLoader( | |
dataset, | |
batch_size=BATCH_SIZE, | |
shuffle=True, | |
collate_fn=lambda x: {k: [b[k] for b in x] for k in x[0].keys()}, | |
) | |
torch.cuda.empty_cache() | |
## Main Training Loop | |
## GRPO Quick Note: https://hackmd.io/@KBlueLeaf/SJ5OZ00_1g | |
pbar = tqdm( | |
range(len(dataloader) * EPOCHS * GRPO_ITER), desc="GRPO Training", smoothing=0.01 | |
) | |
rewards_ema = 0 | |
kl_div_ema = 0 | |
loss_ema = 0 | |
for epoch in range(EPOCHS): | |
for idx, batch in enumerate(dataloader): | |
## Sample entries from dataset | |
lycoris_net.restore() | |
questions = list(chain(*([q] * GROUP_SIZE for q in batch["question"]))) | |
answers = list(chain(*([a] * GROUP_SIZE for a in batch["answer"]))) | |
text_inputs = list(chain(*[[i] * GROUP_SIZE for i in batch["model_input"]])) | |
## Generate group results | |
torch.cuda.empty_cache() | |
lycoris_net.onfly_merge() | |
vllm_load_state_dict(inf_model, model.state_dict(), {}) | |
vllm_model_mapping(inf_model, model) | |
outputs = inf_model.generate( | |
text_inputs, sampling_params=GEN_CONFIG, use_tqdm=False | |
) | |
completions = [output.outputs[0].text for output in outputs] | |
lycoris_net.onfly_restore() | |
torch.cuda.empty_cache() | |
## Compute group rewards | |
advantages, rewards, rewards_list = get_advantages( | |
completions, | |
group_size=GROUP_SIZE, | |
answers=answers, | |
questions=questions, | |
) | |
total_rewards = rewards.mean().item() | |
ema_decay = min((pbar.n / GRPO_ITER) / (pbar.n / GRPO_ITER + 1), 0.99) | |
rewards_ema = rewards_ema * ema_decay + total_rewards * (1 - ema_decay) | |
## Log completions | |
table = wandb.Table(columns=["question", "answer", "model output", "reward"]) | |
for idx, (q, a, c, r) in enumerate( | |
zip(questions, answers, completions, rewards) | |
): | |
if idx % GROUP_SIZE == 0: | |
table.add_data(q, a, c, r.item()) | |
wandb.log({"completions": table}, commit=False) | |
## Compute reference log-prob | |
final_inputs_len = [ | |
len( | |
tokenizer.apply_chat_template( | |
[ | |
{"role": "system", "content": FAKE_SYSTEM_PROMPT}, | |
{"role": "user", "content": q}, | |
] | |
) | |
) | |
for q in questions | |
] | |
final_model_input = [ | |
tokenizer.apply_chat_template( | |
[ | |
{"role": "system", "content": FAKE_SYSTEM_PROMPT}, | |
{"role": "user", "content": q}, | |
{"role": "assistant", "content": c}, | |
], | |
tokenize=False, | |
) | |
for q, c in zip(questions, completions) | |
] | |
inputs = tokenizer( | |
final_model_input, | |
return_tensors="pt", | |
padding="max_length", | |
max_length=MAX_LEN, | |
padding_side="left", | |
truncation=True, | |
) | |
final_total_len = (inputs["attention_mask"].sum(dim=1) - 1).tolist() | |
output_len = [ | |
(total - inp_len - 1) | |
if total < MAX_LEN | |
else (min(MAX_LEN, total - inp_len) - 1) | |
for inp_len, total in zip(final_inputs_len, final_total_len) | |
] | |
inputs["input_ids"] = inputs["input_ids"][:, -MAX_LEN:].to(DEVICE) | |
inputs["attention_mask"] = inputs["attention_mask"][:, -MAX_LEN:].to(DEVICE) | |
token_with_loss = sum(output_len) | |
with torch.inference_mode(): | |
logp_refs = [None] * len(inputs["input_ids"]) | |
for i in range(0, len(inputs["input_ids"]), MINI_BATCH_SIZE): | |
current_inputs = { | |
k: v[i : i + MINI_BATCH_SIZE].to(DEVICE) for k, v in inputs.items() | |
} | |
logp_refs[i] = logp_per_token(model, current_inputs) | |
torch.cuda.empty_cache() | |
lycoris_net.apply_to() | |
## GRPO Iterations | |
if GRAD_CKPT: | |
model.gradient_checkpointing_enable() | |
logp_old = [None] * len(inputs["input_ids"]) | |
for _ in range(GRPO_ITER): | |
step_loss = 0 | |
step_kl_div = 0 | |
step_objective = 0 | |
steps = 0 | |
for i in range(0, len(inputs["input_ids"]), MINI_BATCH_SIZE): | |
steps += 1 | |
current_inputs = { | |
k: v[i : i + MINI_BATCH_SIZE].to(DEVICE).clone() | |
for k, v in inputs.items() | |
} | |
with torch.enable_grad(): | |
logp_proxy = logp_per_token(model, current_inputs, chunked=True) | |
if logp_old[i] is None: | |
logp_old[i] = logp_proxy.detach().clone() | |
kl_div = ( | |
torch.exp(logp_refs[i] - logp_proxy) | |
- (logp_refs[i] - logp_proxy) | |
- 1 | |
) | |
adv_scale = torch.exp(logp_proxy - logp_old[i]) | |
clipped_scale = torch.clamp(adv_scale, 1 - ETA, 1 + ETA) | |
scaled_advantages = ( | |
advantages[i : i + MINI_BATCH_SIZE, None] * adv_scale | |
) | |
clipped_advantages = ( | |
advantages[i : i + MINI_BATCH_SIZE, None] * clipped_scale | |
) | |
mask = scaled_advantages < clipped_advantages | |
final_advantages = torch.where( | |
mask, scaled_advantages, clipped_advantages | |
) | |
per_token_loss = -final_advantages + BETA * kl_div | |
loss_mask = current_inputs["attention_mask"][:, :-1] | |
for idx, length in enumerate(output_len[i : i + MINI_BATCH_SIZE]): | |
loss_mask[idx, : -(length - 1)] = 0 | |
loss = ( | |
torch.sum( | |
torch.where( | |
loss_mask.bool(), | |
per_token_loss, | |
torch.zeros_like(per_token_loss), | |
) | |
) | |
/ token_with_loss | |
) | |
loss.backward() | |
step_kl_div += (kl_div * loss_mask).sum().float().item() | |
step_loss += loss.float().item() | |
step_objective += ( | |
((final_advantages - kl_div * BETA) * loss_mask) | |
.sum() | |
.float() | |
.item() | |
) | |
ema_decay = min((pbar.n) / (pbar.n + 1), 0.999) | |
step_loss = step_loss | |
step_kl_div = step_kl_div / token_with_loss | |
step_objective = step_objective / token_with_loss | |
loss_ema = loss_ema * ema_decay + step_loss * (1 - ema_decay) | |
kl_div_ema = kl_div_ema * ema_decay + step_kl_div * (1 - ema_decay) | |
pbar.set_postfix( | |
{ | |
"rewards": rewards.float().mean().item(), | |
"loss": step_loss, | |
"kl_div": step_kl_div, | |
"objective": step_objective, | |
} | |
) | |
wandb.log( | |
{ | |
"rewards": rewards.float().mean().item(), | |
"correctness_reward": torch.mean(torch.tensor(rewards_list[0])), | |
"format_reward": torch.mean(torch.tensor(rewards_list[1])), | |
"loss": step_loss, | |
"kl_div": step_kl_div, | |
"objective": step_objective, | |
"rewards_ema": rewards_ema, | |
"loss_ema": loss_ema, | |
"kl_div_ema": kl_div_ema, | |
}, | |
commit=False, | |
) | |
torch.nn.utils.clip_grad_norm_(lycoris_net.parameters(), 0.1) | |
optimizer.step() | |
optimizer.zero_grad() | |
pbar.update() | |
wandb.log({"global_step": pbar.n}) | |
if pbar.n >= MAX_STEPS: | |
break | |
scheduler.step() | |
model.gradient_checkpointing_disable() | |
if pbar.n >= MAX_STEPS: | |
break | |
if pbar.n >= MAX_STEPS: | |
break | |
pbar.close() | |
torch.cuda.empty_cache() | |
## Test after Training | |
lycoris_net.restore() | |
lycoris_net.merge_to() | |
model.save_pretrained(model_name.split("/")[-1] + "-grpo") | |
tokenizer.save_pretrained(model_name.split("/")[-1] + "-grpo") | |
vllm_load_state_dict(inf_model, model.state_dict(), {}) | |
GEN_CONFIG = SamplingParams(max_tokens=4096, temperature=0.25, min_p=0.1) | |
generations = inf_model.generate([i for i in batch["model_input"]], sampling_params=GEN_CONFIG) | |
completions = [ | |
output.outputs[0].text for output in generations | |
] | |
for q, a, c in zip(batch["question"], batch["answer"], completions): | |
print("=" * 30) | |
print("Question:") | |
print(q) | |
print("-" * 10) | |
print("Answer:") | |
print(a) | |
print("-" * 20) | |
print("Completion:") | |
print(c) | |
print("=" * 30) | |
print() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## | |
# GSM-8K test script | |
# tested results: | |
# - Qwen2.5-0.5B-instruct (with ICL prompt) | |
# before grpo: 0.429 | |
# after grpo : 0.522 | |
# - Qwen2.5-1.5B-instruct (without ICL prompt) | |
# before grpo: 0.583 | |
# after grpo : 0.677 | |
## | |
import re | |
import torch | |
import torch.utils.data as data | |
from tqdm import tqdm | |
from datasets import load_dataset, Dataset | |
from transformers import AutoTokenizer | |
from vllm import LLM, SamplingParams | |
# model_name = "Qwen/Qwen2.5-0.5B-Instruct" | |
# model_name = "./Qwen2.5-0.5B-Instruct-grpo" | |
model_name = "Qwen/Qwen2.5-1.5B-Instruct" | |
# model_name = "./Qwen2.5-1.5B-Instruct-grpo" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
## Constants | |
GEN_CONFIG = SamplingParams(max_tokens=4096, temperature=0.25, min_p=0.1) | |
DTYPE = torch.bfloat16 | |
DEVICE = "cuda" | |
## Define the GSM8K dataset and reward functions | |
## Modified from: | |
## https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb | |
SYSTEM_PROMPT = """Following this Format, \ | |
Don't forget <response> and don't put anything outside the format: | |
--- | |
Your Step by Step thinking | |
<response> | |
Your Final Answer or Response | |
</response> | |
""" | |
# FAKE_SYSTEM_PROMPT = """ | |
# You are a human, you can think and answer in your own way. | |
# """.strip() | |
FAKE_SYSTEM_PROMPT = SYSTEM_PROMPT | |
XML_COT_FORMAT = """{thinking} | |
<response> | |
{answer} | |
</response> | |
""" | |
def extract_xml_answer(text: str) -> str: | |
answer = text.split("<response>")[-1] | |
answer = answer.split("</response>")[0] | |
# take last number in answer region as answer | |
# Allow more verbose answer, like $300 or "Alex need to buy 100 apples" (for example) | |
number_answer = re.findall(r"\d+\.\d+", answer) | |
if number_answer: | |
answer = number_answer[-1] | |
int_answer = re.findall(r"\d+", answer) | |
if int_answer: | |
answer = int_answer[-1] | |
return answer.strip() | |
def extract_hash_answer(text: str) -> str | None: | |
if "####" not in text: | |
return None | |
return text.split("####")[1].strip().replace(",", "") | |
def get_gsm8k_questions(split="train") -> Dataset: | |
data = load_dataset("openai/gsm8k", "main")[split] | |
data = data.map( | |
lambda x: { | |
"question": x["question"], | |
"model_input": tokenizer.apply_chat_template( | |
[ | |
{"role": "system", "content": SYSTEM_PROMPT}, | |
# In-Context Learning prompt, somehow necessary for smaller model | |
# {"role": "user", "content": "Which is bigger, 9.11 or 9.8?"}, | |
# { | |
# "role": "assistant", | |
# "content": XML_COT_FORMAT.format( | |
# thinking=( | |
# "User asking me to compare 9.11 and 9.8, to compare two numbers, " | |
# "we need to look at the integer part first, " | |
# "9 is the same in both numbers, " | |
# "then we look at the decimal part, " | |
# "0.11 is smaller than 0.8, so 9.11 is smaller than 9.8." | |
# ), | |
# answer="9.8", | |
# ), | |
# }, | |
{"role": "user", "content": x["question"]}, | |
], | |
tokenize=False, | |
add_generation_prompt=True, | |
), | |
"answer": extract_hash_answer(x["answer"]), | |
} | |
) | |
return data | |
## reward functions | |
def correctness_reward_func(completions, answers, **kwargs) -> list[float]: | |
extracted_responses = [extract_xml_answer(r) for r in completions] | |
return [1 if r == a else 0 for r, a in zip(extracted_responses, answers)] | |
## Training Setup | |
inf_model = LLM( | |
model_name, | |
gpu_memory_utilization=0.85, | |
max_seq_len_to_capture=4096, | |
max_model_len=4096, | |
max_num_seqs=256, | |
) | |
dataset = get_gsm8k_questions() | |
dataloader = data.DataLoader( | |
dataset, | |
batch_size=8192, | |
shuffle=True, | |
collate_fn=lambda x: {k: [b[k] for b in x] for k in x[0].keys()}, | |
) | |
total_correct = [] | |
for idx, batch in enumerate(dataloader): | |
## Sample entries from dataset | |
questions = batch["question"] | |
answers = batch["answer"] | |
text_inputs = batch["model_input"] | |
## Generate group results | |
outputs = inf_model.generate(text_inputs, sampling_params=GEN_CONFIG) | |
completions = [output.outputs[0].text for output in outputs] | |
correct = correctness_reward_func(completions, answers) | |
total_correct.extend(correct) | |
print(f"Accuracy: {sum(total_correct) / len(total_correct)}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
def _logp_per_token(logits, input_ids): | |
""" | |
logits: [B, L, V] | |
input_ids: [B, L] | |
L is dynamic | |
""" | |
choosed_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) | |
logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits]) | |
logp = choosed_logits - logsumexp_values | |
return logp | |
class FusedLinearChunkedLogProb(torch.autograd.Function): | |
""" | |
A fused implementation of hidden -> lm_head -> log_probs per token. | |
to avoid pytorch to cache the logits which is super duper LARGE. | |
""" | |
@staticmethod | |
def forward( | |
ctx, | |
hidden_state, | |
lm_head_weight, | |
lm_head_bias, | |
input_ids, | |
chunk_size: int = 128, | |
): | |
""" | |
Calculates log probability per token in a chunked manner. | |
Args: | |
# logits (torch.Tensor): [B, L, V] tensor of logits. | |
hidden_state (torch.Tensor): [B, L, H] tensor of hidden states. | |
lm_head_weight (torch.Tensor): [V, H] tensor of lm head weights. | |
lm_head_bias (torch.Tensor): [V] tensor of lm head biases. | |
input_ids (torch.Tensor): [B, L] tensor of token indices. | |
chunk_size (Optional[int]): Size of chunks for processing. If None, defaults to batch size. | |
Returns: | |
torch.Tensor: [B, L] tensor of log probabilities per token. | |
""" | |
B, L, _ = hidden_state.shape | |
ctx.chunk_size = chunk_size | |
logp_chunks = [] | |
@torch.compile | |
def chunk_forward(hidden_chunk, lm_head_weight, lm_head_bias, input_ids_chunk): | |
logits_chunk = F.linear(hidden_chunk, lm_head_weight, lm_head_bias) | |
choosed_logits_chunk = logits_chunk.gather( | |
dim=-1, index=input_ids_chunk.unsqueeze(-1) | |
).squeeze(-1) | |
logsumexp_values_chunk = torch.logsumexp(logits_chunk, dim=-1) | |
logp_chunk = choosed_logits_chunk - logsumexp_values_chunk | |
return logp_chunk | |
for i in range(0, B): | |
logp_chunks_ele = [] | |
for j in range(0, L, chunk_size): | |
hidden_chunk = hidden_state[i, j : j + chunk_size] | |
input_ids_chunk = input_ids[i, j : j + chunk_size] | |
logp_chunk = chunk_forward( | |
hidden_chunk, lm_head_weight, lm_head_bias, input_ids_chunk | |
) | |
logp_chunks_ele.append(logp_chunk) | |
logp_chunks.append(torch.cat(logp_chunks_ele, dim=0)) | |
# save for backward is moved to the end to utilize the calculation within forward. | |
ctx.save_for_backward( | |
hidden_state, | |
input_ids, | |
lm_head_weight, | |
lm_head_bias if lm_head_bias is not None else None, | |
) # Save for backward. Offload if needed. | |
return torch.stack(logp_chunks, dim=0) | |
@staticmethod | |
def backward(ctx, grad_output): | |
""" | |
Calculates gradients in a chunked manner. | |
Args: | |
grad_output (torch.Tensor): [B, L] tensor of incoming gradients. | |
Returns: | |
Tuple[torch.Tensor, None, None]: Gradients w.r.t. logits, input_ids (None), and chunk_size (None). | |
""" | |
hidden_state: torch.Tensor | |
hidden_state, input_ids, lm_head_weight, lm_head_bias = ctx.saved_tensors | |
chunk_size = ctx.chunk_size | |
B, L, _ = hidden_state.shape | |
h_dtype = hidden_state.dtype | |
w_dtype = lm_head_weight.dtype | |
grad_hidden = torch.zeros_like(hidden_state, device=grad_output.device) | |
grad_weight = torch.zeros_like(lm_head_weight, device=grad_output.device) | |
if lm_head_bias is not None: | |
grad_bias = torch.zeros_like(lm_head_bias, device=grad_output.device) | |
else: | |
grad_bias = None | |
@torch.compile | |
def chunk_backward( | |
hidden_chunk, | |
lm_head_weight, | |
lm_head_bias, | |
input_ids_chunk, | |
grad_output_chunk, | |
): | |
logits_chunk = F.linear(hidden_chunk, lm_head_weight, lm_head_bias) | |
# Calculate softmax probabilities (efficient for gradient calculation). | |
probs_chunk = F.softmax(logits_chunk, dim=-1) | |
# 1. Gradient from the logsumexp term (d/dx_i logsumexp(x) = exp(x_i) / sum(exp(x_j)) = softmax(x)_i) | |
grad_logits_chunk = -probs_chunk * grad_output_chunk.unsqueeze(-1) | |
# 2. Gradient from the chosen logits term | |
index_for_scatter = input_ids_chunk.unsqueeze(-1) # Reshape for scatter_add | |
grad_logits_chunk.scatter_add_( | |
dim=-1, index=index_for_scatter, src=grad_output_chunk.unsqueeze(-1) | |
) | |
grad_hidden_chunk = torch.matmul( | |
grad_logits_chunk.to(h_dtype), lm_head_weight.to(h_dtype) | |
) | |
grad_weight_chunk = torch.matmul( | |
grad_logits_chunk.transpose(0, 1).to(w_dtype), hidden_chunk.to(w_dtype) | |
) | |
if lm_head_bias is not None: | |
grad_bias_chunk = torch.sum(grad_logits_chunk.to(w_dtype), dim=0) | |
else: | |
grad_bias_chunk = None | |
return grad_hidden_chunk, grad_weight_chunk, grad_bias_chunk | |
for i in range(0, B): | |
for j in range(0, L, chunk_size): | |
hidden_chunk = hidden_state[i, j : j + chunk_size].to( | |
grad_output.device | |
) | |
input_ids_chunk = input_ids[i, j : j + chunk_size] | |
grad_output_chunk = grad_output[i, j : j + chunk_size] | |
grad_hidden_chunk, grad_weight_chunk, grad_bias_chunk = chunk_backward( | |
hidden_chunk, | |
lm_head_weight, | |
lm_head_bias, | |
input_ids_chunk, | |
grad_output_chunk, | |
) | |
del hidden_chunk | |
grad_hidden[i, j : j + chunk_size] += grad_hidden_chunk | |
grad_weight += grad_weight_chunk | |
if lm_head_bias is not None: | |
grad_bias += grad_bias_chunk | |
return grad_hidden, grad_weight, grad_bias, None, None | |
def _linear_with_chunked_logp(x, lm_head, input_ids, chunk_size: int = 128): | |
return FusedLinearChunkedLogProb.apply( | |
x, lm_head.weight, lm_head.bias, input_ids, chunk_size | |
) | |
def patch_lm_head_for_logp(lm_head, logit_ids, chunk_size: int = 128): | |
lm_head.org_forward = lm_head.forward | |
def new_forward(x, *args, **kwargs): | |
return _linear_with_chunked_logp(x[:, :-1], lm_head, logit_ids, chunk_size) | |
lm_head.forward = new_forward | |
def unpatch_lm_head_for_logp(lm_head): | |
lm_head.forward = getattr(lm_head, "org_forward", lm_head.forward) | |
delattr(lm_head, "org_forward") | |
if __name__ == "__main__": | |
x = torch.randn(4, 1024, 1024).cuda() | |
lm_head = torch.nn.Linear(1024, 128000).cuda() | |
input_ids = torch.randint(0, 128000, (4, 1024)).cuda() | |
logits = lm_head(x) | |
input_ids = torch.randint(0, 128000, (4, 1024)).cuda() | |
logp = _logp_per_token(logits, input_ids) | |
logp_chunked = _linear_with_chunked_logp(x, lm_head, input_ids, chunk_size=128) | |
print(torch.allclose(logp, logp_chunked)) | |
x = x.requires_grad_(True) | |
logits = lm_head(x) | |
logp = _logp_per_token(logits, input_ids) | |
logp.sum().backward() | |
first_grad_x = x.grad.clone() | |
first_grad_w = lm_head.weight.grad.clone() | |
first_grad_b = lm_head.bias.grad.clone() | |
x.grad = None | |
lm_head.weight.grad = None | |
lm_head.bias.grad = None | |
logp_chunked = _linear_with_chunked_logp(x, lm_head, input_ids, chunk_size=128) | |
logp_chunked.sum().backward() | |
second_grad_x = x.grad.clone() | |
second_grad_w = lm_head.weight.grad.clone() | |
second_grad_b = lm_head.bias.grad.clone() | |
x.grad = None | |
lm_head.weight.grad = None | |
lm_head.bias.grad = None | |
print(torch.allclose(first_grad_x, second_grad_x)) | |
print( | |
F.mse_loss(first_grad_x, second_grad_x), | |
torch.max(torch.abs(first_grad_x - second_grad_x)), | |
) | |
print(torch.allclose(first_grad_w, second_grad_w)) | |
print( | |
F.mse_loss(first_grad_w, second_grad_w), | |
torch.max(torch.abs(first_grad_w - second_grad_w)), | |
) | |
print(torch.allclose(first_grad_b, second_grad_b)) | |
print( | |
F.mse_loss(first_grad_b, second_grad_b), | |
torch.max(torch.abs(first_grad_b - second_grad_b)), | |
) |
Current version can reach high reward easily with < 100 step training.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Since my training on this script keep getting bad, I don't have any idea if this implementation is correct or not
if you find anything weird, please notice me.