|
#!/usr/bin/env python3 |
|
"""Tiny standalone AIME25 reproduction script for OpenAI-compatible servers. |
|
|
|
No third-party dependencies. Example: |
|
|
|
python3 aime25_repro.py \ |
|
--base-url http://127.0.0.1:8080/v1 \ |
|
--model mlx-community/Qwen3.6-27B-4bit \ |
|
--temperature 0 \ |
|
--max-tokens 8192 \ |
|
--seed 1 \ |
|
--no-think |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
import argparse |
|
import json |
|
import os |
|
import re |
|
import sys |
|
import time |
|
import urllib.error |
|
import urllib.request |
|
from dataclasses import dataclass |
|
from pathlib import Path |
|
from typing import Any |
|
|
|
|
|
PROMPT_TEMPLATE = ( |
|
"Solve this competition math problem. Show your reasoning, then give " |
|
'the final answer on the last line by itself in the format "Answer: ". ' |
|
"The answer is always an integer between 0 and 999 inclusive.\n\n" |
|
"{problem}" |
|
) |
|
|
|
|
|
CASES = [ |
|
{ |
|
"id": "aime25_0000", |
|
"answer": "70", |
|
"problem": "Find the sum of all integer bases $b>9$ for which $17_b$ is a divisor of $97_b.$", |
|
}, |
|
{ |
|
"id": "aime25_0001", |
|
"answer": "588", |
|
"problem": ( |
|
"In $\\triangle ABC$ points $D$ and $E$ lie on $\\overline{AB}$ so that " |
|
"$AD < AE < AB$, while points $F$ and $G$ lie on $\\overline{AC}$ so that " |
|
"$AF < AG < AC$. Suppose $AD = 4$, $DE = 16$, $EB = 8$, $AF = 13$, " |
|
"$FG = 52$, and $GC = 26$. Let $M$ be the reflection of $D$ through $F$, " |
|
"and let $N$ be the reflection of $G$ through $E$. The area of quadrilateral " |
|
"$DEGF$ is $288$. Find the area of heptagon $AFNBCEM$." |
|
), |
|
}, |
|
{ |
|
"id": "aime25_0002", |
|
"answer": "16", |
|
"problem": ( |
|
"The $9$ members of a baseball team went to an ice-cream parlor after their game. " |
|
"Each player had a single scoop cone of chocolate, vanilla, or strawberry ice cream. " |
|
"At least one player chose each flavor, and the number of players who chose chocolate " |
|
"was greater than the number of players who chose vanilla, which was greater than the " |
|
"number of players who chose strawberry. Let $N$ be the number of different assignments " |
|
"of flavors to players that meet these conditions.\n" |
|
"Find the remainder when $N$ is divided by $1000.$" |
|
), |
|
}, |
|
{ |
|
"id": "aime25_0003", |
|
"answer": "117", |
|
"problem": ( |
|
"Find the number of ordered pairs $(x,y)$, where both $x$ and $y$ are integers " |
|
"between $-100$ and $100$ inclusive, such that $12x^2-xy-6y^2=0$." |
|
), |
|
}, |
|
{ |
|
"id": "aime25_0004", |
|
"answer": "279", |
|
"problem": ( |
|
"There are $8!= 40320$ eight-digit positive integers that use each of the digits " |
|
"$1, 2, 3, 4, 5, 6, 7, 8$ exactly once. Let $N$ be the number of these integers " |
|
"that are divisible by $22$. Find the difference between $N$ and $2025$.$" |
|
), |
|
}, |
|
] |
|
|
|
|
|
@dataclass |
|
class Row: |
|
state: str |
|
prompt_tokens: int |
|
gen_tokens: int |
|
given: str |
|
correct: str |
|
test: str |
|
error: str = "" |
|
|
|
|
|
def strip_thinking(text: str) -> str: |
|
text = re.sub(r"<think>.*?</think>", "", text, flags=re.IGNORECASE | re.DOTALL) |
|
return text.strip() |
|
|
|
|
|
def parse_numeric_answer(response: str) -> str: |
|
cleaned = strip_thinking(response) |
|
matches = list(re.finditer(r"Answer", cleaned, flags=re.IGNORECASE)) |
|
if not matches: |
|
for pat in ( |
|
r"the answer is", |
|
r"the result is", |
|
r"therefore", |
|
r"thus", |
|
r"finally", |
|
r"so the answer", |
|
r"we get", |
|
): |
|
matches = list(re.finditer(pat, cleaned, flags=re.IGNORECASE)) |
|
if matches: |
|
break |
|
|
|
if matches: |
|
answer_text = cleaned[matches[-1].end() :].strip() |
|
if answer_text.startswith(":"): |
|
answer_text = answer_text[1:].strip() |
|
numbers = re.findall(r"-?\d+\.?\d*", answer_text.replace(",", "")) |
|
if numbers: |
|
return numbers[-1].rstrip(".") |
|
|
|
numbers = re.findall(r"-?\d+\.?\d*", cleaned.replace(",", "")) |
|
return numbers[-1].rstrip(".") if numbers else "" |
|
|
|
|
|
def normalize_number(value: str) -> str: |
|
value = value.replace(",", "") |
|
if "." in value: |
|
value = value.rstrip("0").rstrip(".") |
|
return value |
|
|
|
|
|
def post_json(url: str, payload: dict[str, Any], timeout: float, api_key: str | None) -> Any: |
|
data = json.dumps(payload).encode("utf-8") |
|
headers = {"Content-Type": "application/json"} |
|
if api_key: |
|
headers["Authorization"] = f"Bearer {api_key}" |
|
request = urllib.request.Request(url, data=data, headers=headers, method="POST") |
|
return urllib.request.urlopen(request, timeout=timeout) |
|
|
|
|
|
def print_request_summary(url: str, payload: dict[str, Any]) -> None: |
|
summary = { |
|
key: value |
|
for key, value in payload.items() |
|
if key not in {"messages"} |
|
} |
|
summary["messages"] = [ |
|
{ |
|
"role": message.get("role"), |
|
"content_chars": len(message.get("content") or ""), |
|
} |
|
for message in payload.get("messages", []) |
|
] |
|
print(f"POST {url}", file=sys.stderr) |
|
print(json.dumps(summary, indent=2, sort_keys=True), file=sys.stderr) |
|
|
|
|
|
def decode_sse_line(raw: bytes) -> str | None: |
|
line = raw.decode("utf-8", errors="replace").strip() |
|
if not line.startswith("data:"): |
|
return None |
|
data = line[5:].strip() |
|
return None if data == "[DONE]" else data |
|
|
|
|
|
def chat_completion( |
|
args: argparse.Namespace, |
|
prompt: str, |
|
*, |
|
progress_label: str = "", |
|
) -> dict[str, Any]: |
|
url = args.base_url.rstrip("/") + "/chat/completions" |
|
payload: dict[str, Any] = { |
|
"model": args.model, |
|
"messages": [{"role": "user", "content": prompt}], |
|
"temperature": args.temperature, |
|
"top_p": args.top_p, |
|
"seed": args.seed, |
|
"max_tokens": args.max_tokens, |
|
"stream": True, |
|
"stream_options": {"include_usage": True}, |
|
} |
|
if args.no_think: |
|
# This matches OpenAI SDK's extra_body behavior used by lbench: |
|
# extra_body fields are flattened into the request JSON body. |
|
payload["chat_template_kwargs"] = {"enable_thinking": False} |
|
|
|
if args.print_request: |
|
print_request_summary(url, payload) |
|
|
|
api_key = args.api_key or os.environ.get("OPENAI_API_KEY") or "not-needed" |
|
try: |
|
response = post_json(url, payload, args.timeout, api_key) |
|
except urllib.error.HTTPError as exc: |
|
body = exc.read().decode("utf-8", errors="replace") |
|
if exc.code in (400, 422) and ("stream_options" in body or "include_usage" in body): |
|
payload.pop("stream_options", None) |
|
response = post_json(url, payload, args.timeout, api_key) |
|
elif exc.code in (400, 422) and "chat_template_kwargs" in body: |
|
payload.pop("chat_template_kwargs", None) |
|
response = post_json(url, payload, args.timeout, api_key) |
|
else: |
|
raise RuntimeError(f"HTTP {exc.code}: {body}") from exc |
|
|
|
output = "" |
|
reasoning = "" |
|
prompt_tokens = 0 |
|
completion_tokens = 0 |
|
finish_reason = "" |
|
received_chunks = 0 |
|
last_progress = time.perf_counter() |
|
|
|
for raw in response: |
|
data = decode_sse_line(raw) |
|
if data is None: |
|
continue |
|
chunk = json.loads(data) |
|
usage = chunk.get("usage") |
|
if usage: |
|
prompt_tokens = int(usage.get("prompt_tokens") or 0) |
|
completion_tokens = int(usage.get("completion_tokens") or 0) |
|
choices = chunk.get("choices") or [] |
|
if not choices: |
|
continue |
|
finish_reason = choices[0].get("finish_reason") or finish_reason |
|
delta = choices[0].get("delta") or {} |
|
content_delta = delta.get("content") or "" |
|
reasoning_delta = delta.get("reasoning_content") or delta.get("reasoning") or "" |
|
output += content_delta |
|
reasoning += reasoning_delta |
|
if content_delta or reasoning_delta: |
|
received_chunks += 1 |
|
now = time.perf_counter() |
|
if progress_label and not args.quiet and now - last_progress >= args.progress_interval: |
|
print( |
|
f"\r{progress_label} streaming {received_chunks}/{args.max_tokens}", |
|
end="", |
|
flush=True, |
|
) |
|
last_progress = now |
|
|
|
if completion_tokens == 0: |
|
# Only a fallback for servers that do not report usage in streaming. |
|
# Use reported usage when available for token-count comparisons. |
|
completion_tokens = received_chunks or len(re.findall(r"\S+", output + reasoning)) |
|
|
|
return { |
|
"output": output, |
|
"reasoning": reasoning, |
|
"prompt_tokens": prompt_tokens, |
|
"completion_tokens": completion_tokens, |
|
"received_chunks": received_chunks, |
|
"finish_reason": finish_reason, |
|
} |
|
|
|
|
|
def fmt_hhmm(seconds: float) -> str: |
|
total = int(round(seconds)) |
|
h, rem = divmod(total, 3600) |
|
m, s = divmod(rem, 60) |
|
return f"{h:02d}h:{m:02d}:{s:02d}s" |
|
|
|
|
|
def print_table(rows: list[Row], runtime_s: float) -> None: |
|
passed = sum(1 for row in rows if row.state == "PASSED") |
|
total = len(rows) |
|
print(f"\naime25: {passed}/{total} passed, {total - passed} failed, runtime {fmt_hhmm(runtime_s)}") |
|
|
|
table = [] |
|
for i, row in enumerate(rows, start=1): |
|
table.append( |
|
( |
|
str(i), |
|
row.state, |
|
str(row.prompt_tokens), |
|
str(row.gen_tokens), |
|
str(row.prompt_tokens + row.gen_tokens), |
|
row.given or "-", |
|
row.correct, |
|
row.test, |
|
) |
|
) |
|
headers = ("#", "state", "prompt", "gen", "total", "given", "correct", "test") |
|
widths = [max(len(headers[i]), *(len(r[i]) for r in table)) for i in range(len(headers))] |
|
right = {0, 2, 3, 4} |
|
|
|
def fmt_row(cells: tuple[str, ...]) -> str: |
|
parts = [cell.rjust(widths[i]) if i in right else cell.ljust(widths[i]) for i, cell in enumerate(cells)] |
|
return " ".join(parts).rstrip() |
|
|
|
print(fmt_row(headers)) |
|
for row in table: |
|
print(fmt_row(row)) |
|
|
|
for row in rows: |
|
if row.error: |
|
print(f"\n{row.test} error: {row.error}", file=sys.stderr) |
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
parser = argparse.ArgumentParser(description="Standalone AIME25 OpenAI-compatible repro.") |
|
parser.add_argument("--base-url", default="http://127.0.0.1:8080/v1") |
|
parser.add_argument("-m", "--model", required=True) |
|
parser.add_argument("--api-key", default=None) |
|
parser.add_argument("--cases", type=int, default=5, help="Number of embedded cases to run, max 5.") |
|
parser.add_argument("--temperature", type=float, default=0.0) |
|
parser.add_argument("--top-p", type=float, default=1.0) |
|
parser.add_argument("--seed", type=int, default=1) |
|
parser.add_argument("--max-tokens", type=int, default=8192) |
|
parser.add_argument("--timeout", type=float, default=900) |
|
parser.add_argument("--no-think", action="store_true", help="Send enable_thinking=false via extra_body.") |
|
parser.add_argument("--quiet", action="store_true", help="Only print the final summary table.") |
|
parser.add_argument( |
|
"--progress-interval", |
|
type=float, |
|
default=1.0, |
|
help="Seconds between live streaming progress updates.", |
|
) |
|
parser.add_argument( |
|
"--save-outputs", |
|
default=None, |
|
help="Optional directory where raw per-case outputs are written as .txt files.", |
|
) |
|
parser.add_argument( |
|
"--print-request", |
|
action="store_true", |
|
help="Print the effective request parameters before each API call, excluding prompt text.", |
|
) |
|
return parser.parse_args() |
|
|
|
|
|
def main() -> int: |
|
args = parse_args() |
|
cases = CASES[: max(0, min(args.cases, len(CASES)))] |
|
rows: list[Row] = [] |
|
started = time.perf_counter() |
|
output_dir = Path(args.save_outputs) if args.save_outputs else None |
|
if output_dir: |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
for idx, case in enumerate(cases, start=1): |
|
prompt = PROMPT_TEMPLATE.format(problem=case["problem"]) |
|
case_started = time.perf_counter() |
|
if not args.quiet: |
|
print(f"[{idx}/{len(cases)}] {case['id']} running...", flush=True) |
|
try: |
|
progress_label = f"[{idx}/{len(cases)}] {case['id']}" |
|
result = chat_completion(args, prompt, progress_label=progress_label) |
|
if output_dir: |
|
text = result["output"] |
|
if result["reasoning"]: |
|
text = result["reasoning"] + "\n\n--- content ---\n\n" + text |
|
(output_dir / f"{case['id']}.txt").write_text(text, encoding="utf-8") |
|
given = normalize_number(parse_numeric_answer(result["output"])) |
|
correct = normalize_number(case["answer"]) |
|
passed = bool(given) and given == correct |
|
elapsed = time.perf_counter() - case_started |
|
rows.append( |
|
Row( |
|
state="PASSED" if passed else "FAILED", |
|
prompt_tokens=result["prompt_tokens"], |
|
gen_tokens=result["completion_tokens"], |
|
given=given, |
|
correct=correct, |
|
test=f"aime25/{case['id']}", |
|
) |
|
) |
|
if not args.quiet: |
|
print("\r" + " " * 80 + "\r", end="", flush=True) |
|
state = "PASSED" if passed else "FAILED" |
|
print( |
|
f"[{idx}/{len(cases)}] {case['id']} {state} " |
|
f"given={given or '-'} correct={correct} " |
|
f"prompt={result['prompt_tokens']} gen={result['completion_tokens']} " |
|
f"runtime={fmt_hhmm(elapsed)}", |
|
flush=True, |
|
) |
|
except Exception as exc: |
|
elapsed = time.perf_counter() - case_started |
|
rows.append( |
|
Row( |
|
state="ERR", |
|
prompt_tokens=0, |
|
gen_tokens=0, |
|
given="-", |
|
correct=case["answer"], |
|
test=f"aime25/{case['id']}", |
|
error=str(exc), |
|
) |
|
) |
|
if not args.quiet: |
|
print( |
|
f"[{idx}/{len(cases)}] {case['id']} ERR " |
|
f"runtime={fmt_hhmm(elapsed)} error={exc}", |
|
flush=True, |
|
file=sys.stderr, |
|
) |
|
|
|
print_table(rows, time.perf_counter() - started) |
|
return 0 if all(row.state == "PASSED" for row in rows) else 1 |
|
|
|
|
|
if __name__ == "__main__": |
|
raise SystemExit(main()) |