Created
March 5, 2025 22:33
-
-
Save tadamcz/7b3a45d0d5d35f758bb6deb8000731bb to your computer and use it in GitHub Desktop.
Epoch AI implementation of FrontierMath
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
galois==0.4.4 | |
gmpy2==2.2.1 | |
mpmath==1.3.0 | |
networkx==3.4.2 | |
numpy==2.1.3 | |
pyadic==0.2.3 | |
scipy==1.15.2 | |
sympy==1.13.3 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import textwrap | |
from typing import Unpack | |
from inspect_ai.model import ChatMessageUser, ChatMessageTool, GenerateConfigArgs | |
from inspect_ai.solver import solver, TaskState, Generate, Solver | |
from inspect_ai.tool import tool, Tool, ToolResult, python, ToolFunction | |
from inspect_ai.util import store | |
from bench import PROJECT_ROOT | |
logger = logging.getLogger(__name__) | |
PYTHON_TOOL_TIMEOUT = 30 | |
ANSWER_FUNC_TIMEOUT = 30 | |
@solver | |
def frontiermath_agent( | |
token_limit: int = 100_000, | |
forced_submit_tokens: int = 66_000, | |
) -> Solver: | |
async def solve(state: TaskState, generate: Generate) -> TaskState: | |
# Validate token limits | |
if forced_submit_tokens >= token_limit: | |
raise ValueError("forced_submit_tokens must be less than token_limit") | |
# Initialize submitted answer in store | |
store().set("submitted_answer", None) | |
# Set token limit on state, so generate() will respect it | |
state.token_limit = token_limit | |
# Make both tools available from the start | |
state.tools = [python(timeout=PYTHON_TOOL_TIMEOUT), submit_answer()] | |
state.tool_choice = "auto" | |
# Display initial prompt with instructions | |
state.messages = [ | |
ChatMessageUser( | |
content=initial_prompt( | |
question=state.user_prompt.text, | |
answer_type=state.metadata["answer_type"], | |
token_limit=token_limit, | |
forced_submit_tokens=forced_submit_tokens, | |
), | |
), | |
] | |
while state.token_usage < token_limit: | |
state = await generate(state, tool_calls="single") | |
# If an answer was submitted, we're done | |
if store().get("submitted_answer") is not None: | |
break | |
# If the forced submission threshold is reached, force a submit_answer tool call | |
if state.token_usage >= forced_submit_tokens: | |
# TODO: This fixes the Anthropic error "Thinking may not be enabled when tool_choice forces tool use" | |
# As of now, Inspect doesn't let you disable extended thinking in | |
# a call to generate (because it happens on the ``is_using_thinking`` method | |
# of the provider). So, instead of disabling thinking, we do not force the tool | |
# choice in this case. It's annoying that we have to introduce model-specific | |
# logic here, but I don't expect this to affect the results. | |
# We still prevent the model from using the python tool | |
# by setting the tools to only the submit_answer tool. | |
if state.model.name and state.model.name.startswith("claude-3-7-sonnet"): | |
state.tools = [submit_answer()] | |
else: | |
# Force a submit_answer tool call | |
state.tool_choice = ToolFunction("submit_answer") | |
warning_message = ( | |
f"You have used {state.token_usage:,} tokens, which is above the forced " | |
f"submission threshold of {forced_submit_tokens:,} tokens. (Hard limit: {token_limit:,} tokens). " | |
f"You must submit your answer now using the `submit_answer` tool. " | |
) | |
state.messages.append(ChatMessageUser(content=warning_message)) | |
continue | |
state = insert_tool_call_help_message(state) | |
# Otherwise, continue the conversation | |
message = ( | |
f"You have used {state.token_usage:,} tokens (forced submission: {forced_submit_tokens:,} tokens, " | |
f"hard limit: {token_limit:,} tokens). " | |
f"Continue exploring the problem, or submit your answer " | |
f"using the `submit_answer` tool when you are confident in your answer." | |
) | |
state.messages.append(ChatMessageUser(content=message)) | |
return state | |
return solve | |
@tool | |
def submit_answer() -> Tool: | |
async def execute(answer_fn_code: str) -> ToolResult: | |
""" | |
Submit your answer as a Python function named 'answer' that: | |
- takes no parameters | |
- returns your answer | |
- prints no output | |
- contains no code comments | |
For example, when the answer type is a Python integer: | |
def answer(): | |
return 42 | |
Args: | |
answer_fn_code (str): Python code defining the 'answer' function. | |
It should only define this function, do not call | |
the function or include any other code. | |
""" | |
store().set("submitted_answer", answer_fn_code) | |
return "Your answer has been recorded. No feedback is provided." | |
return execute | |
def insert_tool_call_help_message(state: TaskState) -> TaskState: | |
""" | |
Some models are very poor at using the tools despite our explicit instructions. | |
In the case of the Python tool, we can provide specific help for a common mistake. | |
""" | |
message = state.messages[-1] | |
if not isinstance(message, ChatMessageTool): | |
return state | |
if message.function == "python": | |
# Eliminate other reasons that there could be no output | |
# (not sure if output_limit belongs here, but might as well include it) | |
if message.error and message.error.type in [ | |
"parsing", | |
"timeout", | |
"unicode_decode", | |
"output_limit", | |
]: | |
return state | |
output = str(message.content) | |
if output.strip() == "": | |
help_message = ( | |
"I notice you used the Python tool but got no output. " | |
"Remember, the tool only shows stdout (and stderr). You must include " | |
"print() statements to see results.\n\n" | |
"❌ Incorrect (no output):\n" | |
"```python\n" | |
"x = 5 * 12 # Calculation happens but not shown\n" | |
"```\n\n" | |
"✅ Correct (output shown):\n" | |
"```python\n" | |
"x = 5 * 12\n" | |
"print(x) # Shows: 60\n" | |
"```\n" | |
) | |
state.messages.append(ChatMessageUser(content=help_message)) | |
return state | |
def initial_prompt( | |
question: str, answer_type: str, token_limit: int, forced_submit_tokens: int | |
) -> str: | |
requirements_path = ( | |
PROJECT_ROOT / "bench" / "task" / "frontiermath" / "agent-python-requirements.txt" | |
) | |
with open(requirements_path) as f: | |
requirements = f.read() | |
requirements = requirements.strip() | |
requirements = textwrap.indent(requirements, "\t" * 2) | |
return TEMPLATE.format( | |
question=question, | |
answer_type=answer_type, | |
token_limit=token_limit, | |
forced_submit_tokens=forced_submit_tokens, | |
requirements=requirements, | |
python_tool_timeout=PYTHON_TOOL_TIMEOUT, | |
answer_func_timeout=ANSWER_FUNC_TIMEOUT, | |
) | |
TEMPLATE = r""" | |
You will be solving a challenging mathematics question. Here's how it works: | |
1. You can: | |
- Think out loud and explore the problem | |
- Use the `python` tool to execute arbitrary Python code | |
- Submit your answer using the `submit_answer` tool when you are confident in your answer. | |
2. Token limits: | |
- There is a hard limit of {token_limit:,} tokens. If you exceed this, the conversation will end | |
immediately (even if you haven't submitted an answer). | |
- If you reach {forced_submit_tokens:,} tokens (but less than the hard limit of {token_limit:,}), | |
you will be forced to use the `submit_answer` tool in your next message. This forced submission | |
stage is designed to give you the best chance of submitting an answer before reaching the | |
hard token limit. But it is not a guarantee. It is still your responsibility to avoid hitting | |
the hard limit. | |
- Both input and output tokens count towards the limits. | |
3. Scoring: | |
- If your answer is correct you will get 1 point. If it is incorrect, or if you don't submit an | |
answer, you will get 0 points. | |
4. Explain your reasoning to me before submitting an answer. | |
5. Tips | |
- I strongly recommend that you start by making a high-level plan for how you will attack the | |
problem. If you can, think about different approaches that could be used to solve the problem. | |
To help you stay on track, periodically summarize your key findings and potentially revise | |
your plan. | |
- Before submitting, verify your answer satisfies all problem requirements. It may be worth trying | |
a different approach if you can see that your current answer is not correct. | |
6. For using the `submit_answer` tool: | |
- Pass in the code of a Python function named 'answer' that: | |
- Takes no parameters | |
- Returns your answer as a {answer_type} | |
- Prints no output | |
- Contains no code comments | |
- When scoring your answer, the maximum runtime for the `answer` function is | |
{answer_func_timeout} seconds. The code is executed on typical commodity hardware for the year | |
2025. | |
7. For using the `python` tool: | |
* The tool will only return stdout (and stderr), so you must make sure to | |
use print() to see your results. If you don't get any output from a `python` tool call, you | |
probably forgot to print. | |
* Example: | |
```python | |
x = 5 * 12 | |
print("The result is", x) | |
``` | |
In this example, you must include the print statement. Otherwise, you won't see the value of | |
x. | |
* The tool is completely stateless and doesn't come with anything pre-imported. This is very | |
important. If you need modules (e.g. math, sympy), you must import them each time. You cannot | |
access variables defined in a previous call to `python`, so you must re-define anything you | |
need in each call. | |
* You have access to the standard library, and the following libraries (expressed in | |
`requirements.txt` format): | |
``` | |
{requirements} | |
``` | |
* Do not submit your answer using the `python` tool. Use the `submit_answer` | |
tool when you're ready to submit | |
* The maximum runtime for a `python` tool call is {python_tool_timeout} seconds. The code is | |
executed on typical commodity hardware for the year 2025. | |
Here is the problem to solve. The answer type is {answer_type}. | |
{question} | |
""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
services: | |
default: | |
build: | |
context: "." | |
dockerfile: "Dockerfile" | |
command: "tail -f /dev/null" | |
init: true | |
network_mode: none | |
stop_grace_period: 1s | |
mem_limit: 512m |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
FROM python:3.10-bookworm | |
# Install system dependencies required for gmpy2 | |
RUN apt-get update && apt-get install -y \ | |
libmpfr-dev \ | |
libmpc-dev \ | |
libgmp-dev \ | |
&& rm -rf /var/lib/apt/lists/* | |
COPY agent-python-requirements.txt . | |
RUN pip install -r agent-python-requirements.txt |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
import sys | |
from inspect_ai.scorer import CORRECT, INCORRECT, Score, Target, accuracy, stderr, scorer, Scorer | |
from inspect_ai.solver import TaskState | |
from inspect_ai.util import sandbox | |
from .agent import ANSWER_FUNC_TIMEOUT | |
VERIFICATION_CODE_TIMEOUT = 120 | |
# When `verify()` returns True or False (100-101) | |
EXIT_VERIFY_TRUE = 100 | |
EXIT_VERIFY_FALSE = 101 | |
# Errors in the `answer()` function (102-103) | |
EXIT_ANSWER_ERROR = 102 | |
EXIT_ANSWER_TIMEOUT = 103 | |
# Errors in the `verify()` function (104) | |
EXIT_VERIFY_ERROR = 104 | |
@scorer(metrics=[accuracy(), stderr()], name="verification_code") | |
def verification_code_scorer() -> Scorer: | |
"""Score a FrontierMath problem submission.""" | |
async def score(state: TaskState, target: Target) -> Score: | |
# Get the submitted answer | |
answer_code = state.store.get("submitted_answer") | |
if answer_code is None: | |
return Score(value=INCORRECT, explanation="No solution submitted") | |
answer_code_backticks = f"```python\n{answer_code}\n```" | |
scoring_code = f""" | |
import sys | |
{inspect.getsource(ExitOnException)} | |
{inspect.getsource(WithTimeout)} | |
{answer_code} | |
{state.metadata['verification_code']} | |
with ExitOnException(exit_code={EXIT_ANSWER_ERROR}), WithTimeout({ANSWER_FUNC_TIMEOUT}, {EXIT_ANSWER_TIMEOUT}): | |
a = answer() | |
with ExitOnException(exit_code={EXIT_VERIFY_ERROR}): | |
is_correct = verify(a) | |
if is_correct: | |
sys.exit({EXIT_VERIFY_TRUE}) | |
else: | |
sys.exit({EXIT_VERIFY_FALSE}) | |
""" | |
try: | |
# Execute the verification code in the sandbox | |
result = await sandbox().exec( | |
["python3", "-c", scoring_code], timeout=VERIFICATION_CODE_TIMEOUT | |
) | |
except TimeoutError: | |
return Score( | |
answer=answer_code_backticks, | |
value=INCORRECT, | |
explanation=f"Verification code execution timed out (timeout: {VERIFICATION_CODE_TIMEOUT} seconds)", | |
) | |
if result.returncode == EXIT_VERIFY_TRUE: | |
return Score( | |
answer=answer_code_backticks, | |
value=CORRECT, | |
explanation="The `verify()` function returned True", | |
) | |
elif result.returncode == EXIT_VERIFY_FALSE: | |
return Score( | |
answer=answer_code_backticks, | |
value=INCORRECT, | |
explanation="The `verify()` function returned False", | |
) | |
# Answer function errors | |
elif result.returncode == EXIT_ANSWER_ERROR: | |
return Score( | |
answer=answer_code_backticks, | |
value=INCORRECT, | |
explanation=f"Error while executing the `answer()` function:\n{result.stderr}", | |
) | |
elif result.returncode == EXIT_ANSWER_TIMEOUT: | |
return Score( | |
answer=answer_code_backticks, | |
value=INCORRECT, | |
explanation=f"Error while executing the `answer()` function: Timeout exceeded (timeout: {ANSWER_FUNC_TIMEOUT} seconds)", | |
) | |
# Verification function errors | |
elif result.returncode == EXIT_VERIFY_ERROR: | |
return Score( | |
answer=answer_code_backticks, | |
value=INCORRECT, | |
explanation=f"Error while executing the `verify()` function:\n{result.stderr}", | |
) | |
# Other errors with exit code 1 | |
elif result.returncode == 1: | |
return Score( | |
answer=answer_code_backticks, | |
value=INCORRECT, | |
explanation=f"Error in verification code:\n{result.stderr}", | |
) | |
else: | |
return Score( | |
answer=answer_code_backticks, | |
value=INCORRECT, | |
explanation=f"Unknown failure in verification code:\n{result.stderr}", | |
) | |
return score | |
class ExitOnException: | |
""" | |
A context manager that catches exceptions (except SystemExit), prints the traceback to stderr, | |
and exits the program with a custom exit code. | |
Usage: | |
with ExitOnException(exit_code=2): | |
# Your code that might raise an exception | |
risky_operation() | |
""" | |
def __init__(self, exit_code=1): | |
""" | |
Initialize the context manager with a custom exit code. | |
Args: | |
exit_code (int): The exit code to use when an exception occurs. | |
Defaults to 1. | |
""" | |
self.exit_code = exit_code | |
def __enter__(self): | |
"""Enter the context.""" | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
""" | |
Exit the context. If an exception occurred, print the traceback and exit. | |
Args: | |
exc_type: The type of the exception. | |
exc_val: The exception instance. | |
exc_tb: The traceback object. | |
Returns: | |
bool: True if the exception was handled, False otherwise. | |
""" | |
import sys | |
import traceback | |
# Don't catch SystemExit exceptions | |
if exc_type is not None and exc_type is not SystemExit: | |
# Print the traceback to stderr | |
traceback.print_exception(exc_type, exc_val, exc_tb, file=sys.stderr) | |
# Exit with the custom exit code | |
sys.exit(self.exit_code) | |
# If no exception occurred or it was SystemExit, return False to propagate | |
return False | |
class WithTimeout: | |
""" | |
A context manager that sets a signal-based timeout. | |
If the block takes too long, it raises SystemExit with a custom exit code. | |
""" | |
def __init__(self, timeout_sec: int, exit_code: int): | |
self.timeout_sec = timeout_sec | |
self.exit_code = exit_code | |
def _handle_alarm(self, signum, frame): | |
print("Timeout while calling answer()", file=sys.stderr) | |
sys.exit(self.exit_code) | |
def __enter__(self): | |
import signal | |
signal.signal(signal.SIGALRM, self._handle_alarm) | |
signal.alarm(self.timeout_sec) | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
import signal | |
signal.alarm(0) # Disable the alarm | |
return False |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment