Created
February 7, 2025 13:53
-
-
Save cgpeter96/53ffcd5b49c10e8de5303059c21388ac to your computer and use it in GitHub Desktop.
a grpo modifaction for deepspeed in multigpu from https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# train_grpo.py | |
from typing import * | |
import re | |
import torch | |
from datasets import load_dataset, Dataset, load_from_disk | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments | |
from peft import LoraConfig | |
from trl import GRPOConfig, GRPOTrainer, TrlParser | |
from dataclasses import dataclass, field | |
@dataclass | |
class ModelArguments: | |
""" | |
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. | |
""" | |
model_name_or_path: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." | |
) | |
}, | |
) | |
model_type: Optional[str] = field( | |
default=None, | |
metadata={"help": "If training from scratch, pass a model type from the list: "}, | |
) | |
config_overrides: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"Override some existing default config settings when a model is trained from scratch. Example: " | |
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" | |
) | |
}, | |
) | |
config_name: Optional[str] = field( | |
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} | |
) | |
tokenizer_name: Optional[str] = field( | |
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} | |
) | |
cache_dir: Optional[str] = field( | |
default=None, | |
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, | |
) | |
use_fast_tokenizer: bool = field( | |
default=True, | |
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, | |
) | |
model_revision: str = field( | |
default="main", | |
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, | |
) | |
token: str = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " | |
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)." | |
) | |
}, | |
) | |
trust_remote_code: bool = field( | |
default=False, | |
metadata={ | |
"help": ( | |
"Whether to trust the execution of code from datasets/models defined on the Hub." | |
" This option should only be set to `True` for repositories you trust and in which you have read the" | |
" code, as it will execute code present on the Hub on your local machine." | |
) | |
}, | |
) | |
torch_dtype: Optional[str] = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " | |
"dtype will be automatically derived from the model's weights." | |
), | |
"choices": ["auto", "bfloat16", "float16", "float32"], | |
}, | |
) | |
low_cpu_mem_usage: bool = field( | |
default=False, | |
metadata={ | |
"help": ( | |
"It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. " | |
"set True will benefit LLM loading time and RAM consumption." | |
) | |
}, | |
) | |
def __post_init__(self): | |
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): | |
raise ValueError( | |
"--config_overrides can't be used in combination with --config_name or --model_name_or_path" | |
) | |
# Load and prep dataset | |
SYSTEM_PROMPT = """ | |
Respond in the following format: | |
<reasoning> | |
... | |
</reasoning> | |
<answer> | |
... | |
</answer> | |
""" | |
XML_COT_FORMAT = """\ | |
<reasoning> | |
{reasoning} | |
</reasoning> | |
<answer> | |
{answer} | |
</answer> | |
""" | |
def extract_xml_answer(text: str) -> str: | |
answer = text.split("<answer>")[-1] | |
answer = answer.split("</answer>")[0] | |
return answer.strip() | |
def extract_hash_answer(text: str) -> str | None: | |
if "####" not in text: | |
return None | |
return text.split("####")[1].strip() | |
# uncomment middle messages for 1-shot prompting | |
def get_gsm8k_questions(split = "train") -> Dataset: | |
data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore | |
# data = load_from_disk("path tp gsm8k")[split] # for local path | |
data = data.map(lambda x: { # type: ignore | |
'prompt': [ | |
{'role': 'system', 'content': SYSTEM_PROMPT}, | |
#{'role': 'user', 'content': 'What is the largest single-digit prime number?'}, | |
#{'role': 'assistant', 'content': XML_COT_FORMAT.format( | |
# reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.", | |
# answer="7" | |
#)}, | |
{'role': 'user', 'content': x['question']} | |
], | |
'answer': extract_hash_answer(x['answer']) | |
}) # type: ignore | |
return data # type: ignore | |
dataset = get_gsm8k_questions() | |
# Reward functions | |
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: | |
responses = [completion[0]['content'] for completion in completions] | |
q = prompts[0][-1]['content'] | |
extracted_responses = [extract_xml_answer(r) for r in responses] | |
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") | |
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] | |
def int_reward_func(completions, **kwargs) -> list[float]: | |
responses = [completion[0]['content'] for completion in completions] | |
extracted_responses = [extract_xml_answer(r) for r in responses] | |
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] | |
def strict_format_reward_func(completions, **kwargs) -> list[float]: | |
"""Reward function that checks if the completion has a specific format.""" | |
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$" | |
responses = [completion[0]["content"] for completion in completions] | |
matches = [re.match(pattern, r) for r in responses] | |
return [0.5 if match else 0.0 for match in matches] | |
def soft_format_reward_func(completions, **kwargs) -> list[float]: | |
"""Reward function that checks if the completion has a specific format.""" | |
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>" | |
responses = [completion[0]["content"] for completion in completions] | |
matches = [re.match(pattern, r) for r in responses] | |
return [0.5 if match else 0.0 for match in matches] | |
def count_xml(text) -> float: | |
count = 0.0 | |
if text.count("<reasoning>\n") == 1: | |
count += 0.125 | |
if text.count("\n</reasoning>\n") == 1: | |
count += 0.125 | |
if text.count("\n<answer>\n") == 1: | |
count += 0.125 | |
count -= len(text.split("\n</answer>\n")[-1])*0.001 | |
if text.count("\n</answer>") == 1: | |
count += 0.125 | |
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001 | |
return count | |
def xmlcount_reward_func(completions, **kwargs) -> list[float]: | |
contents = [completion[0]["content"] for completion in completions] | |
return [count_xml(c) for c in contents] | |
def main(model_args, training_args): | |
# peft_config = LoraConfig( | |
# r=16, | |
# lora_alpha=64, | |
# target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"], | |
# task_type="CAUSAL_LM", | |
# lora_dropout=0.05, | |
# ) | |
torch_dtype = ( | |
model_args.torch_dtype | |
if model_args.torch_dtype in ["auto", None] | |
else getattr(torch, model_args.torch_dtype) | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_args.model_name_or_path, | |
torch_dtype=torch_dtype, | |
attn_implementation="flash_attention_2", | |
) | |
model = model.to("cuda") | |
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) | |
tokenizer.pad_token = tokenizer.eos_token | |
# use peft at your own risk; not working for me with multi-GPU training | |
trainer = GRPOTrainer( | |
model=model, | |
processing_class=tokenizer, | |
reward_funcs=[ | |
xmlcount_reward_func, | |
soft_format_reward_func, | |
strict_format_reward_func, | |
int_reward_func, | |
correctness_reward_func], | |
args=training_args, | |
train_dataset=dataset, | |
#peft_config=peft_config | |
) | |
trainer.train() | |
if __name__ == "__main__": | |
parser = TrlParser((ModelArguments,GRPOConfig,)) | |
model_args, training_args, = parser.parse_args_and_config() | |
main(model_args, training_args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
CUDA_HOME=/usr/local/cuda | |
gpus=${gpus:-0,1,2,3,4,5,6,7} | |
dataset=${dataset:-sft_train_data} | |
output_dir=${output_dir:-experiments/$(date +"%Y%m%d_%H%M%S")} | |
port=$(shuf -i 10000-20000 -n 1) | |
deepspeed --include localhost:${gpus} --master_port=$port grpo_demo.py \ | |
--deepspeed "ds_zero2.json" \ | |
--model_name_or_path "path to Qwen2.5-1.5B-Instruct/" \ | |
--output_dir outputs/Qwen2.5-1.5B-GRPO-gsm8k \ | |
--run_name Qwen2.5-1.5B-GRPO-gsm8k \ | |
--learning_rate 1e-5 \ | |
--adam_beta1 0.9 \ | |
--adam_beta2 0.99 \ | |
--weight_decay 0.1 \ | |
--warmup_ratio 0.1 \ | |
--lr_scheduler_type cosine \ | |
--logging_steps 10 \ | |
--bf16 True \ | |
--per_device_train_batch_size 1 \ | |
--gradient_accumulation_steps 2 \ | |
--num_generations 16 \ | |
--max_prompt_length 512 \ | |
--max_completion_length 768 \ | |
--num_train_epochs 5 \ | |
--save_steps 100 \ | |
--max_grad_norm 0.1 \ | |
--report_to tensorboard \ | |
--log_on_each_node False |
Do you mind sharing what you're using for your ds_zero2.json?
Just the defaults here listed here? https://github.com/huggingface/transformers/blob/main/tests/deepspeed/ds_config_zero2.json
i need also. thanks for your contributions
ust the defaults here listed here
Yes, it's a common configuration.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Do you mind sharing what you're using for your ds_zero2.json?
Just the defaults here listed here?
https://github.com/huggingface/transformers/blob/main/tests/deepspeed/ds_config_zero2.json