Last active
February 3, 2025 21:06
-
-
Save tadamcz/a61515465e34a3c66f3a78673502bc3f to your computer and use it in GitHub Desktop.
Epoch AI GPQA implementation (based on inspect_evals)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
from string import ascii_uppercase | |
from typing import Dict, Any | |
from inspect_ai import task, Task | |
from inspect_ai.dataset import Sample | |
from inspect_ai.dataset._sources.hf import hf_dataset | |
import inspect_ai.solver | |
import inspect_ai.scorer | |
def record_to_sample(record: Dict[str, Any]) -> Sample: | |
choices = [ | |
record["Correct Answer"], | |
record["Incorrect Answer 1"], | |
record["Incorrect Answer 2"], | |
record["Incorrect Answer 3"], | |
] | |
random.shuffle(choices) | |
target_index = choices.index(record["Correct Answer"]) | |
# Convert the index to a letter (A, B, C, ...) | |
target = ascii_uppercase[target_index] | |
return Sample( | |
input=record["Question"], | |
choices=choices, | |
target=target, | |
id=record["Record ID"], | |
) | |
SIMPLE_COT_TEMPLATE = r""" | |
Answer the following multiple choice question. The last line of your response should be of the following format: 'ANSWER: $LETTER' (without quotes) where LETTER is one of {letters}. Think step by step before answering. | |
{question} | |
{choices} """.strip() | |
@task(name="GPQA Diamond") | |
def gpqa_diamond() -> Task: | |
dataset = hf_dataset( | |
path="Idavidrein/gpqa", | |
name="gpqa_diamond", | |
split="train", | |
sample_fields=record_to_sample, | |
) | |
# Note: NO DOT use shuffle=True here, as inspect currently modifies the messages (!!) | |
# See: https://github.com/UKGovernmentBEIS/inspect_ai/issues/63 | |
# We shuffle the choices in record_to_sample. | |
plan = [inspect_ai.solver.multiple_choice(template=SIMPLE_COT_TEMPLATE)] | |
return Task( | |
dataset=dataset, | |
plan=plan, | |
scorer=inspect_ai.scorer.choice(), | |
epochs=16, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Based on https://github.com/tadamcz/inspect_evals/blob/b30a1aab73217e035d5aa22fd0526c70650e4b3e/src/inspect_evals/gpqa/gpqa.py