Skip to content

Instantly share code, notes, and snippets.

@pyliaorachel
Created June 15, 2018 05:09
Show Gist options
  • Save pyliaorachel/07174cc61a24ea81ba67b2415de982d4 to your computer and use it in GitHub Desktop.
Save pyliaorachel/07174cc61a24ea81ba67b2415de982d4 to your computer and use it in GitHub Desktop.
OpenAI Gym CartPole - Deep Q-Learning (train)
env = gym.make('CartPole-v0')
# Environment parameters
n_actions = env.action_space.n
n_states = env.observation_space.shape[0]
# Hyper parameters
n_hidden = 50
batch_size = 32
lr = 0.01 # learning rate
epsilon = 0.1 # epsilon-greedy
gamma = 0.9 # reward discount factor
target_replace_iter = 100 # target network 更新間隔
memory_capacity = 2000
n_episodes = 4000
# 建立 DQN
dqn = DQN(n_states, n_actions, n_hidden, batch_size, lr, epsilon, gamma, target_replace_iter, memory_capacity)
# 學習
for i_episode in range(n_episodes):
t = 0
rewards = 0
state = env.reset()
while True:
env.render()
# 選擇 action
action = dqn.choose_action(state)
next_state, reward, done, info = env.step(action)
# 儲存 experience
dqn.store_transition(state, action, reward, next_state)
# 累積 reward
rewards += reward
# 有足夠 experience 後進行訓練
if dqn.memory_counter > memory_capacity:
dqn.learn()
# 進入下一 state
state = next_state
if done:
print('Episode finished after {} timesteps, total rewards {}'.format(t+1, rewards))
break
t += 1
env.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment