Skip to content

Instantly share code, notes, and snippets.

@cgpeter96
Created February 7, 2025 13:53
Show Gist options
  • Save cgpeter96/53ffcd5b49c10e8de5303059c21388ac to your computer and use it in GitHub Desktop.
Save cgpeter96/53ffcd5b49c10e8de5303059c21388ac to your computer and use it in GitHub Desktop.
a grpo modifaction for deepspeed in multigpu from https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb
# train_grpo.py
from typing import *
import re
import torch
from datasets import load_dataset, Dataset, load_from_disk
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer, TrlParser
from dataclasses import dataclass, field
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
)
},
)
model_type: Optional[str] = field(
default=None,
metadata={"help": "If training from scratch, pass a model type from the list: "},
)
config_overrides: Optional[str] = field(
default=None,
metadata={
"help": (
"Override some existing default config settings when a model is trained from scratch. Example: "
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
)
},
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
token: str = field(
default=None,
metadata={
"help": (
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
)
},
)
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether to trust the execution of code from datasets/models defined on the Hub."
" This option should only be set to `True` for repositories you trust and in which you have read the"
" code, as it will execute code present on the Hub on your local machine."
)
},
)
torch_dtype: Optional[str] = field(
default=None,
metadata={
"help": (
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
"dtype will be automatically derived from the model's weights."
),
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
low_cpu_mem_usage: bool = field(
default=False,
metadata={
"help": (
"It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "
"set True will benefit LLM loading time and RAM consumption."
)
},
)
def __post_init__(self):
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
raise ValueError(
"--config_overrides can't be used in combination with --config_name or --model_name_or_path"
)
# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
def extract_xml_answer(text: str) -> str:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip()
# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
# data = load_from_disk("path tp gsm8k")[split] # for local path
data = data.map(lambda x: { # type: ignore
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
#{'role': 'user', 'content': 'What is the largest single-digit prime number?'},
#{'role': 'assistant', 'content': XML_COT_FORMAT.format(
# reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.",
# answer="7"
#)},
{'role': 'user', 'content': x['question']}
],
'answer': extract_hash_answer(x['answer'])
}) # type: ignore
return data # type: ignore
dataset = get_gsm8k_questions()
# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def int_reward_func(completions, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
def strict_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def soft_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def count_xml(text) -> float:
count = 0.0
if text.count("<reasoning>\n") == 1:
count += 0.125
if text.count("\n</reasoning>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1])*0.001
if text.count("\n</answer>") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
return count
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
def main(model_args, training_args):
# peft_config = LoraConfig(
# r=16,
# lora_alpha=64,
# target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
# task_type="CAUSAL_LM",
# lora_dropout=0.05,
# )
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2",
)
model = model.to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
# use peft at your own risk; not working for me with multi-GPU training
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func],
args=training_args,
train_dataset=dataset,
#peft_config=peft_config
)
trainer.train()
if __name__ == "__main__":
parser = TrlParser((ModelArguments,GRPOConfig,))
model_args, training_args, = parser.parse_args_and_config()
main(model_args, training_args)
CUDA_HOME=/usr/local/cuda
gpus=${gpus:-0,1,2,3,4,5,6,7}
dataset=${dataset:-sft_train_data}
output_dir=${output_dir:-experiments/$(date +"%Y%m%d_%H%M%S")}
port=$(shuf -i 10000-20000 -n 1)
deepspeed --include localhost:${gpus} --master_port=$port grpo_demo.py \
--deepspeed "ds_zero2.json" \
--model_name_or_path "path to Qwen2.5-1.5B-Instruct/" \
--output_dir outputs/Qwen2.5-1.5B-GRPO-gsm8k \
--run_name Qwen2.5-1.5B-GRPO-gsm8k \
--learning_rate 1e-5 \
--adam_beta1 0.9 \
--adam_beta2 0.99 \
--weight_decay 0.1 \
--warmup_ratio 0.1 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--bf16 True \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 2 \
--num_generations 16 \
--max_prompt_length 512 \
--max_completion_length 768 \
--num_train_epochs 5 \
--save_steps 100 \
--max_grad_norm 0.1 \
--report_to tensorboard \
--log_on_each_node False
@moaterboat
Copy link

moaterboat commented Feb 23, 2025

Do you mind sharing what you're using for your ds_zero2.json?

Just the defaults here listed here?
https://github.com/huggingface/transformers/blob/main/tests/deepspeed/ds_config_zero2.json

@why6why
Copy link

why6why commented Feb 26, 2025

Do you mind sharing what you're using for your ds_zero2.json?

Just the defaults here listed here? https://github.com/huggingface/transformers/blob/main/tests/deepspeed/ds_config_zero2.json

i need also. thanks for your contributions

@cgpeter96
Copy link
Author

ust the defaults here listed here

Yes, it's a common configuration.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment