Skip to content

Instantly share code, notes, and snippets.

@RicardoDominguez
Created February 7, 2025 14:58
Show Gist options
  • Save RicardoDominguez/51bda9cabaaefe71e0abed8c1c7ab0cf to your computer and use it in GitHub Desktop.
Save RicardoDominguez/51bda9cabaaefe71e0abed8c1c7ab0cf to your computer and use it in GitHub Desktop.
import datasets
import transformers
import vllm
from tqdm import tqdm
model_dir = 'meta-llama/Llama-2-7b-chat-hf'
# prompt from R1
system_prompt = "The user will ask you a question, and you should solve it. " \
"You should first think about the reasoning process in the mind and then provide the user with the answer. " \
"The reasoning process and answer must be enclosed within <think> </think> and " \
"<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> " \
"<answer> answer here </answer>."
# wait leads to many false positives, so be mindful of that
keywords = ['re-evaluate', 're-check', 'wait', 'reevaluate', 'recheck', 'check again', 'try again', 'think again', 'aha']
# load tokenizer and model
tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir)
model = vllm.LLM(
model_dir,
max_model_len=1024,
enable_prefix_caching=True,
load_format='safetensors',
)
# set sampling parameters
n_samples = 100
sample_params = {'max_tokens': 512, 'temperature': 1., 'n': n_samples}
sample_params = vllm.SamplingParams(**sample_params)
sample_kwargs = {'sampling_params': sample_params}
# download and process MATH-500
def build_prompt(tokenizer, question):
messages = [
{'role': 'system', 'content': system_prompt},
{'role': 'user', 'content': question},
]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False, add_generation_prompt=True, add_special_tokens=True
)
return prompt
def process_dataset(example):
question = build_prompt(tokenizer, example['problem'])
answer = example['answer']
return {'prompt': question, 'answer': answer}
dataset = datasets.load_dataset('ricdomolm/MATH-500')['test']
test_set = dataset.map(process_dataset)
# loop and print matches (many false positives!!)
for i, example in tqdm(enumerate(test_set)):
output = model.generate(example['prompt'], **sample_kwargs, use_tqdm=False)
for j, o in enumerate(output[0].outputs):
for kw in keywords:
if kw in o.text.lower():
print(i, o.text)
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment