This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Optuna example that optimizes the hyperparameters of | |
a reinforcement learning agent using PPO implementation from Stable-Baselines3 | |
on a Gymnasium environment. | |
This is a simplified version of what can be found in https://github.com/DLR-RM/rl-baselines3-zoo. | |
You can run this example as follows: | |
$ python optimize_ppo.py | |
""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import numpy as np | |
import seaborn as sns | |
from gymnasium import spaces | |
from stable_baselines3 import PPO | |
from stable_baselines3.common.env_util import make_vec_env | |
from stable_baselines3.common.vec_env import VecEnvWrapper | |
sns.set_theme() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A simple GUI to collect human feedback. | |
It writes the rating to a file "gui_value.txt" next to the script. | |
The rating can be reset by removing or emptying the text file. | |
Nicegui is the only dependency. | |
If you use `uv` you can do `uv run feedback_gui.py`. | |
Author: Antonin Raffin (2024) | |
MIT License |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sbx | |
import shimmy | |
import stable_baselines3 as sb3 | |
from dm_control import suite | |
from gymnasium.wrappers import FlattenObservation | |
from stable_baselines3.common.env_checker import check_env | |
# Available envs: | |
# suite._DOMAINS | |
# suite.dog.SUITE |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gymnasium as gym | |
import numpy as np | |
from gymnasium.envs.mujoco.mujoco_env import MujocoEnv | |
# Env initialization | |
env = gym.make("HalfCheetah-v4", render_mode="human") | |
# Wrap to have reward statistics | |
env = gym.wrappers.RecordEpisodeStatistics(env) | |
mujoco_env = env.unwrapped | |
n_joints = 6 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gymnasium as gym | |
import numpy as np | |
from gymnasium.envs.mujoco.mujoco_env import MujocoEnv | |
# Env initialization | |
env = gym.make("Swimmer-v4", render_mode="human") | |
# Wrap to have reward statistics | |
env = gym.wrappers.RecordEpisodeStatistics(env) | |
mujoco_env = env.unwrapped | |
n_joints = 2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// https://github.com/datalogix/google-fonts-helper | |
// npm install google-fonts-helper | |
import { download } from 'google-fonts-helper' | |
const downloader = download('https://fonts.googleapis.com/css?family=Montserrat:400,700%7CRoboto:400,400italic,700%7CRoboto+Mono&display=swap', { | |
base64: false, | |
overwriting: false, | |
outputDir: './', | |
stylePath: 'fonts.css', | |
fontsDir: 'fonts', |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
import numpy as np | |
import cma | |
from collections import OrderedDict | |
from stable_baselines import A2C | |
def flatten(params): | |
""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytest | |
import numpy as np | |
from stable_baselines import A2C, ACER, ACKTR, DQN, DDPG, PPO1, PPO2, TRPO | |
from stable_baselines.common import set_global_seeds | |
MODEL_LIST_DISCRETE = [ | |
A2C, | |
ACER, | |
ACKTR, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
tensorboard --logdir /tmp/a2c_cartpole_tensorboard/ |
NewerOlder