Created
February 7, 2025 14:58
-
-
Save RicardoDominguez/51bda9cabaaefe71e0abed8c1c7ab0cf to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datasets | |
import transformers | |
import vllm | |
from tqdm import tqdm | |
model_dir = 'meta-llama/Llama-2-7b-chat-hf' | |
# prompt from R1 | |
system_prompt = "The user will ask you a question, and you should solve it. " \ | |
"You should first think about the reasoning process in the mind and then provide the user with the answer. " \ | |
"The reasoning process and answer must be enclosed within <think> </think> and " \ | |
"<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> " \ | |
"<answer> answer here </answer>." | |
# wait leads to many false positives, so be mindful of that | |
keywords = ['re-evaluate', 're-check', 'wait', 'reevaluate', 'recheck', 'check again', 'try again', 'think again', 'aha'] | |
# load tokenizer and model | |
tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir) | |
model = vllm.LLM( | |
model_dir, | |
max_model_len=1024, | |
enable_prefix_caching=True, | |
load_format='safetensors', | |
) | |
# set sampling parameters | |
n_samples = 100 | |
sample_params = {'max_tokens': 512, 'temperature': 1., 'n': n_samples} | |
sample_params = vllm.SamplingParams(**sample_params) | |
sample_kwargs = {'sampling_params': sample_params} | |
# download and process MATH-500 | |
def build_prompt(tokenizer, question): | |
messages = [ | |
{'role': 'system', 'content': system_prompt}, | |
{'role': 'user', 'content': question}, | |
] | |
prompt = tokenizer.apply_chat_template(messages, | |
tokenize=False, add_generation_prompt=True, add_special_tokens=True | |
) | |
return prompt | |
def process_dataset(example): | |
question = build_prompt(tokenizer, example['problem']) | |
answer = example['answer'] | |
return {'prompt': question, 'answer': answer} | |
dataset = datasets.load_dataset('ricdomolm/MATH-500')['test'] | |
test_set = dataset.map(process_dataset) | |
# loop and print matches (many false positives!!) | |
for i, example in tqdm(enumerate(test_set)): | |
output = model.generate(example['prompt'], **sample_kwargs, use_tqdm=False) | |
for j, o in enumerate(output[0].outputs): | |
for kw in keywords: | |
if kw in o.text.lower(): | |
print(i, o.text) | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment