Skip to content

Instantly share code, notes, and snippets.

@tadamcz
Created March 5, 2025 22:33
Show Gist options
  • Save tadamcz/7b3a45d0d5d35f758bb6deb8000731bb to your computer and use it in GitHub Desktop.
Save tadamcz/7b3a45d0d5d35f758bb6deb8000731bb to your computer and use it in GitHub Desktop.
Epoch AI implementation of FrontierMath
galois==0.4.4
gmpy2==2.2.1
mpmath==1.3.0
networkx==3.4.2
numpy==2.1.3
pyadic==0.2.3
scipy==1.15.2
sympy==1.13.3
import logging
import textwrap
from typing import Unpack
from inspect_ai.model import ChatMessageUser, ChatMessageTool, GenerateConfigArgs
from inspect_ai.solver import solver, TaskState, Generate, Solver
from inspect_ai.tool import tool, Tool, ToolResult, python, ToolFunction
from inspect_ai.util import store
from bench import PROJECT_ROOT
logger = logging.getLogger(__name__)
PYTHON_TOOL_TIMEOUT = 30
ANSWER_FUNC_TIMEOUT = 30
@solver
def frontiermath_agent(
token_limit: int = 100_000,
forced_submit_tokens: int = 66_000,
) -> Solver:
async def solve(state: TaskState, generate: Generate) -> TaskState:
# Validate token limits
if forced_submit_tokens >= token_limit:
raise ValueError("forced_submit_tokens must be less than token_limit")
# Initialize submitted answer in store
store().set("submitted_answer", None)
# Set token limit on state, so generate() will respect it
state.token_limit = token_limit
# Make both tools available from the start
state.tools = [python(timeout=PYTHON_TOOL_TIMEOUT), submit_answer()]
state.tool_choice = "auto"
# Display initial prompt with instructions
state.messages = [
ChatMessageUser(
content=initial_prompt(
question=state.user_prompt.text,
answer_type=state.metadata["answer_type"],
token_limit=token_limit,
forced_submit_tokens=forced_submit_tokens,
),
),
]
while state.token_usage < token_limit:
state = await generate(state, tool_calls="single")
# If an answer was submitted, we're done
if store().get("submitted_answer") is not None:
break
# If the forced submission threshold is reached, force a submit_answer tool call
if state.token_usage >= forced_submit_tokens:
# TODO: This fixes the Anthropic error "Thinking may not be enabled when tool_choice forces tool use"
# As of now, Inspect doesn't let you disable extended thinking in
# a call to generate (because it happens on the ``is_using_thinking`` method
# of the provider). So, instead of disabling thinking, we do not force the tool
# choice in this case. It's annoying that we have to introduce model-specific
# logic here, but I don't expect this to affect the results.
# We still prevent the model from using the python tool
# by setting the tools to only the submit_answer tool.
if state.model.name and state.model.name.startswith("claude-3-7-sonnet"):
state.tools = [submit_answer()]
else:
# Force a submit_answer tool call
state.tool_choice = ToolFunction("submit_answer")
warning_message = (
f"You have used {state.token_usage:,} tokens, which is above the forced "
f"submission threshold of {forced_submit_tokens:,} tokens. (Hard limit: {token_limit:,} tokens). "
f"You must submit your answer now using the `submit_answer` tool. "
)
state.messages.append(ChatMessageUser(content=warning_message))
continue
state = insert_tool_call_help_message(state)
# Otherwise, continue the conversation
message = (
f"You have used {state.token_usage:,} tokens (forced submission: {forced_submit_tokens:,} tokens, "
f"hard limit: {token_limit:,} tokens). "
f"Continue exploring the problem, or submit your answer "
f"using the `submit_answer` tool when you are confident in your answer."
)
state.messages.append(ChatMessageUser(content=message))
return state
return solve
@tool
def submit_answer() -> Tool:
async def execute(answer_fn_code: str) -> ToolResult:
"""
Submit your answer as a Python function named 'answer' that:
- takes no parameters
- returns your answer
- prints no output
- contains no code comments
For example, when the answer type is a Python integer:
def answer():
return 42
Args:
answer_fn_code (str): Python code defining the 'answer' function.
It should only define this function, do not call
the function or include any other code.
"""
store().set("submitted_answer", answer_fn_code)
return "Your answer has been recorded. No feedback is provided."
return execute
def insert_tool_call_help_message(state: TaskState) -> TaskState:
"""
Some models are very poor at using the tools despite our explicit instructions.
In the case of the Python tool, we can provide specific help for a common mistake.
"""
message = state.messages[-1]
if not isinstance(message, ChatMessageTool):
return state
if message.function == "python":
# Eliminate other reasons that there could be no output
# (not sure if output_limit belongs here, but might as well include it)
if message.error and message.error.type in [
"parsing",
"timeout",
"unicode_decode",
"output_limit",
]:
return state
output = str(message.content)
if output.strip() == "":
help_message = (
"I notice you used the Python tool but got no output. "
"Remember, the tool only shows stdout (and stderr). You must include "
"print() statements to see results.\n\n"
"❌ Incorrect (no output):\n"
"```python\n"
"x = 5 * 12 # Calculation happens but not shown\n"
"```\n\n"
"✅ Correct (output shown):\n"
"```python\n"
"x = 5 * 12\n"
"print(x) # Shows: 60\n"
"```\n"
)
state.messages.append(ChatMessageUser(content=help_message))
return state
def initial_prompt(
question: str, answer_type: str, token_limit: int, forced_submit_tokens: int
) -> str:
requirements_path = (
PROJECT_ROOT / "bench" / "task" / "frontiermath" / "agent-python-requirements.txt"
)
with open(requirements_path) as f:
requirements = f.read()
requirements = requirements.strip()
requirements = textwrap.indent(requirements, "\t" * 2)
return TEMPLATE.format(
question=question,
answer_type=answer_type,
token_limit=token_limit,
forced_submit_tokens=forced_submit_tokens,
requirements=requirements,
python_tool_timeout=PYTHON_TOOL_TIMEOUT,
answer_func_timeout=ANSWER_FUNC_TIMEOUT,
)
TEMPLATE = r"""
You will be solving a challenging mathematics question. Here's how it works:
1. You can:
- Think out loud and explore the problem
- Use the `python` tool to execute arbitrary Python code
- Submit your answer using the `submit_answer` tool when you are confident in your answer.
2. Token limits:
- There is a hard limit of {token_limit:,} tokens. If you exceed this, the conversation will end
immediately (even if you haven't submitted an answer).
- If you reach {forced_submit_tokens:,} tokens (but less than the hard limit of {token_limit:,}),
you will be forced to use the `submit_answer` tool in your next message. This forced submission
stage is designed to give you the best chance of submitting an answer before reaching the
hard token limit. But it is not a guarantee. It is still your responsibility to avoid hitting
the hard limit.
- Both input and output tokens count towards the limits.
3. Scoring:
- If your answer is correct you will get 1 point. If it is incorrect, or if you don't submit an
answer, you will get 0 points.
4. Explain your reasoning to me before submitting an answer.
5. Tips
- I strongly recommend that you start by making a high-level plan for how you will attack the
problem. If you can, think about different approaches that could be used to solve the problem.
To help you stay on track, periodically summarize your key findings and potentially revise
your plan.
- Before submitting, verify your answer satisfies all problem requirements. It may be worth trying
a different approach if you can see that your current answer is not correct.
6. For using the `submit_answer` tool:
- Pass in the code of a Python function named 'answer' that:
- Takes no parameters
- Returns your answer as a {answer_type}
- Prints no output
- Contains no code comments
- When scoring your answer, the maximum runtime for the `answer` function is
{answer_func_timeout} seconds. The code is executed on typical commodity hardware for the year
2025.
7. For using the `python` tool:
* The tool will only return stdout (and stderr), so you must make sure to
use print() to see your results. If you don't get any output from a `python` tool call, you
probably forgot to print.
* Example:
```python
x = 5 * 12
print("The result is", x)
```
In this example, you must include the print statement. Otherwise, you won't see the value of
x.
* The tool is completely stateless and doesn't come with anything pre-imported. This is very
important. If you need modules (e.g. math, sympy), you must import them each time. You cannot
access variables defined in a previous call to `python`, so you must re-define anything you
need in each call.
* You have access to the standard library, and the following libraries (expressed in
`requirements.txt` format):
```
{requirements}
```
* Do not submit your answer using the `python` tool. Use the `submit_answer`
tool when you're ready to submit
* The maximum runtime for a `python` tool call is {python_tool_timeout} seconds. The code is
executed on typical commodity hardware for the year 2025.
Here is the problem to solve. The answer type is {answer_type}.
{question}
"""
services:
default:
build:
context: "."
dockerfile: "Dockerfile"
command: "tail -f /dev/null"
init: true
network_mode: none
stop_grace_period: 1s
mem_limit: 512m
FROM python:3.10-bookworm
# Install system dependencies required for gmpy2
RUN apt-get update && apt-get install -y \
libmpfr-dev \
libmpc-dev \
libgmp-dev \
&& rm -rf /var/lib/apt/lists/*
COPY agent-python-requirements.txt .
RUN pip install -r agent-python-requirements.txt
import inspect
import sys
from inspect_ai.scorer import CORRECT, INCORRECT, Score, Target, accuracy, stderr, scorer, Scorer
from inspect_ai.solver import TaskState
from inspect_ai.util import sandbox
from .agent import ANSWER_FUNC_TIMEOUT
VERIFICATION_CODE_TIMEOUT = 120
# When `verify()` returns True or False (100-101)
EXIT_VERIFY_TRUE = 100
EXIT_VERIFY_FALSE = 101
# Errors in the `answer()` function (102-103)
EXIT_ANSWER_ERROR = 102
EXIT_ANSWER_TIMEOUT = 103
# Errors in the `verify()` function (104)
EXIT_VERIFY_ERROR = 104
@scorer(metrics=[accuracy(), stderr()], name="verification_code")
def verification_code_scorer() -> Scorer:
"""Score a FrontierMath problem submission."""
async def score(state: TaskState, target: Target) -> Score:
# Get the submitted answer
answer_code = state.store.get("submitted_answer")
if answer_code is None:
return Score(value=INCORRECT, explanation="No solution submitted")
answer_code_backticks = f"```python\n{answer_code}\n```"
scoring_code = f"""
import sys
{inspect.getsource(ExitOnException)}
{inspect.getsource(WithTimeout)}
{answer_code}
{state.metadata['verification_code']}
with ExitOnException(exit_code={EXIT_ANSWER_ERROR}), WithTimeout({ANSWER_FUNC_TIMEOUT}, {EXIT_ANSWER_TIMEOUT}):
a = answer()
with ExitOnException(exit_code={EXIT_VERIFY_ERROR}):
is_correct = verify(a)
if is_correct:
sys.exit({EXIT_VERIFY_TRUE})
else:
sys.exit({EXIT_VERIFY_FALSE})
"""
try:
# Execute the verification code in the sandbox
result = await sandbox().exec(
["python3", "-c", scoring_code], timeout=VERIFICATION_CODE_TIMEOUT
)
except TimeoutError:
return Score(
answer=answer_code_backticks,
value=INCORRECT,
explanation=f"Verification code execution timed out (timeout: {VERIFICATION_CODE_TIMEOUT} seconds)",
)
if result.returncode == EXIT_VERIFY_TRUE:
return Score(
answer=answer_code_backticks,
value=CORRECT,
explanation="The `verify()` function returned True",
)
elif result.returncode == EXIT_VERIFY_FALSE:
return Score(
answer=answer_code_backticks,
value=INCORRECT,
explanation="The `verify()` function returned False",
)
# Answer function errors
elif result.returncode == EXIT_ANSWER_ERROR:
return Score(
answer=answer_code_backticks,
value=INCORRECT,
explanation=f"Error while executing the `answer()` function:\n{result.stderr}",
)
elif result.returncode == EXIT_ANSWER_TIMEOUT:
return Score(
answer=answer_code_backticks,
value=INCORRECT,
explanation=f"Error while executing the `answer()` function: Timeout exceeded (timeout: {ANSWER_FUNC_TIMEOUT} seconds)",
)
# Verification function errors
elif result.returncode == EXIT_VERIFY_ERROR:
return Score(
answer=answer_code_backticks,
value=INCORRECT,
explanation=f"Error while executing the `verify()` function:\n{result.stderr}",
)
# Other errors with exit code 1
elif result.returncode == 1:
return Score(
answer=answer_code_backticks,
value=INCORRECT,
explanation=f"Error in verification code:\n{result.stderr}",
)
else:
return Score(
answer=answer_code_backticks,
value=INCORRECT,
explanation=f"Unknown failure in verification code:\n{result.stderr}",
)
return score
class ExitOnException:
"""
A context manager that catches exceptions (except SystemExit), prints the traceback to stderr,
and exits the program with a custom exit code.
Usage:
with ExitOnException(exit_code=2):
# Your code that might raise an exception
risky_operation()
"""
def __init__(self, exit_code=1):
"""
Initialize the context manager with a custom exit code.
Args:
exit_code (int): The exit code to use when an exception occurs.
Defaults to 1.
"""
self.exit_code = exit_code
def __enter__(self):
"""Enter the context."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""
Exit the context. If an exception occurred, print the traceback and exit.
Args:
exc_type: The type of the exception.
exc_val: The exception instance.
exc_tb: The traceback object.
Returns:
bool: True if the exception was handled, False otherwise.
"""
import sys
import traceback
# Don't catch SystemExit exceptions
if exc_type is not None and exc_type is not SystemExit:
# Print the traceback to stderr
traceback.print_exception(exc_type, exc_val, exc_tb, file=sys.stderr)
# Exit with the custom exit code
sys.exit(self.exit_code)
# If no exception occurred or it was SystemExit, return False to propagate
return False
class WithTimeout:
"""
A context manager that sets a signal-based timeout.
If the block takes too long, it raises SystemExit with a custom exit code.
"""
def __init__(self, timeout_sec: int, exit_code: int):
self.timeout_sec = timeout_sec
self.exit_code = exit_code
def _handle_alarm(self, signum, frame):
print("Timeout while calling answer()", file=sys.stderr)
sys.exit(self.exit_code)
def __enter__(self):
import signal
signal.signal(signal.SIGALRM, self._handle_alarm)
signal.alarm(self.timeout_sec)
def __exit__(self, exc_type, exc_val, exc_tb):
import signal
signal.alarm(0) # Disable the alarm
return False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment