Created
February 26, 2025 18:35
-
-
Save araffin/e069945a68aa0d51fcdff3f01e945c70 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import numpy as np | |
import seaborn as sns | |
from gymnasium import spaces | |
from stable_baselines3 import PPO | |
from stable_baselines3.common.env_util import make_vec_env | |
from stable_baselines3.common.vec_env import VecEnvWrapper | |
sns.set_theme() | |
class PlotActionVecEnvWrapper(VecEnvWrapper): | |
""" | |
VecEnv wrapper for plotting the taken actions. | |
""" | |
def __init__(self, venv, plot_freq: int = 10_000): | |
super().__init__(venv) | |
# Action buffer | |
assert isinstance(self.action_space, spaces.Box) | |
self.n_actions = self.action_space.shape[0] | |
self.actions = np.zeros((plot_freq, self.num_envs, self.n_actions)) | |
self.n_steps = 0 | |
self.plot_freq = plot_freq | |
def reset(self): | |
return self.venv.reset() | |
def step_wait(self): | |
obs, rewards, dones, infos = self.venv.step_wait() | |
return obs, rewards, dones, infos | |
def step_async(self, actions): | |
self.actions[self.n_steps % self.plot_freq] = actions | |
self.n_steps += 1 | |
if self.n_steps % self.plot_freq == 0: | |
self.plot() | |
self.venv.step_async(actions) | |
def plot(self) -> None: | |
# Flatten the env dimension | |
actions = self.actions.reshape(-1, self.n_actions) | |
n_steps = self.num_envs * self.n_steps | |
# Create a figure with subplots for each action dimension | |
n_rows = min(2, self.n_actions // 2 + 1) | |
n_cols = max(self.n_actions // 2, 1) | |
fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 10)) | |
fig.suptitle( | |
f"Distribution of Actions per Dimension after {n_steps} steps", fontsize=16 | |
) | |
# Flatten the axes array for easy iteration | |
if n_rows > 1: | |
axes = axes.flatten() | |
else: | |
# Special case, n_actions == 1 | |
axes = [axes] | |
# Plot the distribution for each action dimension | |
for i in range(self.n_actions): | |
sns.histplot(actions[:, i], kde=True, ax=axes[i], stat="density") | |
axes[i].set_title(f"Action Dimension {i+1}") | |
axes[i].set_xlabel("Action Value") | |
axes[i].set_ylabel("Density") | |
# Adjust the layout and display the plot | |
plt.tight_layout() | |
plt.show() | |
vec_env = make_vec_env("Pendulum-v1", n_envs=2) | |
wrapped_env = PlotActionVecEnvWrapper(vec_env, plot_freq=5_000) | |
# from sbx import PPO | |
# from sbx import SAC | |
# policy_kwargs = dict(log_std_init=-0.5) | |
model = PPO("MlpPolicy", wrapped_env, gamma=0.98, verbose=1) | |
model.learn(total_timesteps=1_000_000) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment