Created
July 24, 2019 22:57
-
-
Save dniku/bf816713238e4f6cfcfae1d0a8e58cfe to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import itertools | |
from contextlib import closing | |
from pathlib import Path | |
import baselines.run | |
import cv2 | |
import numpy as np | |
from baselines.common.cmd_util import make_vec_env | |
from baselines.common.vec_env import VecFrameStack | |
from tqdm.auto import tqdm | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('load_path', type=Path) | |
parser.add_argument('--seed', type=int, default=1000) | |
parser.add_argument('--render', action='store_true') | |
args = parser.parse_args() | |
model = baselines.run.main([str(v) for v in [ | |
'--env', 'BreakoutNoFrameskip-v4', | |
'--seed', args.seed, | |
'--alg', 'ppo2', | |
'--num_timesteps', 0, | |
'--network', 'cnn', | |
'--num_env', 1, | |
'--load_path', str(args.load_path), | |
]]) | |
def make_envs(env_name, seed): | |
eval_envs = make_vec_env(env_name, 'atari', num_env=1, seed=seed) | |
eval_envs = VecFrameStack(eval_envs, nstack=4) | |
return eval_envs | |
with closing(make_envs(env_name='BreakoutNoFrameskip-v4', seed=args.seed)) as eval_envs: | |
raw_observations = [] | |
obs = eval_envs.reset() | |
for _ in tqdm(itertools.count(start=0), postfix='playing'): | |
actions, _, _, _ = model.step(obs) | |
raw_observations.append(eval_envs.get_images()[0]) | |
if args.render: | |
eval_envs.render() | |
obs, reward, done, infos = eval_envs.step(actions) | |
if 'episode' in infos[0].keys(): | |
epinfo = infos[0]['episode'] | |
tqdm.write('finished episode with reward={r}, length={l}, elapsed_time={t}'.format(**epinfo)) | |
if np.isclose(epinfo['r'], 864): | |
break | |
else: | |
raw_observations = [] | |
fourcc = cv2.VideoWriter_fourcc(*'FFV1') | |
out = cv2.VideoWriter('gym_309.mkv', fourcc, 30.0, (160, 210)) | |
for frame in tqdm(raw_observations, postfix='writing video'): | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
out.write(frame) | |
out.release() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment