Skip to content

Instantly share code, notes, and snippets.

@KohakuBlueleaf
Last active April 29, 2025 11:34
Show Gist options
  • Save KohakuBlueleaf/85dd382427d6c848f541ec4e99cc5c4a to your computer and use it in GitHub Desktop.
Save KohakuBlueleaf/85dd382427d6c848f541ec4e99cc5c4a to your computer and use it in GitHub Desktop.
10% gsm-8k acc gain within 15min
## Note
## if use vllm in same gpu, remember to set a low gpu_memory_utilization to avoid OOM
## For larger model please consider to use multi-GPU or CPU offloading
## AnySchedule: https://github.com/KohakuBlueleaf/AnySchedule
## LyCORIS: https://github.com/KohakuBlueleaf/LyCORIS
## Following code can perform reasonable training on Llama-3.2-1B-Instruct model with GSM8K dataset
## With noticable improvement on each reward function
from itertools import chain
import re
import wandb
import torch
import torch.utils.data as data
from tqdm import tqdm
from anyschedule import AnySchedule
from datasets import load_dataset, Dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)
from vllm import LLM, SamplingParams
from lycoris import create_lycoris, LycorisNetwork
from utils import patch_lm_head_for_logp, unpatch_lm_head_for_logp
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
## Constants
GEN_CONFIG = SamplingParams(max_tokens=1024, temperature=0.75, min_p=0.1)
DTYPE = torch.bfloat16
DEVICE = "cuda"
EPOCHS = 1
MAX_LEN = 1024
MAX_STEPS = 128
GRAD_CKPT = True
LR = 5e-6
# GRPO Hyperparameters
answer_r = 2.5
format_r = 1.0
length_penalty = 0.001
exceed_penalty = 0.005
BATCH_SIZE = 8
MINI_BATCH_SIZE = 8
GROUP_SIZE = 16
GRPO_ITER = 4
ETA = 0.5
BETA = 0.25
wandb.init(
project="HakuGRPO",
name="GRPO-QWen-2.5-0.5B-vllm",
# mode="offline",
)
wandb.config.update(
{
"learning_rate": LR,
"batch_size": BATCH_SIZE,
"mini_batch_size": MINI_BATCH_SIZE,
"group_size": GROUP_SIZE,
"grpo_iter": GRPO_ITER,
"eta": ETA,
"beta": BETA,
}
)
wandb.run.log_code(".")
## Define the GSM8K dataset and reward functions
## Modified from:
## https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb
SYSTEM_PROMPT = """Following this Format, \
Don't forget <response> and don't put anything outside the format:
---
Your Step by Step thinking
<response>
Your Final Answer or Response
</response>
"""
# FAKE_SYSTEM_PROMPT = """
# You are a human, you can think and answer in your own way.
# """.strip()
FAKE_SYSTEM_PROMPT = SYSTEM_PROMPT
XML_COT_FORMAT = """{thinking}
<response>
{answer}
</response>
"""
def extract_xml_answer(text: str) -> str:
answer = text.split("<response>")[-1]
# answer = answer.split("</response>")[0]
if answer.count("</response>") >= 1:
answer = answer.split("</response>")[0]
else:
return ""
# take last number in answer region as answer
# Allow more verbose answer, like $300 or "Alex need to buy 100 apples" (for example)
number_answer = re.findall(r"\d+\.\d+", answer)
if number_answer:
answer = number_answer[-1]
int_answer = re.findall(r"\d+", answer)
if int_answer:
answer = int_answer[-1]
return answer.strip()
def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip().replace(",", "")
def get_gsm8k_questions(split="train") -> Dataset:
data = load_dataset("openai/gsm8k", "main")[split]
data = data.map(
lambda x: {
"question": x["question"],
"model_input": tokenizer.apply_chat_template(
[
{"role": "system", "content": SYSTEM_PROMPT},
# In-Context Learning prompt, somehow necessary for smaller model
{"role": "user", "content": "Which is bigger, 9.11 or 9.8?"},
{
"role": "assistant",
"content": XML_COT_FORMAT.format(
thinking=(
"User asking me to compare 9.11 and 9.8, to compare two numbers, "
"we need to look at the integer part first, "
"9 is the same in both numbers, "
"then we look at the decimal part, "
"0.11 is smaller than 0.8, so 9.11 is smaller than 9.8."
),
answer="9.8",
),
},
{"role": "user", "content": x["question"]},
],
tokenize=False,
add_generation_prompt=True,
),
"answer": extract_hash_answer(x["answer"]),
}
)
return data
## reward functions
def correctness_reward_func(completions, answers, **kwargs) -> list[float]:
extracted_responses = [extract_xml_answer(r) for r in completions]
return [answer_r if r == a else 0.0 for r, a in zip(extracted_responses, answers)]
def format_reward(text) -> float:
count = 0.0
if text.count("<response>") == 1:
if text.split("<response>")[-1].count("</response>") == 1:
count += format_r
thinking_length = len(text.split("<response>")[0])
if thinking_length < 500:
count += thinking_length * length_penalty
else:
count -= (thinking_length - 500) * length_penalty
count -= len(text.split("</response>")[1]) * exceed_penalty
return count
def format_reward_func(completions, **kwargs) -> list[float]:
return [format_reward(c) for c in completions]
def total_reward(completions, **kwargs) -> list[float]:
rewards = [
correctness_reward_func(completions, **kwargs),
format_reward_func(completions, **kwargs),
]
return [sum(r) for r in zip(*rewards)], rewards
## GRPO Utilities
def get_advantages(completions, group_size=GROUP_SIZE, **kwargs) -> list[float]:
"""
completions: list of group completions
"""
rewards, rewards_list = total_reward(completions, **kwargs)
rewards = torch.tensor(rewards, dtype=DTYPE, device=DEVICE)
rewards = rewards.view(-1, group_size)
mean_group = rewards.mean(dim=1)
std_group = rewards.std(dim=1)
mean_group = torch.repeat_interleave(mean_group, group_size, dim=0)
std_group = torch.repeat_interleave(std_group, group_size, dim=0)
advantages = (rewards.flatten() - mean_group) / (std_group + 1e-4)
return advantages, rewards.flatten(), rewards_list
def logp_per_token(model, inputs, chunked=False):
input_ids = inputs["input_ids"][:, 1:]
if chunked:
patch_lm_head_for_logp(model.lm_head, input_ids, chunk_size=1024)
with torch.autocast("cuda", DTYPE):
logp_per_token = model(**inputs, return_dict=False)[0]
unpatch_lm_head_for_logp(model.lm_head)
return logp_per_token
else:
with torch.autocast("cuda", DTYPE):
output = model(**inputs, return_dict=False)[0]
return _logp_per_token(output[:, :-1, :], input_ids)
@torch.compile
def _logp_per_token(logits, input_ids):
"""
logits: [B, L, V]
input_ids: [B, L]
L is dynamic
"""
choosed_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits])
logp = choosed_logits - logsumexp_values
return logp
## Training Setup
inf_model = LLM(
model_name,
gpu_memory_utilization=0.15,
max_seq_len_to_capture=1024,
max_model_len=1024,
max_num_seqs=128,
)
inf_model.llm_engine.model_executor.driver_worker.model_runner.model.requires_grad_(
False
)
model = (
AutoModelForCausalLM.from_pretrained(
model_name,
attn_implementation="flash_attention_2",
)
.requires_grad_(False)
.eval()
.to(DTYPE)
.to(DEVICE)
)
def vllm_load_state_dict(llm: LLM, state_dict, skip_dicts):
llm.llm_engine.model_executor.driver_worker.model_runner.model.load_weights(
tuple((k, v) for k, v in state_dict.items() if k not in skip_dicts)
)
def vllm_model_mapping(inf_model, hf_model):
## Specifically for qwen, may not work for other arch( need custom impl)
vllm_model = inf_model.llm_engine.model_executor.driver_worker.model_runner.model
vllm_state_dict = vllm_model.state_dict()
config = vllm_model.config
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
for key, param in hf_model.state_dict().items():
if key in vllm_state_dict:
param.data = vllm_state_dict[key].data
else:
for param_name, shard_name, shard_id in stacked_params_mapping:
try_key = key.replace(shard_name, param_name)
if try_key in vllm_state_dict:
current_param = vllm_state_dict[try_key]
out_ch = current_param.shape[0]
if isinstance(shard_id, int):
target_ch = out_ch // 2
param.data = current_param.data[
target_ch * shard_id : target_ch * (shard_id + 1)
]
else:
q_heads = config.num_attention_heads
kv_heads = config.num_key_value_heads
total = q_heads + kv_heads * 2
dims = {
"q": (0, out_ch * q_heads // total),
"k": (
out_ch * q_heads // total,
out_ch * kv_heads // total,
),
"v": (
out_ch * q_heads // total + out_ch * kv_heads // total,
out_ch * kv_heads // total,
),
}[shard_id]
param.data = current_param.data[dims[0] : dims[0] + dims[1]]
break
else:
pass
skip_dicts = vllm_model_mapping(inf_model, model)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if getattr(tokenizer, "pad_token", None) is None:
tokenizer.pad_token = tokenizer.eos_token
GEN_CONFIG.pad_token_id = tokenizer.pad_token_id
LycorisNetwork.apply_preset(
{
"target_module": [],
"target_name": [
".*proj.*",
".*norm.*",
".*lm_head"
],
}
)
lycoris_net = create_lycoris(
model,
algo="lokr",
full_matrix=True,
factor=8,
bypass_mode=False,
train_norm=True,
)
lycoris_net.apply_to()
lycoris_net.to(DEVICE)
print(
sum(p.numel() for p in lycoris_net.parameters() if p.requires_grad) / 1e6,
"M trainable parameters",
)
optimizer = torch.optim.AdamW(lycoris_net.parameters(), lr=LR, betas=(0.8, 0.98))
scheduler = AnySchedule(
optimizer,
config={
# Value is factor for corresponding hyperparam
"lr": {
"mode": "cosine",
"end": MAX_STEPS,
"value": 1.0,
"min_value": 0.05,
# "warmup": int(MAX_STEPS * 0.15),
}
},
)
dataset = get_gsm8k_questions()
dataloader = data.DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
collate_fn=lambda x: {k: [b[k] for b in x] for k in x[0].keys()},
)
torch.cuda.empty_cache()
## Main Training Loop
## GRPO Quick Note: https://hackmd.io/@KBlueLeaf/SJ5OZ00_1g
pbar = tqdm(
range(len(dataloader) * EPOCHS * GRPO_ITER), desc="GRPO Training", smoothing=0.01
)
rewards_ema = 0
kl_div_ema = 0
loss_ema = 0
for epoch in range(EPOCHS):
for idx, batch in enumerate(dataloader):
## Sample entries from dataset
lycoris_net.restore()
questions = list(chain(*([q] * GROUP_SIZE for q in batch["question"])))
answers = list(chain(*([a] * GROUP_SIZE for a in batch["answer"])))
text_inputs = list(chain(*[[i] * GROUP_SIZE for i in batch["model_input"]]))
## Generate group results
torch.cuda.empty_cache()
lycoris_net.onfly_merge()
vllm_load_state_dict(inf_model, model.state_dict(), {})
vllm_model_mapping(inf_model, model)
outputs = inf_model.generate(
text_inputs, sampling_params=GEN_CONFIG, use_tqdm=False
)
completions = [output.outputs[0].text for output in outputs]
lycoris_net.onfly_restore()
torch.cuda.empty_cache()
## Compute group rewards
advantages, rewards, rewards_list = get_advantages(
completions,
group_size=GROUP_SIZE,
answers=answers,
questions=questions,
)
total_rewards = rewards.mean().item()
ema_decay = min((pbar.n / GRPO_ITER) / (pbar.n / GRPO_ITER + 1), 0.99)
rewards_ema = rewards_ema * ema_decay + total_rewards * (1 - ema_decay)
## Log completions
table = wandb.Table(columns=["question", "answer", "model output", "reward"])
for idx, (q, a, c, r) in enumerate(
zip(questions, answers, completions, rewards)
):
if idx % GROUP_SIZE == 0:
table.add_data(q, a, c, r.item())
wandb.log({"completions": table}, commit=False)
## Compute reference log-prob
final_inputs_len = [
len(
tokenizer.apply_chat_template(
[
{"role": "system", "content": FAKE_SYSTEM_PROMPT},
{"role": "user", "content": q},
]
)
)
for q in questions
]
final_model_input = [
tokenizer.apply_chat_template(
[
{"role": "system", "content": FAKE_SYSTEM_PROMPT},
{"role": "user", "content": q},
{"role": "assistant", "content": c},
],
tokenize=False,
)
for q, c in zip(questions, completions)
]
inputs = tokenizer(
final_model_input,
return_tensors="pt",
padding="max_length",
max_length=MAX_LEN,
padding_side="left",
truncation=True,
)
final_total_len = (inputs["attention_mask"].sum(dim=1) - 1).tolist()
output_len = [
(total - inp_len - 1)
if total < MAX_LEN
else (min(MAX_LEN, total - inp_len) - 1)
for inp_len, total in zip(final_inputs_len, final_total_len)
]
inputs["input_ids"] = inputs["input_ids"][:, -MAX_LEN:].to(DEVICE)
inputs["attention_mask"] = inputs["attention_mask"][:, -MAX_LEN:].to(DEVICE)
token_with_loss = sum(output_len)
with torch.inference_mode():
logp_refs = [None] * len(inputs["input_ids"])
for i in range(0, len(inputs["input_ids"]), MINI_BATCH_SIZE):
current_inputs = {
k: v[i : i + MINI_BATCH_SIZE].to(DEVICE) for k, v in inputs.items()
}
logp_refs[i] = logp_per_token(model, current_inputs)
torch.cuda.empty_cache()
lycoris_net.apply_to()
## GRPO Iterations
if GRAD_CKPT:
model.gradient_checkpointing_enable()
logp_old = [None] * len(inputs["input_ids"])
for _ in range(GRPO_ITER):
step_loss = 0
step_kl_div = 0
step_objective = 0
steps = 0
for i in range(0, len(inputs["input_ids"]), MINI_BATCH_SIZE):
steps += 1
current_inputs = {
k: v[i : i + MINI_BATCH_SIZE].to(DEVICE).clone()
for k, v in inputs.items()
}
with torch.enable_grad():
logp_proxy = logp_per_token(model, current_inputs, chunked=True)
if logp_old[i] is None:
logp_old[i] = logp_proxy.detach().clone()
kl_div = (
torch.exp(logp_refs[i] - logp_proxy)
- (logp_refs[i] - logp_proxy)
- 1
)
adv_scale = torch.exp(logp_proxy - logp_old[i])
clipped_scale = torch.clamp(adv_scale, 1 - ETA, 1 + ETA)
scaled_advantages = (
advantages[i : i + MINI_BATCH_SIZE, None] * adv_scale
)
clipped_advantages = (
advantages[i : i + MINI_BATCH_SIZE, None] * clipped_scale
)
mask = scaled_advantages < clipped_advantages
final_advantages = torch.where(
mask, scaled_advantages, clipped_advantages
)
per_token_loss = -final_advantages + BETA * kl_div
loss_mask = current_inputs["attention_mask"][:, :-1]
for idx, length in enumerate(output_len[i : i + MINI_BATCH_SIZE]):
loss_mask[idx, : -(length - 1)] = 0
loss = (
torch.sum(
torch.where(
loss_mask.bool(),
per_token_loss,
torch.zeros_like(per_token_loss),
)
)
/ token_with_loss
)
loss.backward()
step_kl_div += (kl_div * loss_mask).sum().float().item()
step_loss += loss.float().item()
step_objective += (
((final_advantages - kl_div * BETA) * loss_mask)
.sum()
.float()
.item()
)
ema_decay = min((pbar.n) / (pbar.n + 1), 0.999)
step_loss = step_loss
step_kl_div = step_kl_div / token_with_loss
step_objective = step_objective / token_with_loss
loss_ema = loss_ema * ema_decay + step_loss * (1 - ema_decay)
kl_div_ema = kl_div_ema * ema_decay + step_kl_div * (1 - ema_decay)
pbar.set_postfix(
{
"rewards": rewards.float().mean().item(),
"loss": step_loss,
"kl_div": step_kl_div,
"objective": step_objective,
}
)
wandb.log(
{
"rewards": rewards.float().mean().item(),
"correctness_reward": torch.mean(torch.tensor(rewards_list[0])),
"format_reward": torch.mean(torch.tensor(rewards_list[1])),
"loss": step_loss,
"kl_div": step_kl_div,
"objective": step_objective,
"rewards_ema": rewards_ema,
"loss_ema": loss_ema,
"kl_div_ema": kl_div_ema,
},
commit=False,
)
torch.nn.utils.clip_grad_norm_(lycoris_net.parameters(), 0.1)
optimizer.step()
optimizer.zero_grad()
pbar.update()
wandb.log({"global_step": pbar.n})
if pbar.n >= MAX_STEPS:
break
scheduler.step()
model.gradient_checkpointing_disable()
if pbar.n >= MAX_STEPS:
break
if pbar.n >= MAX_STEPS:
break
pbar.close()
torch.cuda.empty_cache()
## Test after Training
lycoris_net.restore()
lycoris_net.merge_to()
model.save_pretrained(model_name.split("/")[-1] + "-grpo")
tokenizer.save_pretrained(model_name.split("/")[-1] + "-grpo")
vllm_load_state_dict(inf_model, model.state_dict(), {})
GEN_CONFIG = SamplingParams(max_tokens=4096, temperature=0.25, min_p=0.1)
generations = inf_model.generate([i for i in batch["model_input"]], sampling_params=GEN_CONFIG)
completions = [
output.outputs[0].text for output in generations
]
for q, a, c in zip(batch["question"], batch["answer"], completions):
print("=" * 30)
print("Question:")
print(q)
print("-" * 10)
print("Answer:")
print(a)
print("-" * 20)
print("Completion:")
print(c)
print("=" * 30)
print()
##
# GSM-8K test script
# tested results:
# - Qwen2.5-0.5B-instruct (with ICL prompt)
# before grpo: 0.429
# after grpo : 0.522
# - Qwen2.5-1.5B-instruct (without ICL prompt)
# before grpo: 0.583
# after grpo : 0.677
##
import re
import torch
import torch.utils.data as data
from tqdm import tqdm
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
# model_name = "Qwen/Qwen2.5-0.5B-Instruct"
# model_name = "./Qwen2.5-0.5B-Instruct-grpo"
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
# model_name = "./Qwen2.5-1.5B-Instruct-grpo"
tokenizer = AutoTokenizer.from_pretrained(model_name)
## Constants
GEN_CONFIG = SamplingParams(max_tokens=4096, temperature=0.25, min_p=0.1)
DTYPE = torch.bfloat16
DEVICE = "cuda"
## Define the GSM8K dataset and reward functions
## Modified from:
## https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb
SYSTEM_PROMPT = """Following this Format, \
Don't forget <response> and don't put anything outside the format:
---
Your Step by Step thinking
<response>
Your Final Answer or Response
</response>
"""
# FAKE_SYSTEM_PROMPT = """
# You are a human, you can think and answer in your own way.
# """.strip()
FAKE_SYSTEM_PROMPT = SYSTEM_PROMPT
XML_COT_FORMAT = """{thinking}
<response>
{answer}
</response>
"""
def extract_xml_answer(text: str) -> str:
answer = text.split("<response>")[-1]
answer = answer.split("</response>")[0]
# take last number in answer region as answer
# Allow more verbose answer, like $300 or "Alex need to buy 100 apples" (for example)
number_answer = re.findall(r"\d+\.\d+", answer)
if number_answer:
answer = number_answer[-1]
int_answer = re.findall(r"\d+", answer)
if int_answer:
answer = int_answer[-1]
return answer.strip()
def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip().replace(",", "")
def get_gsm8k_questions(split="train") -> Dataset:
data = load_dataset("openai/gsm8k", "main")[split]
data = data.map(
lambda x: {
"question": x["question"],
"model_input": tokenizer.apply_chat_template(
[
{"role": "system", "content": SYSTEM_PROMPT},
# In-Context Learning prompt, somehow necessary for smaller model
# {"role": "user", "content": "Which is bigger, 9.11 or 9.8?"},
# {
# "role": "assistant",
# "content": XML_COT_FORMAT.format(
# thinking=(
# "User asking me to compare 9.11 and 9.8, to compare two numbers, "
# "we need to look at the integer part first, "
# "9 is the same in both numbers, "
# "then we look at the decimal part, "
# "0.11 is smaller than 0.8, so 9.11 is smaller than 9.8."
# ),
# answer="9.8",
# ),
# },
{"role": "user", "content": x["question"]},
],
tokenize=False,
add_generation_prompt=True,
),
"answer": extract_hash_answer(x["answer"]),
}
)
return data
## reward functions
def correctness_reward_func(completions, answers, **kwargs) -> list[float]:
extracted_responses = [extract_xml_answer(r) for r in completions]
return [1 if r == a else 0 for r, a in zip(extracted_responses, answers)]
## Training Setup
inf_model = LLM(
model_name,
gpu_memory_utilization=0.85,
max_seq_len_to_capture=4096,
max_model_len=4096,
max_num_seqs=256,
)
dataset = get_gsm8k_questions()
dataloader = data.DataLoader(
dataset,
batch_size=8192,
shuffle=True,
collate_fn=lambda x: {k: [b[k] for b in x] for k in x[0].keys()},
)
total_correct = []
for idx, batch in enumerate(dataloader):
## Sample entries from dataset
questions = batch["question"]
answers = batch["answer"]
text_inputs = batch["model_input"]
## Generate group results
outputs = inf_model.generate(text_inputs, sampling_params=GEN_CONFIG)
completions = [output.outputs[0].text for output in outputs]
correct = correctness_reward_func(completions, answers)
total_correct.extend(correct)
print(f"Accuracy: {sum(total_correct) / len(total_correct)}")
import torch
import torch.nn.functional as F
def _logp_per_token(logits, input_ids):
"""
logits: [B, L, V]
input_ids: [B, L]
L is dynamic
"""
choosed_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits])
logp = choosed_logits - logsumexp_values
return logp
class FusedLinearChunkedLogProb(torch.autograd.Function):
"""
A fused implementation of hidden -> lm_head -> log_probs per token.
to avoid pytorch to cache the logits which is super duper LARGE.
"""
@staticmethod
def forward(
ctx,
hidden_state,
lm_head_weight,
lm_head_bias,
input_ids,
chunk_size: int = 128,
):
"""
Calculates log probability per token in a chunked manner.
Args:
# logits (torch.Tensor): [B, L, V] tensor of logits.
hidden_state (torch.Tensor): [B, L, H] tensor of hidden states.
lm_head_weight (torch.Tensor): [V, H] tensor of lm head weights.
lm_head_bias (torch.Tensor): [V] tensor of lm head biases.
input_ids (torch.Tensor): [B, L] tensor of token indices.
chunk_size (Optional[int]): Size of chunks for processing. If None, defaults to batch size.
Returns:
torch.Tensor: [B, L] tensor of log probabilities per token.
"""
B, L, _ = hidden_state.shape
ctx.chunk_size = chunk_size
logp_chunks = []
@torch.compile
def chunk_forward(hidden_chunk, lm_head_weight, lm_head_bias, input_ids_chunk):
logits_chunk = F.linear(hidden_chunk, lm_head_weight, lm_head_bias)
choosed_logits_chunk = logits_chunk.gather(
dim=-1, index=input_ids_chunk.unsqueeze(-1)
).squeeze(-1)
logsumexp_values_chunk = torch.logsumexp(logits_chunk, dim=-1)
logp_chunk = choosed_logits_chunk - logsumexp_values_chunk
return logp_chunk
for i in range(0, B):
logp_chunks_ele = []
for j in range(0, L, chunk_size):
hidden_chunk = hidden_state[i, j : j + chunk_size]
input_ids_chunk = input_ids[i, j : j + chunk_size]
logp_chunk = chunk_forward(
hidden_chunk, lm_head_weight, lm_head_bias, input_ids_chunk
)
logp_chunks_ele.append(logp_chunk)
logp_chunks.append(torch.cat(logp_chunks_ele, dim=0))
# save for backward is moved to the end to utilize the calculation within forward.
ctx.save_for_backward(
hidden_state,
input_ids,
lm_head_weight,
lm_head_bias if lm_head_bias is not None else None,
) # Save for backward. Offload if needed.
return torch.stack(logp_chunks, dim=0)
@staticmethod
def backward(ctx, grad_output):
"""
Calculates gradients in a chunked manner.
Args:
grad_output (torch.Tensor): [B, L] tensor of incoming gradients.
Returns:
Tuple[torch.Tensor, None, None]: Gradients w.r.t. logits, input_ids (None), and chunk_size (None).
"""
hidden_state: torch.Tensor
hidden_state, input_ids, lm_head_weight, lm_head_bias = ctx.saved_tensors
chunk_size = ctx.chunk_size
B, L, _ = hidden_state.shape
h_dtype = hidden_state.dtype
w_dtype = lm_head_weight.dtype
grad_hidden = torch.zeros_like(hidden_state, device=grad_output.device)
grad_weight = torch.zeros_like(lm_head_weight, device=grad_output.device)
if lm_head_bias is not None:
grad_bias = torch.zeros_like(lm_head_bias, device=grad_output.device)
else:
grad_bias = None
@torch.compile
def chunk_backward(
hidden_chunk,
lm_head_weight,
lm_head_bias,
input_ids_chunk,
grad_output_chunk,
):
logits_chunk = F.linear(hidden_chunk, lm_head_weight, lm_head_bias)
# Calculate softmax probabilities (efficient for gradient calculation).
probs_chunk = F.softmax(logits_chunk, dim=-1)
# 1. Gradient from the logsumexp term (d/dx_i logsumexp(x) = exp(x_i) / sum(exp(x_j)) = softmax(x)_i)
grad_logits_chunk = -probs_chunk * grad_output_chunk.unsqueeze(-1)
# 2. Gradient from the chosen logits term
index_for_scatter = input_ids_chunk.unsqueeze(-1) # Reshape for scatter_add
grad_logits_chunk.scatter_add_(
dim=-1, index=index_for_scatter, src=grad_output_chunk.unsqueeze(-1)
)
grad_hidden_chunk = torch.matmul(
grad_logits_chunk.to(h_dtype), lm_head_weight.to(h_dtype)
)
grad_weight_chunk = torch.matmul(
grad_logits_chunk.transpose(0, 1).to(w_dtype), hidden_chunk.to(w_dtype)
)
if lm_head_bias is not None:
grad_bias_chunk = torch.sum(grad_logits_chunk.to(w_dtype), dim=0)
else:
grad_bias_chunk = None
return grad_hidden_chunk, grad_weight_chunk, grad_bias_chunk
for i in range(0, B):
for j in range(0, L, chunk_size):
hidden_chunk = hidden_state[i, j : j + chunk_size].to(
grad_output.device
)
input_ids_chunk = input_ids[i, j : j + chunk_size]
grad_output_chunk = grad_output[i, j : j + chunk_size]
grad_hidden_chunk, grad_weight_chunk, grad_bias_chunk = chunk_backward(
hidden_chunk,
lm_head_weight,
lm_head_bias,
input_ids_chunk,
grad_output_chunk,
)
del hidden_chunk
grad_hidden[i, j : j + chunk_size] += grad_hidden_chunk
grad_weight += grad_weight_chunk
if lm_head_bias is not None:
grad_bias += grad_bias_chunk
return grad_hidden, grad_weight, grad_bias, None, None
def _linear_with_chunked_logp(x, lm_head, input_ids, chunk_size: int = 128):
return FusedLinearChunkedLogProb.apply(
x, lm_head.weight, lm_head.bias, input_ids, chunk_size
)
def patch_lm_head_for_logp(lm_head, logit_ids, chunk_size: int = 128):
lm_head.org_forward = lm_head.forward
def new_forward(x, *args, **kwargs):
return _linear_with_chunked_logp(x[:, :-1], lm_head, logit_ids, chunk_size)
lm_head.forward = new_forward
def unpatch_lm_head_for_logp(lm_head):
lm_head.forward = getattr(lm_head, "org_forward", lm_head.forward)
delattr(lm_head, "org_forward")
if __name__ == "__main__":
x = torch.randn(4, 1024, 1024).cuda()
lm_head = torch.nn.Linear(1024, 128000).cuda()
input_ids = torch.randint(0, 128000, (4, 1024)).cuda()
logits = lm_head(x)
input_ids = torch.randint(0, 128000, (4, 1024)).cuda()
logp = _logp_per_token(logits, input_ids)
logp_chunked = _linear_with_chunked_logp(x, lm_head, input_ids, chunk_size=128)
print(torch.allclose(logp, logp_chunked))
x = x.requires_grad_(True)
logits = lm_head(x)
logp = _logp_per_token(logits, input_ids)
logp.sum().backward()
first_grad_x = x.grad.clone()
first_grad_w = lm_head.weight.grad.clone()
first_grad_b = lm_head.bias.grad.clone()
x.grad = None
lm_head.weight.grad = None
lm_head.bias.grad = None
logp_chunked = _linear_with_chunked_logp(x, lm_head, input_ids, chunk_size=128)
logp_chunked.sum().backward()
second_grad_x = x.grad.clone()
second_grad_w = lm_head.weight.grad.clone()
second_grad_b = lm_head.bias.grad.clone()
x.grad = None
lm_head.weight.grad = None
lm_head.bias.grad = None
print(torch.allclose(first_grad_x, second_grad_x))
print(
F.mse_loss(first_grad_x, second_grad_x),
torch.max(torch.abs(first_grad_x - second_grad_x)),
)
print(torch.allclose(first_grad_w, second_grad_w))
print(
F.mse_loss(first_grad_w, second_grad_w),
torch.max(torch.abs(first_grad_w - second_grad_w)),
)
print(torch.allclose(first_grad_b, second_grad_b))
print(
F.mse_loss(first_grad_b, second_grad_b),
torch.max(torch.abs(first_grad_b - second_grad_b)),
)
@KohakuBlueleaf
Copy link
Author

Since my training on this script keep getting bad, I don't have any idea if this implementation is correct or not
if you find anything weird, please notice me.

@KohakuBlueleaf
Copy link
Author

Current version can reach high reward easily with < 100 step training.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment