Last active
February 25, 2025 12:14
-
-
Save tadamcz/faf4681e154be2e4c8a6579d67aca7d3 to your computer and use it in GitHub Desktop.
Epoch AI implementation of OTIS Mock AIME 2024-2025
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict, Any | |
from inspect_ai import Task, task, Epochs | |
from inspect_ai.dataset import Sample, hf_dataset | |
from inspect_ai.scorer import scorer, Score, Target, CORRECT, INCORRECT, accuracy, stderr | |
from inspect_ai.solver import generate, prompt_template, TaskState | |
from bench.model import default_grader_model | |
SOLUTION_TEMPLATE = """Please solve this AIME problem step by step. The answer is an integer ranging from | |
000 to 999, inclusive. | |
{prompt} | |
Remember to show your work clearly and end with 'ANSWER: X' where X is your final numerical answer.""" | |
GRADING_TEMPLATE = """You are a mathematics expert tasked with grading AIME solutions. You will be given: | |
1. A student's complete solution with their reasoning | |
2. The correct solution | |
Grade the student solution as either CORRECT or INCORRECT, based only on the **student's final answer**. | |
Only respond with a single word: either "CORRECT" or "INCORRECT". | |
Student Solution: | |
{student_solution} | |
Correct Solution: | |
{target} | |
Grade (CORRECT/INCORRECT):""" | |
def record_to_sample(record: Dict[str, Any]) -> Sample: | |
"""Convert a HuggingFace dataset record to a Sample object.""" | |
metadata = { | |
"solution": record["solution"], | |
} | |
return Sample( | |
input=record["input"], | |
target=record["target"], | |
id=record["id"], | |
metadata=metadata, | |
) | |
@scorer(metrics=[accuracy(), stderr()]) | |
def model_graded(): | |
async def score(state: TaskState, target: Target) -> Score: | |
student_solution = state.output.completion | |
grading_prompt = GRADING_TEMPLATE.format( | |
student_solution=student_solution, target=target.text | |
) | |
grader_model = default_grader_model() | |
grader_output = await grader_model.generate(grading_prompt) | |
grader_completion = grader_output.completion | |
explanation = f"The grader model responded with:\n{grader_completion}" | |
return Score( | |
value=CORRECT if grader_completion.strip().upper() == "CORRECT" else INCORRECT, | |
answer=student_solution, | |
explanation=explanation, | |
) | |
return score | |
@task(name="OTIS Mock AIME 2024-2025") | |
def otis_mock_aime_24_25(epochs: int = 16) -> Task: | |
dataset = hf_dataset( | |
path="EpochAI/otis-mock-aime-24-25", | |
split="train", | |
sample_fields=record_to_sample, | |
) | |
plan = [prompt_template(SOLUTION_TEMPLATE), generate()] | |
if epochs > 1: | |
k_values = [epochs] | |
reducers = ["mean"] | |
for k in k_values: | |
reducers.append(f"pass_at_{k}") | |
else: | |
reducers = ["mean"] | |
return Task( | |
dataset=dataset, | |
plan=plan, | |
scorer=model_graded(), | |
epochs=Epochs(epochs, reducers), | |
metadata={"inspect-log-public": True}, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment