Last active
June 16, 2025 20:56
-
-
Save avivajpeyi/5e62dbdc77aeee008474e7f391a24ebc to your computer and use it in GitHub Desktop.
Bayesian opt animation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
.DS_Store | |
.idea | |
__pycache__ | |
media | |
out | |
*.h5 | |
*.png | |
*.mp4 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
manim -pql bayesian_opt_anim.py BODataAnimation | |
ql is for quick rendering, | |
p for previewing the animation. | |
qh is for high quality rendering. | |
""" | |
from manim import * | |
import numpy as np | |
import h5py | |
import sys | |
class BODataAnimation(Scene): | |
data_file = "bo_data.h5" | |
def construct(self): | |
print("Running anim for file:", self.data_file) | |
# 1. Load data from HDF5 | |
with h5py.File(self.data_file, 'r') as f: | |
# Load target data | |
target_x = f['target_data/x'][:] # shape (N,) or (N,1) | |
target_y = f['target_data/y'][:] # shape (N,) or (N,1) | |
target_x = target_x.reshape(-1) | |
target_y = target_y.reshape(-1) | |
# y_range and acq_range from attributes | |
y_range = tuple(f['target_data'].attrs['y_range']) | |
acq_range = tuple(f['target_data'].attrs['acq_range']) | |
# Collect iteration keys sorted by index | |
iter_keys = [k for k in f.keys() if k.startswith('iteration_')] | |
def iter_index(name): | |
try: | |
return int(name.split('_')[1]) | |
except: | |
return float('inf') | |
iter_keys_sorted = sorted(iter_keys, key=iter_index) | |
# Read iteration data | |
iterations = [] | |
for key in iter_keys_sorted: | |
grp = f[key] | |
gp = grp['gp'][:] # shape (3, N) | |
acq = grp['acq'][:] # shape (N,) | |
x_obs = grp['x'][:] # shape (M,) or (M,1) | |
y_obs = grp['y'][:] # shape (M,) or (M,1) | |
steps = grp.attrs.get('steps', iter_index(key)) | |
x_obs = x_obs.reshape(-1) | |
y_obs = y_obs.reshape(-1) | |
iterations.append({ | |
'gp': gp, | |
'acq': acq.reshape(-1), | |
'x_obs': x_obs, | |
'y_obs': y_obs, | |
'steps': steps | |
}) | |
# 2. Determine axis ranges | |
# X-range from data | |
x_min, x_max = float(np.min(target_x)), float(np.max(target_x)) | |
# Override if desired: e.g., x_min, x_max = -2.0, 10.0 | |
# Y-range for top axes from y_range | |
y_bot, y_top = float(y_range[0]), float(y_range[1]) | |
# Y-range for bottom axes from acq_range, with multiplier | |
acq_bot, acq_top = float(acq_range[0]), float(acq_range[1]) * 1.25 | |
# 3. Build Axes without tick labels | |
# Top axes: for target + GP CI + observations | |
x_tick = (x_max - x_min) / 4 if x_max != x_min else 1.0 | |
y_tick = (y_top - y_bot) / 4 if (y_top - y_bot) != 0 else 1.0 | |
axes_top = Axes( | |
x_range=[x_min, x_max, x_tick], | |
y_range=[y_bot, y_top, y_tick], | |
x_length=7, | |
y_length=3, | |
axis_config={"include_ticks": False, "include_tip": False}, | |
).to_edge(UP, buff=0.5) | |
# Remove tick labels: do not call add_numbers | |
# Add axis lines only; no numeric labels. | |
# Axis labels manually | |
xlabel_top = Text("x", font_size=20).next_to(axes_top.x_axis.get_end(), DOWN, buff=0.1) | |
ylabel_top = Text("f(x)", font_size=20).rotate(90 * DEGREES).next_to(axes_top.y_axis.get_end(), LEFT, buff=0.1) | |
# Bottom axes: acquisition | |
x_tick_b = x_tick | |
acq_tick = (acq_top - acq_bot) / 4 if (acq_top - acq_bot) != 0 else 1.0 | |
axes_bot = Axes( | |
x_range=[x_min, x_max, x_tick_b], | |
y_range=[acq_bot, acq_top, acq_tick], | |
x_length=7, | |
y_length=2, | |
axis_config={"include_ticks": False, "include_tip": False}, | |
).next_to(axes_top, DOWN, buff=0.7) | |
xlabel_bot = Text("x", font_size=20).next_to(axes_bot.x_axis.get_end(), DOWN, buff=0.1) | |
ylabel_bot = Text("Acquisition", font_size=20).rotate(90 * DEGREES).next_to(axes_bot.y_axis.get_end(), LEFT, | |
buff=0.1) | |
# Add axes and labels | |
self.add(axes_top, xlabel_top, ylabel_top, axes_bot, xlabel_bot, ylabel_bot) | |
# 4. Plot target function once (semi-transparent) | |
x_list = target_x.tolist() | |
y_list = target_y.tolist() | |
true_curve = axes_top.plot_line_graph( | |
x_values=x_list, | |
y_values=y_list, | |
add_vertex_dots=False, | |
line_color=GRAY, | |
stroke_width=3, | |
stroke_opacity=0.3, | |
) | |
true_label = Text("Target", font_size=20, color=GRAY).next_to(true_curve, UR, buff=0.2) | |
self.play(Create(true_curve), FadeIn(true_label)) | |
self.wait(0.5) | |
# 5. Helper function to create GP confidence polygon | |
def create_gp_polygon(gp_data, x_data, axes): | |
lower = gp_data[0].reshape(-1) | |
upper = gp_data[2].reshape(-1) | |
pts = [] | |
for xi, yi in zip(x_data, upper): | |
pts.append(axes.coords_to_point(float(xi), float(yi))) | |
for xi, yi in zip(x_data[::-1], lower[::-1]): | |
pts.append(axes.coords_to_point(float(xi), float(yi))) | |
return Polygon(*pts, fill_color=ORANGE, fill_opacity=0.3, stroke_width=0) | |
# Helper function to create smooth line from data | |
def create_smooth_line(x_data, y_data, axes, color, stroke_width=2): | |
return axes.plot_line_graph( | |
x_values=x_data.tolist(), | |
y_values=y_data.tolist(), | |
add_vertex_dots=False, | |
line_color=color, | |
stroke_width=stroke_width, | |
) | |
# 6. Prepare placeholders for mobjects that will be transformed | |
prev_gp_conf = None | |
prev_gp_mean = None | |
prev_obs_dots = VGroup() | |
prev_acq_line = None | |
prev_next_dot = None | |
prev_next_label = None | |
surrogate_label = None | |
# 7. Loop over iterations with improved animation sequence | |
for idx, it in enumerate(iterations): | |
gp = it['gp'] # shape (3, N) | |
acq = it['acq'] # shape (N,) | |
x_obs = it['x_obs'] | |
y_obs = it['y_obs'] | |
steps = it['steps'] | |
# 7a. Create new acquisition function FIRST | |
acq_list = acq.reshape(-1).tolist() | |
new_acq_line = create_smooth_line(target_x, acq, axes_bot, PURPLE) | |
# 7b. Animate acquisition function update | |
if idx == 0: | |
self.play(Create(new_acq_line), run_time=1.0) | |
else: | |
self.play(Transform(prev_acq_line, new_acq_line), run_time=1.2) | |
prev_acq_line = new_acq_line if idx == 0 else prev_acq_line | |
# 7c. Now show minima point on acquisition function | |
next_x = float(x_obs[-1]) if len(x_obs) > 0 else 0.0 | |
idx_closest = int(np.argmin(np.abs(target_x - next_x))) | |
next_y_acq = float(acq[idx_closest]) | |
new_next_dot = Dot(axes_bot.coords_to_point(next_x, next_y_acq), color=YELLOW, radius=0.08) | |
new_next_label = Text("Next", font_size=18, color=YELLOW).next_to(new_next_dot, UR, buff=0.1) | |
if idx == 0: | |
self.play(Create(new_next_dot), FadeIn(new_next_label), run_time=0.8) | |
else: | |
self.play( | |
Transform(prev_next_dot, new_next_dot), | |
Transform(prev_next_label, new_next_label), | |
run_time=0.8 | |
) | |
prev_next_dot = new_next_dot if idx == 0 else prev_next_dot | |
prev_next_label = new_next_label if idx == 0 else prev_next_label | |
# Pause to emphasize the acquisition selection | |
self.wait(0.5) | |
# 7d. Update observation points | |
new_obs_dots = VGroup(*[ | |
Dot(axes_top.coords_to_point(float(xi), float(yi)), color=ORANGE, radius=0.06) | |
for xi, yi in zip(x_obs, y_obs) | |
]) | |
if len(prev_obs_dots) > 0: | |
self.play( | |
FadeOut(prev_obs_dots), | |
*[FadeIn(d) for d in new_obs_dots], | |
run_time=0.8 | |
) | |
else: | |
self.play(*[FadeIn(d) for d in new_obs_dots], run_time=0.8) | |
prev_obs_dots = new_obs_dots | |
# 7e. Finally update GP confidence and mean | |
# GP confidence interval | |
new_gp_conf = create_gp_polygon(gp, target_x, axes_top) | |
# GP mean line (solid orange) | |
gp_mean = gp[1].reshape(-1) | |
new_gp_mean = create_smooth_line(target_x, gp_mean, axes_top, ORANGE, stroke_width=3) | |
# Animate GP updates | |
gp_animations = [] | |
if prev_gp_conf is not None: | |
gp_animations.append(Transform(prev_gp_conf, new_gp_conf)) | |
else: | |
gp_animations.append(Create(new_gp_conf)) | |
if prev_gp_mean is not None: | |
gp_animations.append(Transform(prev_gp_mean, new_gp_mean)) | |
else: | |
gp_animations.append(Create(new_gp_mean)) | |
self.play(*gp_animations, run_time=1.2) | |
# Add surrogate label on first iteration | |
if idx == 0 and surrogate_label is None: | |
surrogate_label = Text("Surrogate", font_size=20, color=ORANGE).next_to(new_gp_mean, UL, buff=0.2) | |
self.play(FadeIn(surrogate_label)) | |
# Store references for next iteration | |
prev_gp_conf = new_gp_conf if idx == 0 else prev_gp_conf | |
prev_gp_mean = new_gp_mean if idx == 0 else prev_gp_mean | |
# Pause between iterations | |
self.wait(1.0) | |
# 8. Final pause and summary | |
final_text = Text("Bayesian Optimization Complete!", font_size=24, color=GREEN).to_edge(DOWN) | |
self.play(FadeIn(final_text)) | |
self.wait(3) | |
class BOBalanced(BODataAnimation): | |
data_file = "out/balanced/bo_data_balanced.h5" | |
class BOExplore(BODataAnimation): | |
data_file = "out/explore/bo_data_explore.h5" | |
class BOExploit(BODataAnimation): | |
data_file = "out/exploit/bo_data_exploit.h5" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from bayes_opt import BayesianOptimization | |
from bayes_opt import acquisition | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib import gridspec | |
import h5py | |
from dataclasses import dataclass | |
import os | |
import sys | |
# https://bayesian-optimization.github.io/BayesianOptimization/2.0.4/exploitation_vs_exploration.html | |
AQ_FN = dict( | |
explore=acquisition.UpperConfidenceBound(kappa=10), | |
exploit=acquisition.ExpectedImprovement(xi=0.0), | |
balanced=acquisition.UpperConfidenceBound(kappa=5) | |
) | |
@dataclass | |
class GPIterationData: | |
gp: np.ndarray | |
acq: np.ndarray | |
x: np.ndarray | |
y: np.ndarray | |
steps: int | |
def target(x): | |
return np.exp(-(x - 2) ** 2) + np.exp(-(x - 6) ** 2 / 10) + 1 / (x ** 2 + 1) | |
def posterior(optimizer, grid): | |
mu, sigma = optimizer._gp.predict(grid, return_std=True) | |
return mu, sigma | |
def collect_gp_data(optimizer, x, y, acq_type): | |
x_obs = np.array([[res["params"]["x"]] for res in optimizer.res]) | |
y_obs = np.array([res["target"] for res in optimizer.res]) | |
optimizer.acquisition_function._fit_gp(optimizer._gp, optimizer._space) | |
mu, sigma = posterior(optimizer, x) | |
gp_preds = np.array([ | |
mu - 1.9600 * sigma, | |
mu, | |
mu + 1.9600 * sigma | |
]) | |
utility = -1 * AQ_FN[acq_type]._get_acq(gp=optimizer._gp)(x) | |
steps = len(optimizer.space) | |
return GPIterationData( | |
gp=gp_preds, | |
acq=utility, | |
x=x_obs.flatten(), | |
y=y_obs, | |
steps=steps | |
) | |
def save_data(acq_type: str, n_itr: int, x: np.ndarray, y: np.ndarray, fname: str): | |
optimizer = BayesianOptimization( | |
target, {'x': (-2, 10)}, | |
acquisition_function=AQ_FN[acq_type], | |
random_state=27 | |
) | |
gp_data = [] | |
optimizer.maximize(init_points=0, n_iter=1) | |
for i in range(n_itr): | |
gp_data.append(collect_gp_data(optimizer, x, y, acq_type)) | |
optimizer.maximize(init_points=0, n_iter=1) | |
y_range = [np.inf, 0] | |
acq_range = [np.inf, 0] | |
# save data to HDF5 file | |
with h5py.File(fname, 'w') as f: | |
group = f.create_group('target_data') | |
group.create_dataset('x', data=x) | |
group.create_dataset('y', data=y) | |
for i, data in enumerate(gp_data): | |
grp = f.create_group(f'iteration_{i}') | |
grp.create_dataset('gp', data=data.gp) | |
grp.create_dataset('acq', data=data.acq) | |
grp.create_dataset('x', data=data.x) | |
grp.create_dataset('y', data=data.y) | |
grp.attrs['steps'] = data.steps | |
# update y_range and acq_range | |
y_range[0] = min(y_range[0], data.gp.min()) | |
y_range[1] = max(y_range[1], data.gp.max()) | |
acq_range[0] = min(acq_range[0], data.acq.min()) | |
acq_range[1] = max(acq_range[1], data.acq.max()) | |
group.attrs['y_range'] = y_range | |
group.attrs['acq_range'] = acq_range | |
def plot_gp(gp, gp_acq, gp_x, gp_y, x, y, y_range, acq_range, steps, outdir='out'): | |
fig = plt.figure(figsize=(8, 8)) | |
fig.suptitle( | |
f'Step {steps}', | |
fontsize=30 | |
) | |
gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1]) | |
axis = plt.subplot(gs[0]) | |
acq = plt.subplot(gs[1]) | |
axis.plot(x, y, linewidth=3, label='Target', color='k', alpha=0.3) | |
axis.plot(gp_x, gp_y, 'D', markersize=8, label=u'Observations', color='tab:orange', ) | |
axis.fill_between( | |
x.ravel(), gp[0], gp[2], | |
alpha=.3, color='tab:orange', label='GP 95% CI' | |
) | |
axis.set_xlim((-2, 10)) | |
axis.set_ylim(*y_range) | |
axis.set_ylabel('f(x)', fontdict={'size': 20}) | |
# axis.set_xlabel('x', fontdict={'size': 20}) | |
x = x.flatten() | |
acq.plot(x, gp_acq, label='Acquisition Function', color='purple') | |
acq.plot(x[np.argmax(gp_acq)], np.max(gp_acq), 'o', markersize=15, | |
label=u'Next Best Guess', markerfacecolor='gold', markeredgecolor='k', markeredgewidth=1) | |
acq.set_xlim((-2, 10)) | |
acq.set_ylim(acq_range[0], acq_range[1] * 1.25) | |
acq.set_ylabel('Acquisition', fontdict={'size': 20}) | |
acq.set_xlabel('x', fontdict={'size': 20}) | |
axis.legend(loc='upper right', frameon=False) | |
acq.legend(loc='upper right', frameon=False) | |
# remove x ticks from the first subplot | |
axis.xaxis.set_ticklabels([]) | |
# remove both subplot x axes spines | |
for s in ['top', 'bottom', 'left']: | |
axis.spines[s].set_visible(False) | |
acq.spines[s].set_visible(False) | |
# remove vertical space between subplots (share x axis) | |
plt.subplots_adjust(hspace=0) | |
plt.tight_layout() | |
plt.savefig(f'{outdir}/bo_step_{steps}.png', dpi=300) | |
plt.close(fig) | |
def load_and_plot_data(fname='bo_data.h5', outdir='out'): | |
with h5py.File(fname, 'r') as f: | |
target_x = f['target_data/x'][:] | |
target_y = f['target_data/y'][:] | |
y_range = f['target_data'].attrs['y_range'] | |
acq_range = f['target_data'].attrs['acq_range'] | |
for i in range(len(f.keys()) - 1): # -1 to exclude 'target_data' | |
grp = f[f'iteration_{i}'] | |
gp = grp['gp'][:] | |
acq = grp['acq'][:] | |
x_obs = grp['x'][:] | |
y_obs = grp['y'][:] | |
steps = grp.attrs['steps'] | |
plot_gp(gp, acq, x_obs, y_obs, target_x, target_y, y_range, acq_range, steps, outdir) | |
def main(): | |
# Generate data for the target function | |
x = np.linspace(-2, 10, 10000).reshape(-1, 1) | |
y = target(x) | |
n_itr = 15 | |
for acq_type in AQ_FN.keys(): | |
outdir = f"out/{acq_type}" | |
os.makedirs(outdir, exist_ok=True) | |
fname = os.path.join(outdir, f'bo_data_{acq_type}.h5') | |
if not os.path.exists(fname): | |
save_data(acq_type, n_itr, x, y, fname) | |
load_and_plot_data(fname, outdir) | |
if __name__ == "__main__": | |
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
manim -qh bayesian_opt_anim.py BOExplore | |
manim -qh bayesian_opt_anim.py BOExploit | |
manim -qh bayesian_opt_anim.py BOBalanced |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment