-
Star
(1,196)
You must be signed in to star a gist -
Fork
(372)
You must be signed in to fork a gist
-
-
Save willccbb/4676755236bb08cab5f4e54a0475d6fb to your computer and use it in GitHub Desktop.
# train_grpo.py | |
# | |
# See https://github.com/willccbb/verifiers for ongoing developments | |
# | |
import re | |
import torch | |
from datasets import load_dataset, Dataset | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import LoraConfig | |
from trl import GRPOConfig, GRPOTrainer | |
# Load and prep dataset | |
SYSTEM_PROMPT = """ | |
Respond in the following format: | |
<reasoning> | |
... | |
</reasoning> | |
<answer> | |
... | |
</answer> | |
""" | |
XML_COT_FORMAT = """\ | |
<reasoning> | |
{reasoning} | |
</reasoning> | |
<answer> | |
{answer} | |
</answer> | |
""" | |
def extract_xml_answer(text: str) -> str: | |
answer = text.split("<answer>")[-1] | |
answer = answer.split("</answer>")[0] | |
return answer.strip() | |
def extract_hash_answer(text: str) -> str | None: | |
if "####" not in text: | |
return None | |
return text.split("####")[1].strip().replace(",", "").replace("$", "") | |
# uncomment middle messages for 1-shot prompting | |
def get_gsm8k_questions(split = "train") -> Dataset: | |
data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore | |
data = data.map(lambda x: { # type: ignore | |
'prompt': [ | |
{'role': 'system', 'content': SYSTEM_PROMPT}, | |
#{'role': 'user', 'content': 'What is the largest single-digit prime number?'}, | |
#{'role': 'assistant', 'content': XML_COT_FORMAT.format( | |
# reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.", | |
# answer="7" | |
#)}, | |
{'role': 'user', 'content': x['question']} | |
], | |
'answer': extract_hash_answer(x['answer']) | |
}) # type: ignore | |
return data # type: ignore | |
dataset = get_gsm8k_questions() | |
# Reward functions | |
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: | |
responses = [completion[0]['content'] for completion in completions] | |
q = prompts[0][-1]['content'] | |
extracted_responses = [extract_xml_answer(r) for r in responses] | |
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") | |
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] | |
def int_reward_func(completions, **kwargs) -> list[float]: | |
responses = [completion[0]['content'] for completion in completions] | |
extracted_responses = [extract_xml_answer(r) for r in responses] | |
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] | |
def strict_format_reward_func(completions, **kwargs) -> list[float]: | |
"""Reward function that checks if the completion has a specific format.""" | |
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$" | |
responses = [completion[0]["content"] for completion in completions] | |
matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] | |
return [0.5 if match else 0.0 for match in matches] | |
def soft_format_reward_func(completions, **kwargs) -> list[float]: | |
"""Reward function that checks if the completion has a specific format.""" | |
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>" | |
responses = [completion[0]["content"] for completion in completions] | |
matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] | |
return [0.5 if match else 0.0 for match in matches] | |
def count_xml(text) -> float: | |
count = 0.0 | |
if text.count("<reasoning>\n") == 1: | |
count += 0.125 | |
if text.count("\n</reasoning>\n") == 1: | |
count += 0.125 | |
if text.count("\n<answer>\n") == 1: | |
count += 0.125 | |
count -= len(text.split("\n</answer>\n")[-1])*0.001 | |
if text.count("\n</answer>") == 1: | |
count += 0.125 | |
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001 | |
return count | |
def xmlcount_reward_func(completions, **kwargs) -> list[float]: | |
contents = [completion[0]["content"] for completion in completions] | |
return [count_xml(c) for c in contents] | |
#model_name = "meta-llama/Llama-3.2-1B-Instruct" | |
model_name = "Qwen/Qwen2.5-1.5B-Instruct" | |
if "Llama" in model_name: | |
output_dir = "outputs/Llama-1B-GRPO" | |
run_name = "Llama-1B-GRPO-gsm8k" | |
else: | |
output_dir="outputs/Qwen-1.5B-GRPO" | |
run_name="Qwen-1.5B-GRPO-gsm8k" | |
training_args = GRPOConfig( | |
output_dir=output_dir, | |
run_name=run_name, | |
learning_rate=5e-6, | |
adam_beta1 = 0.9, | |
adam_beta2 = 0.99, | |
weight_decay = 0.1, | |
warmup_ratio = 0.1, | |
lr_scheduler_type='cosine', | |
logging_steps=1, | |
bf16=True, | |
per_device_train_batch_size=1, | |
gradient_accumulation_steps=4, | |
num_generations=16, | |
max_prompt_length=256, | |
max_completion_length=786, | |
num_train_epochs=1, | |
save_steps=100, | |
max_grad_norm=0.1, | |
report_to="wandb", | |
log_on_each_node=False, | |
) | |
peft_config = LoraConfig( | |
r=16, | |
lora_alpha=64, | |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"], | |
task_type="CAUSAL_LM", | |
lora_dropout=0.05, | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
attn_implementation="flash_attention_2", | |
device_map=None | |
).to("cuda") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer.pad_token = tokenizer.eos_token | |
# use peft at your own risk; not working for me with multi-GPU training | |
trainer = GRPOTrainer( | |
model=model, | |
processing_class=tokenizer, | |
reward_funcs=[ | |
xmlcount_reward_func, | |
soft_format_reward_func, | |
strict_format_reward_func, | |
int_reward_func, | |
correctness_reward_func], | |
args=training_args, | |
train_dataset=dataset, | |
#peft_config=peft_config | |
) | |
trainer.train() |
I used an A100(40 GB) on the colab. But it can barely support the training of a 1.5B model.
And I found that GRPO does not support efficient fine-tuning, and can only fine-tune all parameters,issue, Can anyone solve this problem?or.. Are there other reinforcement learning frameworks that support efficient parameter fine-tuning(peft)?
After about 3 hours of fine-tuning, the effect of the 1.5B model is indeed better than that of the 0.5B model.

a 1.5B model can be trained for nearly 460 steps in 3 hours with an a100
Can someone share the code of evaluating the model on Test set of gsm8k? Thanks a lot!!!!!!有人能提供一下在gsm'8k数据的测试集上评估模型训练结果的代码吗?非常感谢!
how to adjust the params with just two A10?
@willccbb, why would we prefer separate reward functions instead of having a single unified one in GRPO?
being able to log individual rewards is pretty useful for debugging imo
consolidating them into one shouldn't affect actual training dynamics though
IMHO separate, additive rewards introduce a lot of repetition (e.g., parsing responses) and limit creativity in reward design, e.g., I may want the formatting reward to be a gate for the others, as I may not even want to evaluate a response if the formatting is wrong.
definitely get creative with it! nothing wrong with using if statements + multiplication in your reward functions
Does it also work on smaller models like >3B params model?
When training on the GPU with qwen model, I encountered the error: " probability tensor contains either
inf
,nan
or element < 0"
Hi @fsxbhyy, did you load the model in torch.bfloat16
? I used to encounter such issue when I loaded models in torch.float16
instead of bfloat
. I guess float16
in this context leads to numerical instability, leading to NaN probs. Hope this helps!
I got the same problem. I trained 7B with batch_size == 1, but it just keep reporting oom.
@harrywoo @Tuziking I had the same problem. I then noticed that these values are actually huge for most cases:
max_prompt_length=256,
max_completion_length=786,
786 generated tokens to process per generation requires a lot of memory, especially if your group size is large. Try to set this to 150 or 250 and see if it reduces memory usage. Hope this helps!
me, too!!I also issue my problem in the trl's github. If you solve the problem, please help me.