Skip to content

Instantly share code, notes, and snippets.

@araffin
Created February 26, 2025 18:35
Show Gist options
  • Save araffin/e069945a68aa0d51fcdff3f01e945c70 to your computer and use it in GitHub Desktop.
Save araffin/e069945a68aa0d51fcdff3f01e945c70 to your computer and use it in GitHub Desktop.
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecEnvWrapper
sns.set_theme()
class PlotActionVecEnvWrapper(VecEnvWrapper):
"""
VecEnv wrapper for plotting the taken actions.
"""
def __init__(self, venv, plot_freq: int = 10_000):
super().__init__(venv)
# Action buffer
assert isinstance(self.action_space, spaces.Box)
self.n_actions = self.action_space.shape[0]
self.actions = np.zeros((plot_freq, self.num_envs, self.n_actions))
self.n_steps = 0
self.plot_freq = plot_freq
def reset(self):
return self.venv.reset()
def step_wait(self):
obs, rewards, dones, infos = self.venv.step_wait()
return obs, rewards, dones, infos
def step_async(self, actions):
self.actions[self.n_steps % self.plot_freq] = actions
self.n_steps += 1
if self.n_steps % self.plot_freq == 0:
self.plot()
self.venv.step_async(actions)
def plot(self) -> None:
# Flatten the env dimension
actions = self.actions.reshape(-1, self.n_actions)
n_steps = self.num_envs * self.n_steps
# Create a figure with subplots for each action dimension
n_rows = min(2, self.n_actions // 2 + 1)
n_cols = max(self.n_actions // 2, 1)
fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 10))
fig.suptitle(
f"Distribution of Actions per Dimension after {n_steps} steps", fontsize=16
)
# Flatten the axes array for easy iteration
if n_rows > 1:
axes = axes.flatten()
else:
# Special case, n_actions == 1
axes = [axes]
# Plot the distribution for each action dimension
for i in range(self.n_actions):
sns.histplot(actions[:, i], kde=True, ax=axes[i], stat="density")
axes[i].set_title(f"Action Dimension {i+1}")
axes[i].set_xlabel("Action Value")
axes[i].set_ylabel("Density")
# Adjust the layout and display the plot
plt.tight_layout()
plt.show()
vec_env = make_vec_env("Pendulum-v1", n_envs=2)
wrapped_env = PlotActionVecEnvWrapper(vec_env, plot_freq=5_000)
# from sbx import PPO
# from sbx import SAC
# policy_kwargs = dict(log_std_init=-0.5)
model = PPO("MlpPolicy", wrapped_env, gamma=0.98, verbose=1)
model.learn(total_timesteps=1_000_000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment