Skip to content

Instantly share code, notes, and snippets.

@araffin
Last active April 25, 2025 09:23
Show Gist options
  • Save araffin/d16e77aa88ffc246856f4452ab8a2524 to your computer and use it in GitHub Desktop.
Save araffin/d16e77aa88ffc246856f4452ab8a2524 to your computer and use it in GitHub Desktop.
Example on how to use Optuna for automatic hyperparamer optimization with RL and SB3
"""Optuna example that optimizes the hyperparameters of
a reinforcement learning agent using PPO implementation from Stable-Baselines3
on a Gymnasium environment.
This is a simplified version of what can be found in https://github.com/DLR-RM/rl-baselines3-zoo.
You can run this example as follows:
$ python optimize_ppo.py
"""
from typing import Any
import gymnasium
import optuna
from optuna.pruners import MedianPruner
from optuna.samplers import TPESampler
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.env_util import make_vec_env
import torch
import torch.nn as nn
N_TRIALS = 500
N_STARTUP_TRIALS = 10
N_EVALUATIONS = 2
N_TIMESTEPS = 40_000
EVAL_FREQ = int(N_TIMESTEPS / N_EVALUATIONS)
N_EVAL_EPISODES = 10
ENV_ID = "Pendulum-v1"
N_ENVS = 5
DEFAULT_HYPERPARAMS = {
"policy": "MlpPolicy",
}
def sample_ppo_params(trial: optuna.Trial) -> dict[str, Any]:
"""Sampler for PPO hyperparameters."""
# From 2**5=32 to 2**12=4096
n_steps_pow = trial.suggest_int("n_steps_pow", 5, 12)
gamma = trial.suggest_float("gamma", 0.97, 0.9999)
learning_rate = trial.suggest_float("learning_rate", 3e-5, 3e-3, log=True)
activation_fn = trial.suggest_categorical("activation_fn", ["tanh", "relu"])
n_steps = 2**n_steps_pow
# Display true values
trial.set_user_attr("n_steps", n_steps)
# Convert to PyTorch objects
activation_fn = {"tanh": nn.Tanh, "relu": nn.ReLU}[activation_fn]
return {
"n_steps": n_steps,
"gamma": gamma,
"learning_rate": learning_rate,
"policy_kwargs": {
"activation_fn": activation_fn,
},
}
class TrialEvalCallback(EvalCallback):
"""Callback used for evaluating and reporting a trial."""
def __init__(
self,
eval_env: gymnasium.Env,
trial: optuna.Trial,
n_eval_episodes: int = 5,
eval_freq: int = 10000,
deterministic: bool = True,
verbose: int = 0,
):
super().__init__(
eval_env=eval_env,
n_eval_episodes=n_eval_episodes,
eval_freq=eval_freq,
deterministic=deterministic,
verbose=verbose,
)
self.trial = trial
self.eval_idx = 0
self.is_pruned = False
def _on_step(self) -> bool:
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
super()._on_step()
self.eval_idx += 1
self.trial.report(self.last_mean_reward, self.eval_idx)
# Prune trial if need.
if self.trial.should_prune():
self.is_pruned = True
return False
return True
def objective(trial: optuna.Trial) -> float:
vec_env = make_vec_env(ENV_ID, n_envs=N_ENVS)
kwargs = DEFAULT_HYPERPARAMS.copy()
# Sample hyperparameters.
kwargs.update(sample_ppo_params(trial))
# Create the RL model.
model = PPO(env=vec_env, **kwargs)
# Create env used for evaluation.
eval_env = make_vec_env(ENV_ID, n_envs=N_ENVS)
# Create the callback that will periodically evaluate and report the performance.
eval_callback = TrialEvalCallback(
eval_env,
trial,
n_eval_episodes=N_EVAL_EPISODES,
eval_freq=max(EVAL_FREQ // N_ENVS, 1),
deterministic=True,
)
nan_encountered = False
try:
model.learn(N_TIMESTEPS, callback=eval_callback)
except AssertionError as e:
# Sometimes, random hyperparams can generate NaN.
print(e)
nan_encountered = True
finally:
# Free memory.
model.env.close()
eval_env.close()
# Tell the optimizer that the trial failed.
if nan_encountered:
return float("nan")
if eval_callback.is_pruned:
raise optuna.exceptions.TrialPruned()
return eval_callback.last_mean_reward
if __name__ == "__main__":
# Set pytorch num threads to 1 for faster training.
torch.set_num_threads(1)
sampler = TPESampler(n_startup_trials=N_STARTUP_TRIALS, multivariate=True)
# Do not prune before 1/3 of the max budget is used.
pruner = MedianPruner(
n_startup_trials=N_STARTUP_TRIALS, n_warmup_steps=N_EVALUATIONS // 3
)
study = optuna.create_study(sampler=sampler, pruner=pruner, direction="maximize")
try:
study.optimize(objective, n_trials=N_TRIALS, timeout=600)
except KeyboardInterrupt:
pass
print(f"Number of finished trials: {len(study.trials)}")
print("Best trial:")
trial = study.best_trial
print(" Value: ", trial.value)
print(" Params: ")
for key, value in trial.params.items():
print(f" {key}: {value}")
print(" User attrs:")
for key, value in trial.user_attrs.items():
print(f" {key}: {value}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment