Skip to content

Instantly share code, notes, and snippets.

@avivajpeyi
Last active June 16, 2025 20:56
Show Gist options
  • Save avivajpeyi/5e62dbdc77aeee008474e7f391a24ebc to your computer and use it in GitHub Desktop.
Save avivajpeyi/5e62dbdc77aeee008474e7f391a24ebc to your computer and use it in GitHub Desktop.
Bayesian opt animation
.DS_Store
.idea
__pycache__
media
out
*.h5
*.png
*.mp4
"""
manim -pql bayesian_opt_anim.py BODataAnimation
ql is for quick rendering,
p for previewing the animation.
qh is for high quality rendering.
"""
from manim import *
import numpy as np
import h5py
import sys
class BODataAnimation(Scene):
data_file = "bo_data.h5"
def construct(self):
print("Running anim for file:", self.data_file)
# 1. Load data from HDF5
with h5py.File(self.data_file, 'r') as f:
# Load target data
target_x = f['target_data/x'][:] # shape (N,) or (N,1)
target_y = f['target_data/y'][:] # shape (N,) or (N,1)
target_x = target_x.reshape(-1)
target_y = target_y.reshape(-1)
# y_range and acq_range from attributes
y_range = tuple(f['target_data'].attrs['y_range'])
acq_range = tuple(f['target_data'].attrs['acq_range'])
# Collect iteration keys sorted by index
iter_keys = [k for k in f.keys() if k.startswith('iteration_')]
def iter_index(name):
try:
return int(name.split('_')[1])
except:
return float('inf')
iter_keys_sorted = sorted(iter_keys, key=iter_index)
# Read iteration data
iterations = []
for key in iter_keys_sorted:
grp = f[key]
gp = grp['gp'][:] # shape (3, N)
acq = grp['acq'][:] # shape (N,)
x_obs = grp['x'][:] # shape (M,) or (M,1)
y_obs = grp['y'][:] # shape (M,) or (M,1)
steps = grp.attrs.get('steps', iter_index(key))
x_obs = x_obs.reshape(-1)
y_obs = y_obs.reshape(-1)
iterations.append({
'gp': gp,
'acq': acq.reshape(-1),
'x_obs': x_obs,
'y_obs': y_obs,
'steps': steps
})
# 2. Determine axis ranges
# X-range from data
x_min, x_max = float(np.min(target_x)), float(np.max(target_x))
# Override if desired: e.g., x_min, x_max = -2.0, 10.0
# Y-range for top axes from y_range
y_bot, y_top = float(y_range[0]), float(y_range[1])
# Y-range for bottom axes from acq_range, with multiplier
acq_bot, acq_top = float(acq_range[0]), float(acq_range[1]) * 1.25
# 3. Build Axes without tick labels
# Top axes: for target + GP CI + observations
x_tick = (x_max - x_min) / 4 if x_max != x_min else 1.0
y_tick = (y_top - y_bot) / 4 if (y_top - y_bot) != 0 else 1.0
axes_top = Axes(
x_range=[x_min, x_max, x_tick],
y_range=[y_bot, y_top, y_tick],
x_length=7,
y_length=3,
axis_config={"include_ticks": False, "include_tip": False},
).to_edge(UP, buff=0.5)
# Remove tick labels: do not call add_numbers
# Add axis lines only; no numeric labels.
# Axis labels manually
xlabel_top = Text("x", font_size=20).next_to(axes_top.x_axis.get_end(), DOWN, buff=0.1)
ylabel_top = Text("f(x)", font_size=20).rotate(90 * DEGREES).next_to(axes_top.y_axis.get_end(), LEFT, buff=0.1)
# Bottom axes: acquisition
x_tick_b = x_tick
acq_tick = (acq_top - acq_bot) / 4 if (acq_top - acq_bot) != 0 else 1.0
axes_bot = Axes(
x_range=[x_min, x_max, x_tick_b],
y_range=[acq_bot, acq_top, acq_tick],
x_length=7,
y_length=2,
axis_config={"include_ticks": False, "include_tip": False},
).next_to(axes_top, DOWN, buff=0.7)
xlabel_bot = Text("x", font_size=20).next_to(axes_bot.x_axis.get_end(), DOWN, buff=0.1)
ylabel_bot = Text("Acquisition", font_size=20).rotate(90 * DEGREES).next_to(axes_bot.y_axis.get_end(), LEFT,
buff=0.1)
# Add axes and labels
self.add(axes_top, xlabel_top, ylabel_top, axes_bot, xlabel_bot, ylabel_bot)
# 4. Plot target function once (semi-transparent)
x_list = target_x.tolist()
y_list = target_y.tolist()
true_curve = axes_top.plot_line_graph(
x_values=x_list,
y_values=y_list,
add_vertex_dots=False,
line_color=GRAY,
stroke_width=3,
stroke_opacity=0.3,
)
true_label = Text("Target", font_size=20, color=GRAY).next_to(true_curve, UR, buff=0.2)
self.play(Create(true_curve), FadeIn(true_label))
self.wait(0.5)
# 5. Helper function to create GP confidence polygon
def create_gp_polygon(gp_data, x_data, axes):
lower = gp_data[0].reshape(-1)
upper = gp_data[2].reshape(-1)
pts = []
for xi, yi in zip(x_data, upper):
pts.append(axes.coords_to_point(float(xi), float(yi)))
for xi, yi in zip(x_data[::-1], lower[::-1]):
pts.append(axes.coords_to_point(float(xi), float(yi)))
return Polygon(*pts, fill_color=ORANGE, fill_opacity=0.3, stroke_width=0)
# Helper function to create smooth line from data
def create_smooth_line(x_data, y_data, axes, color, stroke_width=2):
return axes.plot_line_graph(
x_values=x_data.tolist(),
y_values=y_data.tolist(),
add_vertex_dots=False,
line_color=color,
stroke_width=stroke_width,
)
# 6. Prepare placeholders for mobjects that will be transformed
prev_gp_conf = None
prev_gp_mean = None
prev_obs_dots = VGroup()
prev_acq_line = None
prev_next_dot = None
prev_next_label = None
surrogate_label = None
# 7. Loop over iterations with improved animation sequence
for idx, it in enumerate(iterations):
gp = it['gp'] # shape (3, N)
acq = it['acq'] # shape (N,)
x_obs = it['x_obs']
y_obs = it['y_obs']
steps = it['steps']
# 7a. Create new acquisition function FIRST
acq_list = acq.reshape(-1).tolist()
new_acq_line = create_smooth_line(target_x, acq, axes_bot, PURPLE)
# 7b. Animate acquisition function update
if idx == 0:
self.play(Create(new_acq_line), run_time=1.0)
else:
self.play(Transform(prev_acq_line, new_acq_line), run_time=1.2)
prev_acq_line = new_acq_line if idx == 0 else prev_acq_line
# 7c. Now show minima point on acquisition function
next_x = float(x_obs[-1]) if len(x_obs) > 0 else 0.0
idx_closest = int(np.argmin(np.abs(target_x - next_x)))
next_y_acq = float(acq[idx_closest])
new_next_dot = Dot(axes_bot.coords_to_point(next_x, next_y_acq), color=YELLOW, radius=0.08)
new_next_label = Text("Next", font_size=18, color=YELLOW).next_to(new_next_dot, UR, buff=0.1)
if idx == 0:
self.play(Create(new_next_dot), FadeIn(new_next_label), run_time=0.8)
else:
self.play(
Transform(prev_next_dot, new_next_dot),
Transform(prev_next_label, new_next_label),
run_time=0.8
)
prev_next_dot = new_next_dot if idx == 0 else prev_next_dot
prev_next_label = new_next_label if idx == 0 else prev_next_label
# Pause to emphasize the acquisition selection
self.wait(0.5)
# 7d. Update observation points
new_obs_dots = VGroup(*[
Dot(axes_top.coords_to_point(float(xi), float(yi)), color=ORANGE, radius=0.06)
for xi, yi in zip(x_obs, y_obs)
])
if len(prev_obs_dots) > 0:
self.play(
FadeOut(prev_obs_dots),
*[FadeIn(d) for d in new_obs_dots],
run_time=0.8
)
else:
self.play(*[FadeIn(d) for d in new_obs_dots], run_time=0.8)
prev_obs_dots = new_obs_dots
# 7e. Finally update GP confidence and mean
# GP confidence interval
new_gp_conf = create_gp_polygon(gp, target_x, axes_top)
# GP mean line (solid orange)
gp_mean = gp[1].reshape(-1)
new_gp_mean = create_smooth_line(target_x, gp_mean, axes_top, ORANGE, stroke_width=3)
# Animate GP updates
gp_animations = []
if prev_gp_conf is not None:
gp_animations.append(Transform(prev_gp_conf, new_gp_conf))
else:
gp_animations.append(Create(new_gp_conf))
if prev_gp_mean is not None:
gp_animations.append(Transform(prev_gp_mean, new_gp_mean))
else:
gp_animations.append(Create(new_gp_mean))
self.play(*gp_animations, run_time=1.2)
# Add surrogate label on first iteration
if idx == 0 and surrogate_label is None:
surrogate_label = Text("Surrogate", font_size=20, color=ORANGE).next_to(new_gp_mean, UL, buff=0.2)
self.play(FadeIn(surrogate_label))
# Store references for next iteration
prev_gp_conf = new_gp_conf if idx == 0 else prev_gp_conf
prev_gp_mean = new_gp_mean if idx == 0 else prev_gp_mean
# Pause between iterations
self.wait(1.0)
# 8. Final pause and summary
final_text = Text("Bayesian Optimization Complete!", font_size=24, color=GREEN).to_edge(DOWN)
self.play(FadeIn(final_text))
self.wait(3)
class BOBalanced(BODataAnimation):
data_file = "out/balanced/bo_data_balanced.h5"
class BOExplore(BODataAnimation):
data_file = "out/explore/bo_data_explore.h5"
class BOExploit(BODataAnimation):
data_file = "out/exploit/bo_data_exploit.h5"
from bayes_opt import BayesianOptimization
from bayes_opt import acquisition
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec
import h5py
from dataclasses import dataclass
import os
import sys
# https://bayesian-optimization.github.io/BayesianOptimization/2.0.4/exploitation_vs_exploration.html
AQ_FN = dict(
explore=acquisition.UpperConfidenceBound(kappa=10),
exploit=acquisition.ExpectedImprovement(xi=0.0),
balanced=acquisition.UpperConfidenceBound(kappa=5)
)
@dataclass
class GPIterationData:
gp: np.ndarray
acq: np.ndarray
x: np.ndarray
y: np.ndarray
steps: int
def target(x):
return np.exp(-(x - 2) ** 2) + np.exp(-(x - 6) ** 2 / 10) + 1 / (x ** 2 + 1)
def posterior(optimizer, grid):
mu, sigma = optimizer._gp.predict(grid, return_std=True)
return mu, sigma
def collect_gp_data(optimizer, x, y, acq_type):
x_obs = np.array([[res["params"]["x"]] for res in optimizer.res])
y_obs = np.array([res["target"] for res in optimizer.res])
optimizer.acquisition_function._fit_gp(optimizer._gp, optimizer._space)
mu, sigma = posterior(optimizer, x)
gp_preds = np.array([
mu - 1.9600 * sigma,
mu,
mu + 1.9600 * sigma
])
utility = -1 * AQ_FN[acq_type]._get_acq(gp=optimizer._gp)(x)
steps = len(optimizer.space)
return GPIterationData(
gp=gp_preds,
acq=utility,
x=x_obs.flatten(),
y=y_obs,
steps=steps
)
def save_data(acq_type: str, n_itr: int, x: np.ndarray, y: np.ndarray, fname: str):
optimizer = BayesianOptimization(
target, {'x': (-2, 10)},
acquisition_function=AQ_FN[acq_type],
random_state=27
)
gp_data = []
optimizer.maximize(init_points=0, n_iter=1)
for i in range(n_itr):
gp_data.append(collect_gp_data(optimizer, x, y, acq_type))
optimizer.maximize(init_points=0, n_iter=1)
y_range = [np.inf, 0]
acq_range = [np.inf, 0]
# save data to HDF5 file
with h5py.File(fname, 'w') as f:
group = f.create_group('target_data')
group.create_dataset('x', data=x)
group.create_dataset('y', data=y)
for i, data in enumerate(gp_data):
grp = f.create_group(f'iteration_{i}')
grp.create_dataset('gp', data=data.gp)
grp.create_dataset('acq', data=data.acq)
grp.create_dataset('x', data=data.x)
grp.create_dataset('y', data=data.y)
grp.attrs['steps'] = data.steps
# update y_range and acq_range
y_range[0] = min(y_range[0], data.gp.min())
y_range[1] = max(y_range[1], data.gp.max())
acq_range[0] = min(acq_range[0], data.acq.min())
acq_range[1] = max(acq_range[1], data.acq.max())
group.attrs['y_range'] = y_range
group.attrs['acq_range'] = acq_range
def plot_gp(gp, gp_acq, gp_x, gp_y, x, y, y_range, acq_range, steps, outdir='out'):
fig = plt.figure(figsize=(8, 8))
fig.suptitle(
f'Step {steps}',
fontsize=30
)
gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1])
axis = plt.subplot(gs[0])
acq = plt.subplot(gs[1])
axis.plot(x, y, linewidth=3, label='Target', color='k', alpha=0.3)
axis.plot(gp_x, gp_y, 'D', markersize=8, label=u'Observations', color='tab:orange', )
axis.fill_between(
x.ravel(), gp[0], gp[2],
alpha=.3, color='tab:orange', label='GP 95% CI'
)
axis.set_xlim((-2, 10))
axis.set_ylim(*y_range)
axis.set_ylabel('f(x)', fontdict={'size': 20})
# axis.set_xlabel('x', fontdict={'size': 20})
x = x.flatten()
acq.plot(x, gp_acq, label='Acquisition Function', color='purple')
acq.plot(x[np.argmax(gp_acq)], np.max(gp_acq), 'o', markersize=15,
label=u'Next Best Guess', markerfacecolor='gold', markeredgecolor='k', markeredgewidth=1)
acq.set_xlim((-2, 10))
acq.set_ylim(acq_range[0], acq_range[1] * 1.25)
acq.set_ylabel('Acquisition', fontdict={'size': 20})
acq.set_xlabel('x', fontdict={'size': 20})
axis.legend(loc='upper right', frameon=False)
acq.legend(loc='upper right', frameon=False)
# remove x ticks from the first subplot
axis.xaxis.set_ticklabels([])
# remove both subplot x axes spines
for s in ['top', 'bottom', 'left']:
axis.spines[s].set_visible(False)
acq.spines[s].set_visible(False)
# remove vertical space between subplots (share x axis)
plt.subplots_adjust(hspace=0)
plt.tight_layout()
plt.savefig(f'{outdir}/bo_step_{steps}.png', dpi=300)
plt.close(fig)
def load_and_plot_data(fname='bo_data.h5', outdir='out'):
with h5py.File(fname, 'r') as f:
target_x = f['target_data/x'][:]
target_y = f['target_data/y'][:]
y_range = f['target_data'].attrs['y_range']
acq_range = f['target_data'].attrs['acq_range']
for i in range(len(f.keys()) - 1): # -1 to exclude 'target_data'
grp = f[f'iteration_{i}']
gp = grp['gp'][:]
acq = grp['acq'][:]
x_obs = grp['x'][:]
y_obs = grp['y'][:]
steps = grp.attrs['steps']
plot_gp(gp, acq, x_obs, y_obs, target_x, target_y, y_range, acq_range, steps, outdir)
def main():
# Generate data for the target function
x = np.linspace(-2, 10, 10000).reshape(-1, 1)
y = target(x)
n_itr = 15
for acq_type in AQ_FN.keys():
outdir = f"out/{acq_type}"
os.makedirs(outdir, exist_ok=True)
fname = os.path.join(outdir, f'bo_data_{acq_type}.h5')
if not os.path.exists(fname):
save_data(acq_type, n_itr, x, y, fname)
load_and_plot_data(fname, outdir)
if __name__ == "__main__":
main()
manim -qh bayesian_opt_anim.py BOExplore
manim -qh bayesian_opt_anim.py BOExploit
manim -qh bayesian_opt_anim.py BOBalanced
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment