Created
April 19, 2025 01:25
-
-
Save anadim/7b37b8e3fead6b5aff8171a8d78c562d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ─────────────────────────────────────────────────────────────── | |
# Model‑comparison bar chart • Seaborn context + easy font knobs | |
# (adds `tile_margin` for more slack around blue tiles and lets you | |
# decouple bar thickness from spacing via `centre_spacing_override`) | |
# ─────────────────────────────────────────────────────────────── | |
import os | |
import matplotlib.pyplot as plt | |
from matplotlib.patches import Rectangle | |
from itertools import cycle | |
import seaborn as sns | |
# ───────────── CONFIG (tweak me) ───────────── | |
context_choice = "talk" # "notebook" | "talk" | "poster" | "paper" | |
font_scale = 1.5 # global multiplier applied by Seaborn | |
bar_width = 2.5 # thickness of each bar (data‑units) | |
intra_bar_gap = 3 # space between neighbouring bars | |
# If you feel bar_width isn’t visually changing, lock the spacing: | |
centre_spacing_override = intra_bar_gap # distance between bar centres | |
# set to None to default to bar_width + intra_bar_gap | |
cluster_gap = 2.0 # extra space between benchmark clusters | |
tile_margin = 0.5 # extra slack left/right of each blue tile | |
label_fontsize = 12 # numeric labels above bars | |
value_font_offset = 2 # ↑ vertical offset in data units | |
cluster_title_fontsize = 18 # text inside cluster tiles | |
legend_fontsize = 16 | |
# ───────────────────────────────────────────── | |
# 0 MODELS TO HIDE (canonical names, lowercase) | |
EXCLUDE_MODELS = { | |
"claude 3.7 sonnet", | |
"openthinker2‑32b", | |
"o1‑mini", | |
} | |
# 1 BENCHMARK DATA (accuracy / %) | |
benchmarks = { | |
"AIME 24": { | |
"phi‑4": 0.12, "phi‑4‑reasoning": 0.753, "phi‑4‑more‑reasoning": 0.813, | |
# "claude 3.7 sonnet": 0.587, | |
"deepseek r1 distill 70b": 0.693, "deepseek r1": 0.786, | |
"o1": 74.6, "o3‑mini (jan‑25)": 0.88, "o1‑mini": 0.753, | |
"openthinker2‑32b": 0.58, | |
}, | |
"AIME 25": { | |
"phi‑4": 0.129, "phi‑4‑reasoning": 0.629, "phi‑4‑more‑reasoning": 0.78, | |
# "claude 3.7 sonnet": 0.587, | |
"deepseek r1 distill 70b": 0.515, "deepseek r1": 0.704, | |
"o1": 0.753, "o3‑mini (jan‑25)": 0.825, "o1‑mini": 0.753, | |
"openthinker2‑32b": 0.58, | |
}, | |
# "HMMT Feb 25": { | |
# "phi‑4‑reasoning": 54.67, "phi‑4‑more‑reasoning": 58.67, | |
# "claude 3.7 sonnet": 31.67, | |
# "deepseek r1 distill 70b": 33.33, "deepseek r1": 41.67, | |
# "o1": 48.33, "o3‑mini": 67.50, | |
# }, | |
"OmniMath": { | |
"phi‑4": 31.9, "phi‑4‑reasoning": 76.6, "phi‑4‑more‑reasoning": 81.4, | |
"claude 3.7 sonnet": 46.4, | |
"deepseek r1 distill 70b": 63.4, "deepseek r1": 85, | |
"o1‑mini": 60.5, "o1": 67.5, "o3‑mini": 74.6, | |
}, | |
"GPQA": { | |
"phi‑4": 54.7, "phi‑4‑reasoning": 65.8, "phi‑4‑more‑reasoning": 68.9, | |
"claude 3.7 sonnet": 76.8, | |
"deepseek r1 distill 70b": 66.2, "deepseek r1": 73.0, | |
"o1": 76.7, "o3‑mini (jan‑25)": 77.7, "o1‑mini": 60, | |
"openthinker2‑32b": 64.1, | |
}, | |
"LCB": { | |
"phi‑4": 26.9, "phi‑4‑reasoning": 53.8, "phi‑4‑more‑reasoning": 53.1, | |
"o1": 71, "o3‑mini": 69.5, | |
"deepseek r1 distill 70b": 57.5, "deepseek r1": 62.8, | |
"o1‑mini": 52, | |
}, | |
} | |
# 2 COLOUR MAP | |
colour_map = { | |
"phi‑4": "#40826D", "phi‑4‑reasoning": "#00563b", "phi‑4‑more‑reasoning": "#013220", | |
"o1‑mini": "#FFECB3", "o1‑mini (sept‑24)": "#FFECB3", | |
"o1": "#FFD54F", | |
"o3‑mini": "#FFC107", "o3‑mini (jan‑25)": "#FFC107", | |
"deepseek r1 distill 70b": "#8C9BF0", "deepseek r1": "#3A59D1", | |
"claude 3.7 sonnet": "#d8ac8c", "openthinker2‑32b": "#8c8c8c", | |
} | |
_fallback_iter = cycle(plt.cm.tab20.colors) | |
def color_for(model: str): | |
if model not in colour_map: | |
colour_map[model] = next(_fallback_iter) | |
return colour_map[model] | |
# 3 HELPERS | |
def canon(name: str) -> str: | |
"""Return canonical model name (strip parenthetical).""" | |
return name.split("(")[0].strip() | |
phi_order = ["phi‑4", "phi‑4‑reasoning", "phi‑4‑more‑reasoning"] | |
think_order = ["openthinker2‑32b"] | |
ds_order = ["deepseek r1 distill 70b", "deepseek r1"] | |
son_order = ["claude 3.7 sonnet"] | |
o_order = ["o1‑mini", "o1‑mini", "o1", "o3‑mini", "o3‑mini"] | |
priority_blocks = [phi_order, think_order, ds_order, son_order, o_order] | |
def sort_key(model: str): | |
p = canon(model) | |
for bi, block in enumerate(priority_blocks): | |
if p in block: | |
return (bi, block.index(p)) | |
return (len(priority_blocks), p) | |
# convert proportions → percentages | |
for res in benchmarks.values(): | |
for k, v in res.items(): | |
if v is not None and v <= 1: | |
res[k] = v * 100 | |
# 4 Build plot | |
sns.set_context(context_choice, font_scale=font_scale) | |
# decide spacing between bar centres | |
if centre_spacing_override is None: | |
centre_spacing = bar_width + intra_bar_gap | |
else: | |
centre_spacing = centre_spacing_override | |
# recompute intra_bar_gap for bounds calc | |
intra_bar_gap = max(centre_spacing - bar_width, 0) | |
fig, ax = plt.subplots(figsize=(18, 7)) | |
x_vals, heights, colours, labels = [], [], [], [] | |
cluster_mid, bounds = {}, {} | |
cur_x = 0.0 | |
for bench, res in benchmarks.items(): | |
visible = {m: v for m, v in res.items() | |
if v is not None and canon(m) not in EXCLUDE_MODELS} | |
items = sorted(visible.items(), key=lambda t: sort_key(t[0])) | |
start = cur_x | |
for m, v in items: | |
x_vals.append(cur_x) | |
heights.append(v) | |
colours.append(color_for(m)) | |
labels.append(canon(m)) | |
cur_x += centre_spacing | |
end = cur_x - centre_spacing | |
cluster_mid[bench] = (start + end) / 2 | |
bounds[bench] = ( | |
start - bar_width / 2 - tile_margin, | |
end + bar_width / 2 + tile_margin, | |
) | |
cur_x += cluster_gap # leave fixed gap between clusters | |
# shaded cluster backgrounds | |
axis_top = 100 | |
for left, right in bounds.values(): | |
ax.add_patch( | |
Rectangle((left, 0), right - left, axis_top, | |
facecolor="#F0FAFF", edgecolor="#ADD8E6", | |
linewidth=1.6, zorder=0) | |
) | |
# bars & numeric labels | |
# bars & numeric labels | |
bars = ax.bar( | |
x_vals, heights, width=bar_width, color=colours, | |
edgecolor="white", zorder=1 | |
) | |
for bar, val in zip(bars, heights): | |
ax.text( | |
bar.get_x() + bar.get_width() / 2, | |
val + value_font_offset, | |
f"{val:.1f}", | |
ha="center", | |
va="bottom", | |
fontsize=label_fontsize, | |
# fontweight="bold", # ← make text bold | |
rotation=0, # ← negative = clockwise (positive = CCW) | |
rotation_mode="anchor" # anchor the rotation at the text position | |
) | |
# strip x‑tick labels | |
ax.set_xticks([]); ax.set_xticklabels([]) | |
# axes styling | |
ax.set_ylabel("Accuracy (%)", | |
fontsize=25) | |
ax.set_ylim(0, axis_top) | |
ax.tick_params(axis="y", which="major", labelsize=20) # pick any size you like | |
ax.set_axisbelow(True) | |
for spine in ("top", "right"): | |
ax.spines[spine].set_visible(False) | |
# cluster titles | |
for bench, mid in cluster_mid.items(): | |
ax.text(mid, axis_top, bench, | |
ha="center", va="bottom", | |
weight="bold", fontsize=cluster_title_fontsize) | |
# legend | |
legend_models = sorted(set(labels), key=sort_key) | |
legend_patches = [Rectangle((0, 0), 1, 1, facecolor=color_for(m)) | |
for m in legend_models] | |
leg = ax.legend(legend_patches, legend_models, | |
loc="upper center", bbox_to_anchor=(0.5, -0.05), | |
ncol=7, frameon=True, borderpad=0.5, | |
handlelength=1.5, columnspacing=1.5, | |
fontsize=legend_fontsize) | |
leg.get_frame().set_facecolor("#F0FAFF") | |
leg.get_frame().set_edgecolor("#ADD8E6") | |
leg.get_frame().set_linewidth(1.4) | |
# final layout & save | |
plt.tight_layout(rect=[0, 0.08, 1, 1]) # keep extra bottom margin for legend | |
save_path = "/Users/anadim/Desktop/model_comparison.pdf" | |
print(f"Attempting to save figure to: {os.path.abspath(save_path)}") | |
try: | |
plt.savefig(save_path, format="pdf", bbox_inches="tight") | |
print(f"Successfully saved figure to {os.path.abspath(save_path)}") | |
except Exception as e: | |
print(f"Error saving figure: {e}") | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment