Created
September 18, 2018 09:43
-
-
Save araffin/ee9daee110af3b837b0e3a46a6bb403b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytest | |
import numpy as np | |
from stable_baselines import A2C, ACER, ACKTR, DQN, DDPG, PPO1, PPO2, TRPO | |
from stable_baselines.common import set_global_seeds | |
MODEL_LIST_DISCRETE = [ | |
A2C, | |
ACER, | |
ACKTR, | |
DQN, | |
PPO1, | |
PPO2, | |
TRPO | |
] | |
@pytest.mark.parametrize("model_class", MODEL_LIST_DISCRETE) | |
def test_perf_cartpole(model_class): | |
""" | |
Test if the algorithm (with a given policy) | |
can learn something on the simple CartPole environment | |
:param model_class: (BaseRLModel) A model | |
""" | |
# TODO: multiprocess if possible | |
model = model_class(policy="MlpPolicy", env='CartPole-v1', | |
tensorboard_log="/tmp/log/perf/cartpole") | |
model.learn(total_timesteps=int(1e5), seed=0) | |
env = model.get_env() | |
n_trials = 2000 | |
set_global_seeds(0) | |
obs = env.reset() | |
episode_rewards = [] | |
reward_sum = 0 | |
for _ in range(n_trials): | |
action, _ = model.predict(obs) | |
obs, reward, done, _ = env.step(action) | |
reward_sum += reward | |
if done: | |
episode_rewards.append(reward_sum) | |
reward_sum = 0 | |
assert np.mean(episode_rewards) >= 100 | |
# Free memory | |
del model, env |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pytest -v cartpole_bench.py |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
tensorboard --logdir /tmp/log/perf/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment