Skip to content

Instantly share code, notes, and snippets.

@pyliaorachel
Created June 15, 2018 05:08
Show Gist options
  • Save pyliaorachel/60428fa02548cdfa70293063466e6020 to your computer and use it in GitHub Desktop.
Save pyliaorachel/60428fa02548cdfa70293063466e6020 to your computer and use it in GitHub Desktop.
OpenAI Gym CartPole - Deep Q-Learning (dqn learn)
def learn(self):
# 隨機取樣 batch_size 個 experience
sample_index = np.random.choice(self.memory_capacity, self.batch_size)
b_memory = self.memory[sample_index, :]
b_state = torch.FloatTensor(b_memory[:, :self.n_states])
b_action = torch.LongTensor(b_memory[:, self.n_states:self.n_states+1].astype(int))
b_reward = torch.FloatTensor(b_memory[:, self.n_states+1:self.n_states+2])
b_next_state = torch.FloatTensor(b_memory[:, -self.n_states:])
# 計算現有 eval net 和 target net 得出 Q value 的落差
q_eval = self.eval_net(b_state).gather(1, b_action) # 重新計算這些 experience 當下 eval net 所得出的 Q value
q_next = self.target_net(b_next_state).detach() # detach 才不會訓練到 target net
q_target = b_reward + self.gamma * q_next.max(1)[0].view(self.batch_size, 1) # 計算這些 experience 當下 target net 所得出的 Q value
loss = self.loss_func(q_eval, q_target)
# Backpropagation
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# 每隔一段時間 (target_replace_iter), 更新 target net,即複製 eval net 到 target net
self.learn_step_counter += 1
if self.learn_step_counter % self.target_replace_iter == 0:
self.target_net.load_state_dict(self.eval_net.state_dict())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment