Created
February 3, 2025 20:55
-
-
Save tadamcz/e639d7df0663a9c2e15ac7f97275b01b to your computer and use it in GitHub Desktop.
Epoch AI MATH implementation (based on inspect_evals)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional, List | |
from inspect_ai._eval.registry import task | |
from inspect_ai._eval.task import Task | |
from inspect_ai.dataset._sources.hf import hf_dataset | |
from bench.model import DEFAULT_GRADER_MODEL | |
from bench.task.hendrycks_math.dataset import filter_dataset, record_to_sample | |
from bench.task.hendrycks_math.scorer import ( | |
normalized_string_match, | |
sympy_equiv, | |
model_graded_equiv, | |
) | |
from bench.task.hendrycks_math.solver import math_solver | |
@task | |
def hendrycks_math( | |
levels: Optional[List[str]] = None, | |
subjects: Optional[List[str]] = None, | |
fewshot: int = 0, | |
fewshot_seed: int = 42, | |
) -> Task: | |
if levels is None: | |
levels = [] | |
if subjects is None: | |
subjects = [] | |
dataset = hf_dataset( | |
# TODO: change this back to the official dataset if DMCA issues are resolved | |
"tadamcz/hendrycks___competition_math", | |
split="test", | |
trust=True, | |
sample_fields=record_to_sample, | |
shuffle=True, | |
auto_id=True, | |
) | |
# Subset the data based on levels and/or subjects | |
dataset = filter_dataset(dataset=dataset, levels=levels, subjects=subjects) | |
scorers = [ | |
normalized_string_match(), | |
sympy_equiv(), | |
model_graded_equiv(model=DEFAULT_GRADER_MODEL), | |
] | |
return Task( | |
dataset=dataset, | |
plan=math_solver(fewshot=fewshot, fewshot_seed=fewshot_seed), | |
scorer=scorers, | |
epochs=8, | |
metadata={"inspect-log-public": True}, | |
) | |
@task(name="MATH level 5") | |
def hendrycks_math_lvl_5() -> Task: | |
return hendrycks_math(levels=["5"]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict | |
from inspect_ai.dataset import Dataset, Sample | |
from bench.task.hendrycks_math.scorer import remove_boxed, last_boxed_only_string | |
def filter_dataset(dataset: Dataset, levels: list, subjects: list) -> Dataset: | |
"""Filters the MATH dataset by levels and/or subjects. | |
Arguments: | |
dataset (Dataset): Dataset object to be filtered. | |
levels (List): List of levels to filter on, 1 to 5. | |
subjects (List): List of subjects to filter on. | |
""" | |
# Filter dataset by levels, if required | |
levels = levels if isinstance(levels, list) else [levels] | |
levels = [str(elm) for elm in levels] | |
if len(levels) > 0: | |
dataset = dataset.filter( | |
predicate=lambda sample: sample.metadata["level"] in levels | |
if sample.metadata is not None | |
else False, | |
) | |
# Filter dataset by subjects, if required | |
subjects = subjects if isinstance(subjects, list) else [subjects] | |
if len(subjects) > 0: | |
dataset = dataset.filter( | |
predicate=lambda sample: sample.metadata["subject"] in subjects | |
if sample.metadata is not None | |
else False, | |
) | |
return dataset | |
def record_to_sample(record: Dict) -> Sample: | |
return Sample( | |
input=record["problem"], | |
target=remove_boxed(last_boxed_only_string(record["solution"])), | |
metadata={ | |
"level": record["level"].lower().lstrip("level "), | |
"subject": record["type"].lower(), | |
"solution": record["solution"], | |
}, | |
) | |
def sample_to_fewshot(sample: Sample) -> str: | |
# Based on https://arxiv.org/pdf/2206.14858 - Appendix D.2 | |
# Tags are capitalized to match the format of the user prompt | |
prob_str = f"""PROBLEM:\n{sample.input}""" | |
soln = sample.metadata["solution"] if sample.metadata is not None else None | |
assert ( | |
soln is not None | |
), "Solution not found in sample, make sure to include it in the 'sample.metadata' dict." | |
soln_str = f"""SOLUTION:\n{soln}""" | |
ans_str = f"""ANSWER: {sample.target}""" | |
return f"""{prob_str}\n\n{soln_str}\n{ans_str}""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import re | |
import signal | |
import sympy # type: ignore | |
from inspect_ai.model import Model | |
from inspect_ai.model import get_model | |
from inspect_ai.scorer import ( | |
Score, | |
AnswerPattern, | |
CORRECT, | |
INCORRECT, | |
) | |
from inspect_ai.scorer import ( | |
Target, | |
accuracy, | |
scorer, | |
stderr, | |
) | |
from inspect_ai.solver import TaskState | |
from sympy.parsing.latex import parse_latex # type: ignore | |
logger = logging.getLogger(__name__) | |
EQUIVALANCE_TEMPLATE = r""" | |
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications | |
Examples: | |
Expression 1: $2x+3$ | |
Expression 2: $3+2x$ | |
Yes | |
Expression 1: 3/2 | |
Expression 2: 1.5 | |
Yes | |
Expression 1: $x^2+2x+1$ | |
Expression 2: $y^2+2y+1$ | |
No | |
Expression 1: $x^2+2x+1$ | |
Expression 2: $(x+1)^2$ | |
Yes | |
Expression 1: 3245/5 | |
Expression 2: 649 | |
No | |
(these are actually equal, don't mark them equivalent if you need to | |
do nontrivial simplifications) | |
Expression 1: 2/(-3) | |
Expression 2: -2/3 | |
Yes | |
(trivial simplifications are allowed) | |
Expression 1: 72 degrees | |
Expression 2: 72 | |
Yes | |
(give benefit of the doubt to units) | |
Expression 1: 64 | |
Expression 2: 64 square feet | |
Yes | |
(give benefit of the doubt to units) | |
--- | |
YOUR TASK | |
Respond with only "Yes" or "No" (without quotes). Do not include a rationale. | |
Expression 1: %(expression1)s | |
Expression 2: %(expression2)s | |
""".strip() | |
SUBSTITUTIONS = [ | |
("an ", ""), | |
("a ", ""), | |
(".$", "$"), | |
("\\$", ""), | |
(r"\ ", ""), | |
(" ", ""), | |
("mbox", "text"), | |
(",\\text{and}", ","), | |
("\\text{and}", ","), | |
("\\text{m}", "\\text{}"), | |
] | |
REMOVED_EXPRESSIONS = [ | |
"square", | |
"ways", | |
"integers", | |
"dollars", | |
"mph", | |
"inches", | |
"ft", | |
"hours", | |
"km", | |
"units", | |
"\\ldots", | |
"sue", | |
"points", | |
"feet", | |
"minutes", | |
"digits", | |
"cents", | |
"degrees", | |
"cm", | |
"gm", | |
"pounds", | |
"meters", | |
"meals", | |
"edges", | |
"students", | |
"childrentickets", | |
"multiples", | |
"\\text{s}", | |
"\\text{.}", | |
"\\text{\ns}", | |
"\\text{}^2", | |
"\\text{}^3", | |
"\\text{\n}", | |
"\\text{}", | |
r"\mathrm{th}", | |
r"^\circ", | |
r"^{\circ}", | |
r"\;", | |
r",\!", | |
"{,}", | |
'"', | |
"\\dots", | |
] | |
def _extract_answer_helper(text: str) -> str | None: | |
"""Helper function to extract answer from a text string after finding 'ANSWER'.""" | |
try: | |
# Find start of answer after "ANSWER" | |
answer_start = text.index("ANSWER") + 6 | |
# Skip optional colon and spaces | |
while answer_start < len(text) and ( | |
text[answer_start] == ":" or text[answer_start].isspace() | |
): | |
answer_start += 1 | |
answer = text[answer_start:].strip() | |
return answer if answer else None | |
except ValueError: | |
return None | |
def extract_answer(completion: str) -> str | None: | |
"""Extract answer from model completion using string manipulation. | |
Args: | |
completion: The model completion text to extract from | |
Returns: | |
The extracted answer string, or None if no answer found | |
""" | |
# Try LaTeX text pattern first | |
if "\\text{ANSWER" in completion: | |
try: | |
# Find start of answer text | |
start_idx = completion.index("\\text{ANSWER") | |
text_end = completion.index("}", start_idx) | |
# Get the answer text content | |
answer_text = completion[start_idx : text_end + 1] | |
# Try to get answer from inside the braces | |
in_brace_answer = _extract_answer_helper(answer_text[:-1]) | |
if in_brace_answer: | |
return in_brace_answer | |
# Otherwise look for content after the closing brace | |
after_brace = completion[text_end + 1 :].strip() | |
if "\n" in after_brace: | |
after_brace = after_brace[: after_brace.index("\n")] | |
return after_brace if after_brace else None | |
except ValueError: | |
pass | |
lines = completion.split("\n") | |
# Try markdown bold pattern | |
for line in lines: | |
line = line.strip() | |
if line.startswith("**") and line.endswith("**") and "ANSWER" in line: | |
# Remove the markdown bold markers | |
line = line[2:-2].strip() | |
answer = _extract_answer_helper(line) | |
if answer: | |
return answer | |
continue | |
# Try simple line pattern | |
for line in lines: | |
if "ANSWER" in line: | |
answer = _extract_answer_helper(line) | |
if answer: | |
return answer | |
continue | |
return None | |
async def score_helper( | |
state: TaskState, | |
target: Target, | |
model_graded: bool, | |
use_sympy: bool = False, | |
model: Model | None = None, | |
) -> Score: | |
answer = extract_answer(state.output.completion) | |
if answer: | |
if not model_graded: | |
correct = await match_helper( | |
answer=answer, | |
target=target, | |
use_sympy=use_sympy, | |
) | |
# Ask grader model to judge equivalence | |
else: | |
if model is None: | |
raise ValueError("Model is required for model graded scoring") | |
prompt = EQUIVALANCE_TEMPLATE % ({"expression1": target.text, "expression2": answer}) | |
result = await model.generate(prompt) | |
# Return the score | |
correct = result.completion.strip().lower() == "yes" | |
score = Score( | |
value=CORRECT if correct else INCORRECT, | |
explanation=state.output.completion, | |
answer=answer, | |
) | |
if model_graded and score.metadata is not None: | |
score.metadata.update({"grader_model_usage": result.usage}) | |
else: | |
score = Score( | |
value=INCORRECT, | |
explanation="Answer not found in model output: " + f"{state.output.completion}", | |
) | |
return score | |
# From here till normalize_final_answer() is borrowed from: | |
# https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/minerva_math/utils.py#L144 | |
class timeout: | |
def __init__(self, seconds=1, error_message="Timeout"): | |
self.seconds = seconds | |
self.error_message = error_message | |
def handle_timeout(self, signum, frame): | |
raise TimeoutError(self.error_message) | |
def __enter__(self): | |
signal.signal(signal.SIGALRM, self.handle_timeout) | |
signal.alarm(self.seconds) | |
def __exit__(self, type, value, traceback): | |
signal.alarm(0) | |
async def is_equiv_sympy(x1: str, x2: str) -> bool: | |
"""x1 and x2 are normalized latex string""" | |
try: | |
with timeout(seconds=5): | |
try: | |
parsed_x1 = parse_latex(x1) | |
parsed_x2 = parse_latex(x2) | |
except ( | |
sympy.parsing.latex.errors.LaTeXParsingError, | |
sympy.SympifyError, | |
TypeError, | |
) as e: | |
logger.debug(f"Couldn't parse one of {x1} or {x2}: {e}") | |
return False | |
try: | |
diff = parsed_x1 - parsed_x2 | |
except TypeError: | |
logger.debug(f"Couldn't subtract {x1} and {x2}") | |
return False | |
try: | |
if sympy.simplify(diff) == 0: | |
return True | |
else: | |
return False | |
except (ValueError, TypeError, ZeroDivisionError): | |
logger.debug(f"Had some trouble simplifying when comparing {x1} and {x2}") | |
return False | |
except TimeoutError: | |
logger.debug(f"Timed out comparing {x1} and {x2}") | |
return False | |
except RecursionError: | |
logger.debug(f"Recursion error comparing {x1} and {x2}") | |
return False | |
async def normalize_final_answer(final_answer: str) -> str: | |
""" | |
Normalize a final answer to a quantitative reasoning question. | |
Copied character for character from appendix D of Lewkowycz et al. (2022) | |
""" | |
final_answer = final_answer.split("=")[-1] | |
for before, after in SUBSTITUTIONS: | |
final_answer = final_answer.replace(before, after) | |
for expr in REMOVED_EXPRESSIONS: | |
final_answer = final_answer.replace(expr, "") | |
# Extract answer that is in LaTeX math, is bold, | |
# is surrounded by a box, etc. | |
final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) | |
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) | |
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) | |
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) | |
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) | |
try: | |
# If surrounded by `\(` and `\)`, remove them | |
if final_answer[:2] == "\\(" and final_answer[-2:] == "\\)": | |
final_answer = final_answer[2:-2] | |
except IndexError: | |
pass | |
# Normalize shorthand TeX: | |
# \fracab -> \frac{a}{b} | |
# \frac{abc}{bef} -> \frac{abc}{bef} | |
# \fracabc -> \frac{a}{b}c | |
# \sqrta -> \sqrt{a} | |
# \sqrtab -> sqrt{a}b | |
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) | |
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) | |
final_answer = final_answer.replace("$", "") | |
# Normalize 100,000 -> 100000 | |
if final_answer.replace(",", "").isdigit(): | |
final_answer = final_answer.replace(",", "") | |
return final_answer | |
async def is_equiv(str1: str | None, str2: str | None) -> bool: | |
if str1 is None and str2 is None: | |
logger.debug("WARNING: Both None") | |
return True | |
if str1 is None or str2 is None: | |
return False | |
ss1 = await strip_string(str1) | |
ss2 = await strip_string(str2) | |
return ss1 == ss2 | |
async def strip_string(string): | |
# linebreaks | |
string = string.replace("\n", "") | |
# remove inverse spaces | |
string = string.replace("\\!", "") | |
# replace \\ with \ | |
string = string.replace("\\\\", "\\") | |
# replace tfrac and dfrac with frac | |
string = string.replace("tfrac", "frac") | |
string = string.replace("dfrac", "frac") | |
# remove \left and \right | |
string = string.replace("\\left", "") | |
string = string.replace("\\right", "") | |
# Remove circ (degrees) | |
string = string.replace("^{\\circ}", "") | |
string = string.replace("^\\circ", "") | |
# remove dollar signs | |
string = string.replace("\\$", "") | |
# remove units (on the right) | |
string = await remove_right_units(string) | |
# remove percentage | |
string = string.replace("\\%", "") | |
string = string.replace("\%", "") # noqa: W605 | |
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string | |
string = string.replace(" .", " 0.") | |
string = string.replace("{.", "{0.") | |
# if empty, return empty string | |
if len(string) == 0: | |
return string | |
if string[0] == ".": | |
string = "0" + string | |
# to consider: get rid of e.g. "k = " or "q = " at beginning | |
if len(string.split("=")) == 2: | |
if len(string.split("=")[0]) <= 2: | |
string = string.split("=")[1] | |
# fix sqrt3 --> sqrt{3} | |
string = await fix_sqrt(string) | |
# remove spaces | |
string = string.replace(" ", "") | |
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} | |
string = await fix_fracs(string) | |
# manually change 0.5 --> \frac{1}{2} | |
if string == "0.5": | |
string = "\\frac{1}{2}" | |
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y | |
string = await fix_a_slash_b(string) | |
return string | |
async def fix_fracs(string): | |
substrs = string.split("\\frac") | |
new_str = substrs[0] | |
if len(substrs) > 1: | |
substrs = substrs[1:] | |
for substr in substrs: | |
new_str += "\\frac" | |
if substr[0] == "{": | |
new_str += substr | |
else: | |
try: | |
assert len(substr) >= 2 | |
except AssertionError: | |
return string | |
a = substr[0] | |
b = substr[1] | |
if b != "{": | |
if len(substr) > 2: | |
post_substr = substr[2:] | |
new_str += "{" + a + "}{" + b + "}" + post_substr | |
else: | |
new_str += "{" + a + "}{" + b + "}" | |
else: | |
if len(substr) > 2: | |
post_substr = substr[2:] | |
new_str += "{" + a + "}" + b + post_substr | |
else: | |
new_str += "{" + a + "}" + b | |
string = new_str | |
return string | |
async def fix_a_slash_b(string): | |
if len(string.split("/")) != 2: | |
return string | |
a = string.split("/")[0] | |
b = string.split("/")[1] | |
try: | |
a = int(a) | |
b = int(b) | |
assert string == "{}/{}".format(a, b) | |
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" | |
return new_string | |
except (AssertionError, ValueError): | |
return string | |
async def remove_right_units(string): | |
# "\\text{ " only ever occurs (at least in the val set) when describing units | |
if "\\text{ " in string: | |
splits = string.split("\\text{ ") | |
assert len(splits) == 2 | |
return splits[0] | |
else: | |
return string | |
async def fix_sqrt(string): | |
if "\\sqrt" not in string: | |
return string | |
splits = string.split("\\sqrt") | |
new_string = splits[0] | |
for split in splits[1:]: | |
if split[0] != "{": | |
a = split[0] | |
new_substr = "\\sqrt{" + a + "}" + split[1:] | |
else: | |
new_substr = "\\sqrt" + split | |
new_string += new_substr | |
return new_string | |
def remove_boxed(s): | |
if "\\boxed " in s: | |
left = "\\boxed " | |
assert s[: len(left)] == left | |
return s[len(left) :] | |
left = "\\boxed{" | |
assert s[: len(left)] == left | |
assert s[-1] == "}" | |
return s[len(left) : -1] | |
def last_boxed_only_string(string): | |
idx = string.rfind("\\boxed") | |
if "\\boxed " in string: | |
return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] | |
if idx < 0: | |
idx = string.rfind("\\fbox") | |
if idx < 0: | |
return None | |
i = idx | |
right_brace_idx = None | |
num_left_braces_open = 0 | |
while i < len(string): | |
if string[i] == "{": | |
num_left_braces_open += 1 | |
if string[i] == "}": | |
num_left_braces_open -= 1 | |
if num_left_braces_open == 0: | |
right_brace_idx = i | |
break | |
i += 1 | |
if right_brace_idx is None: | |
retval = None | |
else: | |
retval = string[idx : right_brace_idx + 1] | |
return retval | |
@scorer(metrics=[accuracy(), stderr()]) | |
def model_graded_equiv(model: Model): | |
async def score(state: TaskState, target: Target): | |
return await score_helper( | |
state=state, | |
target=target, | |
model=model, | |
model_graded=True, | |
) | |
return score | |
@scorer(metrics=[accuracy(), stderr()]) | |
def sympy_equiv(): | |
async def score(state: TaskState, target: Target): | |
return await score_helper( | |
state=state, | |
target=target, | |
model_graded=False, | |
use_sympy=True, | |
) | |
return score | |
@scorer(metrics=[accuracy(), stderr()]) | |
def normalized_string_match(): | |
async def score(state: TaskState, target: Target): | |
return await score_helper( | |
state=state, | |
target=target, | |
model_graded=False, | |
use_sympy=False, | |
) | |
return score | |
async def match_helper( | |
answer: str, | |
target: Target, | |
use_sympy: bool = False, | |
) -> bool: | |
# If the strings already match exactly, we can return True immediately | |
if answer == target.text: | |
return True | |
norm_answer = await normalize_final_answer(answer) | |
norm_target = await normalize_final_answer(target.text) | |
if use_sympy: | |
# Use sympy library for exact match based on https://arxiv.org/pdf/2206.14858 | |
correct = await is_equiv_sympy(norm_answer, norm_target) | |
else: | |
correct = await is_equiv(norm_answer, norm_target) | |
return correct |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from inspect_ai.dataset import hf_dataset | |
from inspect_ai.solver import ( | |
Solver, | |
generate, | |
prompt_template, | |
system_message, | |
) | |
from .dataset import record_to_sample, sample_to_fewshot | |
# Few-shot prompt template partially based on https://arxiv.org/pdf/2206.14858 - Appendix D.2 | |
SYSTEM_W_EXAMPLES_PROMPT_TEMPLATE = """ | |
You will be asked to solve a math problem. Some examples of problems and solutions are provided below. | |
{examples} | |
""".strip() | |
# Setup for problem + instructions for providing answer | |
USER_PROMPT_TEMPLATE = """ | |
Solve the following math problem step by step. The last line of your response should be of the form "ANSWER: $ANSWER" (without quotes) where $ANSWER is the answer to the problem. | |
{prompt} | |
Remember to put your answer on its own line at the end in the form "ANSWER: $ANSWER" (without quotes) where $ANSWER is the answer to the problem, and you do not need to use a \\boxed command. | |
""".strip() | |
def math_solver( | |
fewshot: int, | |
fewshot_seed: int, | |
) -> list[Solver]: | |
"""Build solver for MATH task. | |
Arguments: | |
fewshot (int): Number of few shot examples to use. | |
fewshot_seed (int): Random seed for sampling few shot examples. | |
""" | |
solver = [prompt_template(USER_PROMPT_TEMPLATE), generate()] | |
if fewshot: | |
fewshot_samples = hf_dataset( | |
"hendrycks/competition_math", | |
split="train", | |
trust=True, | |
sample_fields=record_to_sample, | |
shuffle=True, | |
seed=fewshot_seed, | |
limit=fewshot, | |
) | |
solver.insert( | |
0, | |
system_message( | |
SYSTEM_W_EXAMPLES_PROMPT_TEMPLATE.format( | |
examples="\n\n".join( | |
[sample_to_fewshot(sample=sample) for sample in fewshot_samples] | |
) | |
) | |
), | |
) | |
return solver |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Based on https://github.com/UKGovernmentBEIS/inspect_evals/tree/a247353cdc5788d726c03822ee3281b6a5a0a745/src/inspect_evals/mathematics