Created
June 15, 2018 05:09
-
-
Save pyliaorachel/07174cc61a24ea81ba67b2415de982d4 to your computer and use it in GitHub Desktop.
OpenAI Gym CartPole - Deep Q-Learning (train)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
env = gym.make('CartPole-v0') | |
# Environment parameters | |
n_actions = env.action_space.n | |
n_states = env.observation_space.shape[0] | |
# Hyper parameters | |
n_hidden = 50 | |
batch_size = 32 | |
lr = 0.01 # learning rate | |
epsilon = 0.1 # epsilon-greedy | |
gamma = 0.9 # reward discount factor | |
target_replace_iter = 100 # target network 更新間隔 | |
memory_capacity = 2000 | |
n_episodes = 4000 | |
# 建立 DQN | |
dqn = DQN(n_states, n_actions, n_hidden, batch_size, lr, epsilon, gamma, target_replace_iter, memory_capacity) | |
# 學習 | |
for i_episode in range(n_episodes): | |
t = 0 | |
rewards = 0 | |
state = env.reset() | |
while True: | |
env.render() | |
# 選擇 action | |
action = dqn.choose_action(state) | |
next_state, reward, done, info = env.step(action) | |
# 儲存 experience | |
dqn.store_transition(state, action, reward, next_state) | |
# 累積 reward | |
rewards += reward | |
# 有足夠 experience 後進行訓練 | |
if dqn.memory_counter > memory_capacity: | |
dqn.learn() | |
# 進入下一 state | |
state = next_state | |
if done: | |
print('Episode finished after {} timesteps, total rewards {}'.format(t+1, rewards)) | |
break | |
t += 1 | |
env.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment