Last active
April 25, 2025 09:23
-
-
Save araffin/d16e77aa88ffc246856f4452ab8a2524 to your computer and use it in GitHub Desktop.
Example on how to use Optuna for automatic hyperparamer optimization with RL and SB3
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Optuna example that optimizes the hyperparameters of | |
a reinforcement learning agent using PPO implementation from Stable-Baselines3 | |
on a Gymnasium environment. | |
This is a simplified version of what can be found in https://github.com/DLR-RM/rl-baselines3-zoo. | |
You can run this example as follows: | |
$ python optimize_ppo.py | |
""" | |
from typing import Any | |
import gymnasium | |
import optuna | |
from optuna.pruners import MedianPruner | |
from optuna.samplers import TPESampler | |
from stable_baselines3 import PPO | |
from stable_baselines3.common.callbacks import EvalCallback | |
from stable_baselines3.common.env_util import make_vec_env | |
import torch | |
import torch.nn as nn | |
N_TRIALS = 500 | |
N_STARTUP_TRIALS = 10 | |
N_EVALUATIONS = 2 | |
N_TIMESTEPS = 40_000 | |
EVAL_FREQ = int(N_TIMESTEPS / N_EVALUATIONS) | |
N_EVAL_EPISODES = 10 | |
ENV_ID = "Pendulum-v1" | |
N_ENVS = 5 | |
DEFAULT_HYPERPARAMS = { | |
"policy": "MlpPolicy", | |
} | |
def sample_ppo_params(trial: optuna.Trial) -> dict[str, Any]: | |
"""Sampler for PPO hyperparameters.""" | |
# From 2**5=32 to 2**12=4096 | |
n_steps_pow = trial.suggest_int("n_steps_pow", 5, 12) | |
gamma = trial.suggest_float("gamma", 0.97, 0.9999) | |
learning_rate = trial.suggest_float("learning_rate", 3e-5, 3e-3, log=True) | |
activation_fn = trial.suggest_categorical("activation_fn", ["tanh", "relu"]) | |
n_steps = 2**n_steps_pow | |
# Display true values | |
trial.set_user_attr("n_steps", n_steps) | |
# Convert to PyTorch objects | |
activation_fn = {"tanh": nn.Tanh, "relu": nn.ReLU}[activation_fn] | |
return { | |
"n_steps": n_steps, | |
"gamma": gamma, | |
"learning_rate": learning_rate, | |
"policy_kwargs": { | |
"activation_fn": activation_fn, | |
}, | |
} | |
class TrialEvalCallback(EvalCallback): | |
"""Callback used for evaluating and reporting a trial.""" | |
def __init__( | |
self, | |
eval_env: gymnasium.Env, | |
trial: optuna.Trial, | |
n_eval_episodes: int = 5, | |
eval_freq: int = 10000, | |
deterministic: bool = True, | |
verbose: int = 0, | |
): | |
super().__init__( | |
eval_env=eval_env, | |
n_eval_episodes=n_eval_episodes, | |
eval_freq=eval_freq, | |
deterministic=deterministic, | |
verbose=verbose, | |
) | |
self.trial = trial | |
self.eval_idx = 0 | |
self.is_pruned = False | |
def _on_step(self) -> bool: | |
if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: | |
super()._on_step() | |
self.eval_idx += 1 | |
self.trial.report(self.last_mean_reward, self.eval_idx) | |
# Prune trial if need. | |
if self.trial.should_prune(): | |
self.is_pruned = True | |
return False | |
return True | |
def objective(trial: optuna.Trial) -> float: | |
vec_env = make_vec_env(ENV_ID, n_envs=N_ENVS) | |
kwargs = DEFAULT_HYPERPARAMS.copy() | |
# Sample hyperparameters. | |
kwargs.update(sample_ppo_params(trial)) | |
# Create the RL model. | |
model = PPO(env=vec_env, **kwargs) | |
# Create env used for evaluation. | |
eval_env = make_vec_env(ENV_ID, n_envs=N_ENVS) | |
# Create the callback that will periodically evaluate and report the performance. | |
eval_callback = TrialEvalCallback( | |
eval_env, | |
trial, | |
n_eval_episodes=N_EVAL_EPISODES, | |
eval_freq=max(EVAL_FREQ // N_ENVS, 1), | |
deterministic=True, | |
) | |
nan_encountered = False | |
try: | |
model.learn(N_TIMESTEPS, callback=eval_callback) | |
except AssertionError as e: | |
# Sometimes, random hyperparams can generate NaN. | |
print(e) | |
nan_encountered = True | |
finally: | |
# Free memory. | |
model.env.close() | |
eval_env.close() | |
# Tell the optimizer that the trial failed. | |
if nan_encountered: | |
return float("nan") | |
if eval_callback.is_pruned: | |
raise optuna.exceptions.TrialPruned() | |
return eval_callback.last_mean_reward | |
if __name__ == "__main__": | |
# Set pytorch num threads to 1 for faster training. | |
torch.set_num_threads(1) | |
sampler = TPESampler(n_startup_trials=N_STARTUP_TRIALS, multivariate=True) | |
# Do not prune before 1/3 of the max budget is used. | |
pruner = MedianPruner( | |
n_startup_trials=N_STARTUP_TRIALS, n_warmup_steps=N_EVALUATIONS // 3 | |
) | |
study = optuna.create_study(sampler=sampler, pruner=pruner, direction="maximize") | |
try: | |
study.optimize(objective, n_trials=N_TRIALS, timeout=600) | |
except KeyboardInterrupt: | |
pass | |
print(f"Number of finished trials: {len(study.trials)}") | |
print("Best trial:") | |
trial = study.best_trial | |
print(" Value: ", trial.value) | |
print(" Params: ") | |
for key, value in trial.params.items(): | |
print(f" {key}: {value}") | |
print(" User attrs:") | |
for key, value in trial.user_attrs.items(): | |
print(f" {key}: {value}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment