Skip to content

Instantly share code, notes, and snippets.

@pyliaorachel
Created June 15, 2018 05:06
Show Gist options
  • Save pyliaorachel/8dd5300c0506e4410707b8471f61bf15 to your computer and use it in GitHub Desktop.
Save pyliaorachel/8dd5300c0506e4410707b8471f61bf15 to your computer and use it in GitHub Desktop.
OpenAI Gym CartPole - Deep Q-Learning (dqn framework)
class DQN(object):
def __init__(self, n_states, n_actions, n_hidden, batch_size, lr, epsilon, gamma, target_replace_iter, memory_capacity):
self.eval_net, self.target_net = Net(n_states, n_actions, n_hidden), Net(n_states, n_actions, n_hidden)
self.memory = np.zeros((memory_capacity, n_states * 2 + 2)) # 每個 memory 中的 experience 大小為 (state + next state + reward + action)
self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=lr)
self.loss_func = nn.MSELoss()
self.memory_counter = 0
self.learn_step_counter = 0 # 讓 target network 知道什麼時候要更新
self.n_states = n_states
self.n_actions = n_actions
self.n_hidden = n_hidden
self.batch_size = batch_size
self.lr = lr
self.epsilon = epsilon
self.gamma = gamma
self.target_replace_iter = target_replace_iter
self.memory_capacity = memory_capacity
def choose_action(self):
pass
def store_transition(self):
pass
def learn(self):
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment