Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save MerlinPendragon/672fd7097876c4885469522ea6c2617d to your computer and use it in GitHub Desktop.
Save MerlinPendragon/672fd7097876c4885469522ea6c2617d to your computer and use it in GitHub Desktop.
Deep-Q learning implementation in Tensorflow and Keras (solving CartPole-v0)
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Deep-Q learning implementation in Tensorflow and Keras\n",
"with an example application to solving `CartPole-v0` environment.\n",
"![dqn](https://user-images.githubusercontent.com/38169187/46908388-63807200-cf22-11e8-99f3-b471405495b3.png)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# import"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import gym\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"import tensorflow as tf\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# replay buffer"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# import numpy as np\n",
"from collections import deque\n",
"import random\n",
"\n",
"class ReplayBuffer:\n",
" \"\"\"Fixed-size buffer to store experience tuples.\"\"\"\n",
"\n",
" def __init__(self, buffer_size=int(1e5), random_seed=1234):\n",
" \"\"\"Initialize a ReplayBuffer object.\n",
" Params\n",
" ======\n",
" buffer_size: maximum size of buffer\n",
" The right side of the deque contains the most recent experiences. \n",
" \"\"\"\n",
" self.buffer_size = buffer_size\n",
" self.buffer = deque(maxlen=buffer_size)\n",
" random.seed(random_seed)\n",
"\n",
" def __len__(self):\n",
" \"\"\"Return the current size of internal memory.\"\"\"\n",
" return len(self.buffer)\n",
" \n",
" def add(self, s, a, r, done, s2):\n",
" \"\"\"Add a new experience to buffer.\n",
" Params\n",
" ======\n",
" s: one state sample, numpy array shape (s_dim,)\n",
" a: one action sample, scalar (for DQN)\n",
" r: one reward sample, scalar\n",
" done: True/False, scalar\n",
" s2: one state sample, numpy array shape (s_dim,)\n",
" \"\"\"\n",
" e = (s, a, r, done, s2)\n",
" self.buffer.append(e)\n",
" \n",
" def sample_batch(self, batch_size):\n",
" \"\"\"Randomly sample a batch of experiences from buffer.\"\"\"\n",
" \n",
" # ensure the buffer is large enough for sampleling \n",
" assert (len(self.buffer) >= batch_size)\n",
" \n",
" # sample a batch\n",
" batch = random.sample(self.buffer, batch_size)\n",
" \n",
" # Convert experience tuples to separate arrays for each element (states, actions, rewards, etc.)\n",
" states, actions, rewards, dones, next_states = zip(*batch)\n",
" states = np.asarray(states).reshape(batch_size, -1) # shape (batch_size, s_dim)\n",
" next_states = np.asarray(next_states).reshape(batch_size, -1) # shape (batch_size, s_dim)\n",
" actions = np.asarray(actions) # shape (batch_size,), for DQN, action is an int\n",
" rewards = np.asarray(rewards) # shape (batch_size,)\n",
" dones = np.asarray(dones, dtype=np.uint8) # shape (batch_size,)\n",
" return states, actions, rewards, dones, next_states"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# DQN tf summary"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def build_summaries():\n",
" \"\"\"\n",
" tensorboard summary for monitoring training process\n",
" \"\"\"\n",
" \n",
" # performance per episode\n",
" ph_reward = tf.placeholder(tf.float32) \n",
" tf.summary.scalar(\"Reward_ep\", ph_reward)\n",
" ph_Qmax = tf.placeholder(tf.float32)\n",
" tf.summary.scalar(\"Qmax_ep\", ph_Qmax)\n",
" \n",
" # merge all summary op (must be done at the last step)\n",
" summary_op = tf.summary.merge_all()\n",
" \n",
" return summary_op, ph_reward, ph_Qmax\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# DQN neural network model"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"import time\n",
"from keras import layers, initializers, regularizers\n",
"from functools import partial\n",
"\n",
"def build_net(model_name, state, a_dim, args, trainable):\n",
" \"\"\"\n",
" neural network model\n",
" model input: state\n",
" model output: Qhat\n",
" \"\"\"\n",
" h1 = int(args['h1'])\n",
" h2 = int(args['h2'])\n",
" \n",
" my_dense = partial(layers.Dense, trainable=trainable)\n",
" with tf.variable_scope(model_name):\n",
" net = my_dense(h1, name=\"l1-dense-{}\".format(h1))(state) \n",
" net = layers.Activation('relu', name=\"relu1\")(net) \n",
" net = my_dense(h2, name=\"l2-dense-{}\".format(h2))(net)\n",
" net = layers.Activation('relu', name=\"relu2\")(net)\n",
" net = my_dense(a_dim, name=\"l3-dense-{}\".format(a_dim))(net)\n",
" Qhat = layers.Activation('linear', name=\"Qhat\")(net)\n",
" nn_params = tf.trainable_variables(scope=model_name)\n",
" return Qhat, nn_params"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# DQN agent"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class DeepQNetwork:\n",
" def __init__(self, sess, a_dim, s_dim, args):\n",
" self.a_dim = a_dim\n",
" self.s_dim = s_dim\n",
" self.h1 = args[\"h1\"]\n",
" self.h2 = args[\"h2\"]\n",
" self.lr = args[\"learning_rate\"]\n",
" self.gamma = args[\"gamma\"]\n",
" self.epsilon_start = args[\"epsilon_start\"]\n",
" self.epsilon_stop = args[\"epsilon_stop\"]\n",
" self.epsilon_decay = args[\"epsilon_decay\"]\n",
" self.epsilon = self.epsilon_start # current exploration probability\n",
" self.update_target_C = args[\"update_target_C\"]\n",
" self.update_target_tau = args['update_target_tau']\n",
" self.learn_step_counter = 0\n",
" \n",
" # initialize replay buffer\n",
" self.replay_buffer = ReplayBuffer(int(args['buffer_size']), int(args['random_seed']))\n",
" self.minibatch_size = int(args['minibatch_size'])\n",
"\n",
" self.s = tf.placeholder(tf.float32, [None, self.s_dim], name='state') # input State\n",
" self.s_ = tf.placeholder(tf.float32, [None, self.s_dim], name='state_next') # input Next State\n",
" self.r = tf.placeholder(tf.float32, [None,], name='reward') # input Reward\n",
" self.a = tf.placeholder(tf.int32, [None,], name='action') # input Action\n",
" self.done = tf.placeholder(tf.float32, [None,], name='done')\n",
" \n",
" # initialize NN, self.q shape (batch_size, a_dim)\n",
" self.q, self.nn_params = build_net(\"DQN\", self.s, a_dim, args, trainable=True)\n",
" self.q_, self.nn_params_ = build_net(\"target_DQN\", self.s_, a_dim, args, trainable=False)\n",
" for var in self.nn_params:\n",
" vname = var.name.replace(\"kernel:0\", \"W\").replace(\"bias:0\", \"b\")\n",
" tf.summary.histogram(vname, var)\n",
"\n",
" with tf.variable_scope(\"Qmax\"):\n",
" self.Qmax = tf.reduce_max(self.q_, axis=1) # shape (batch_size,)\n",
"\n",
" with tf.variable_scope(\"yi\"):\n",
" self.yi = self.r + self.gamma * self.Qmax * (1 - self.done) # shape (batch_size,)\n",
" \n",
" with tf.variable_scope(\"Qa_all\"):\n",
" Qa = tf.Variable(tf.zeros([self.minibatch_size, self.a_dim]))\n",
" for aval in np.arange(self.a_dim):\n",
" tf.summary.histogram(\"Qa{}\".format(aval), Qa[:, aval])\n",
" self.Qa_op = Qa.assign(self.q)\n",
" \n",
" with tf.variable_scope(\"Q_at_a\"):\n",
" # select the Q value corresponding to the action\n",
" one_hot_actions = tf.one_hot(self.a, self.a_dim) # shape (batch_size, a_dim)\n",
" q_all = tf.multiply(self.q, one_hot_actions) # shape (batch_size, a_dim)\n",
" self.q_at_a = tf.reduce_sum(q_all, axis=1) # shape (batch_size,)\n",
" \n",
" with tf.variable_scope(\"loss_MSE\"):\n",
" self.loss = tf.losses.mean_squared_error(labels=self.yi, predictions=self.q_at_a)\n",
" \n",
" with tf.variable_scope(\"train_DQN\"):\n",
" self.train_op = tf.train.AdamOptimizer(self.lr).minimize(loss=self.loss, var_list=self.nn_params)\n",
" \n",
" with tf.variable_scope(\"soft_update\"):\n",
" TAU = self.update_target_tau \n",
" self.update_op = [tf.assign(t, (1 - TAU)*t + TAU*e) for t, e in zip(self.nn_params_, self.nn_params)]\n",
" \n",
" \n",
" def choose_action(self, sess, observation):\n",
" # Explore or Exploit\n",
" explore_p = self.epsilon # exploration probability\n",
" \n",
" if np.random.uniform() <= explore_p:\n",
" # Explore: make a random action\n",
" action = np.random.randint(0, self.a_dim)\n",
" else:\n",
" # Exploit: Get action from Q-network\n",
" observation = np.reshape(observation, (1, self.s_dim))\n",
" Qs = sess.run(self.q, feed_dict={self.s: observation}) # shape (1, a_dim)\n",
" action = np.argmax(Qs[0])\n",
" return action\n",
"\n",
" \n",
" def learn_a_batch(self, sess):\n",
" # update target every C learning steps\n",
" if self.learn_step_counter % self.update_target_C == 0:\n",
" sess.run(self.update_op)\n",
" \n",
" # Sample a batch\n",
" s_batch, a_batch, r_batch, done_batch, s2_batch = self.replay_buffer.sample_batch(self.minibatch_size)\n",
" \n",
" # Train\n",
" _, _, Qhat, loss = sess.run([self.train_op, self.Qa_op, self.q_at_a, self.loss], feed_dict={\n",
" self.s: s_batch, self.a: a_batch, self.r: r_batch, self.done: done_batch, self.s_: s2_batch})\n",
" \n",
" # count learning steps\n",
" self.learn_step_counter += 1\n",
" \n",
" # decay exploration probability after each learning step\n",
" if self.epsilon > self.epsilon_stop:\n",
" self.epsilon *= self.epsilon_decay\n",
" \n",
" return np.max(Qhat)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# args `CartPole-v0`"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"code_folding": []
},
"outputs": [],
"source": [
"args = {\"env\": 'CartPole-v0',\n",
" \"random_seed\": 1234,\n",
" \"max_episodes\": 150, # number of episodes\n",
" \"max_episode_len\": 200, # time steps per episode, 200 for CartPole-v0\n",
" ## NN params\n",
" \"h1\": 32, # 32 \n",
" \"h2\": 64, # 64\n",
" \"learning_rate\": 0.001, # 1e-3\n",
" \"gamma\": 0.9, # 0.9 (32), 0.95 (34) better than 0.99\n",
" \"update_target_C\": 1, # update every C learning steps (C=1 if soft update, C=100 if hard update)\n",
" \"update_target_tau\": 8e-2, # soft update (tau=8e-2), hard update (tau=1)\n",
" ## exploration prob\n",
" \"epsilon_start\": 1.0, \n",
" \"epsilon_stop\": 0.01, # 0.01\n",
" \"epsilon_decay\": 0.999, # 0.999\n",
" ## replay buffer\n",
" \"buffer_size\": 1e5, \n",
" \"minibatch_size\": 32, # 32\n",
" ## tensorboard logs\n",
" \"summary_dir\": './results/dqn', \n",
" }\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# main training"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/ying/gym/gym/__init__.py:22: UserWarning: DEPRECATION WARNING: to improve load times, gym no longer automatically loads gym.spaces. Please run \"import gym.spaces\" to load gym.spaces on your own. This warning will turn into an error in a future version of gym.\n",
" warnings.warn('DEPRECATION WARNING: to improve load times, gym no longer automatically loads gym.spaces. Please run \"import gym.spaces\" to load gym.spaces on your own. This warning will turn into an error in a future version of gym.')\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.\u001b[0m\n",
"states: Box(4,)\n",
"actions: Discrete(2)\n",
"episode: 0/150, steps: 23, explore_prob: 1.00, total reward: 23.0\n",
"episode: 10/150, steps: 13, explore_prob: 0.89, total reward: 13.0\n",
"episode: 20/150, steps: 21, explore_prob: 0.69, total reward: 21.0\n",
"episode: 30/150, steps: 17, explore_prob: 0.49, total reward: 17.0\n",
"episode: 40/150, steps: 200, explore_prob: 0.11, total reward: 200.0\n",
"episode: 50/150, steps: 200, explore_prob: 0.01, total reward: 200.0\n",
"episode: 60/150, steps: 200, explore_prob: 0.01, total reward: 200.0\n",
"episode: 70/150, steps: 200, explore_prob: 0.01, total reward: 200.0\n",
"episode: 80/150, steps: 200, explore_prob: 0.01, total reward: 200.0\n",
"episode: 90/150, steps: 200, explore_prob: 0.01, total reward: 200.0\n",
"episode: 100/150, steps: 200, explore_prob: 0.01, total reward: 200.0\n",
"episode: 110/150, steps: 200, explore_prob: 0.01, total reward: 200.0\n",
"episode: 120/150, steps: 200, explore_prob: 0.01, total reward: 200.0\n",
"episode: 130/150, steps: 200, explore_prob: 0.01, total reward: 200.0\n",
"episode: 140/150, steps: 200, explore_prob: 0.01, total reward: 200.0\n"
]
}
],
"source": [
"sess = tf.InteractiveSession()\n",
"tf.set_random_seed(int(args['random_seed']))\n",
"\n",
"# initialize numpy seed\n",
"np.random.seed(int(args['random_seed']))\n",
"\n",
"# initialize gym env\n",
"env = gym.make(args['env'])\n",
"env.seed(int(args['random_seed']))\n",
"state_size = env.observation_space.shape[0]\n",
"action_size = env.action_space.n\n",
"print(\"states:\", env.observation_space)\n",
"print(\"actions:\", env.action_space)\n",
"\n",
"# initialize DQN agent\n",
"agent = DeepQNetwork(sess, action_size, state_size, args)\n",
"\n",
"# initialize summary (for visualization in tensorboard)\n",
"summary_op, ph_reward, ph_Qmax = build_summaries()\n",
"subdir = time.strftime(\"%Y%m%d-%H%M%S\", time.localtime()) # a sub folder, e.g., yyyymmdd-HHMMSS\n",
"logdir = args['summary_dir'] + '/' + subdir\n",
"writer = tf.summary.FileWriter(logdir, sess.graph) # must be done after graph is constructed\n",
"\n",
"# initialize variables existed in the graph\n",
"sess.run(tf.global_variables_initializer())\n",
"\n",
"# training DQN agent\n",
"rewards_list = []\n",
"loss = -999\n",
"num_ep = args['max_episodes']\n",
"max_t = args['max_episode_len']\n",
"for ep in range(num_ep):\n",
" state= env.reset() # shape (s_dim,)\n",
" ep_reward = 0 # total reward per episode\n",
" ep_qmax = 0\n",
" t_step = 0\n",
" done = False\n",
" while (t_step < max_t) and (not done):\n",
" \n",
" # choose an action\n",
" action = agent.choose_action(sess, state)\n",
" \n",
" # interact with the env\n",
" next_state, reward, done, _ = env.step(action)\n",
" \n",
" # add the experience to replay buffer\n",
" agent.replay_buffer.add(state, action, reward, done, next_state)\n",
" \n",
" # learn from a batch of experiences\n",
" if len(agent.replay_buffer) > 3 * agent.minibatch_size:\n",
" qmax = agent.learn_a_batch(sess)\n",
" ep_qmax = max(ep_qmax, qmax)\n",
" \n",
" # next time step\n",
" t_step += 1\n",
" ep_reward += reward\n",
" state= next_state\n",
" \n",
" # end of an episode\n",
" rewards_list.append((ep, ep_reward))\n",
"\n",
" # write to tensorboard summary\n",
" summary_str = sess.run(summary_op, feed_dict={ph_reward: ep_reward, ph_Qmax: ep_qmax})\n",
" writer.add_summary(summary_str, ep)\n",
" writer.flush()\n",
"\n",
" if ep % 10 == 0:\n",
" print(\"episode: {}/{}, steps: {}, explore_prob: {:.2f}, total reward: {}\".\\\n",
" format(ep, num_ep, t_step, agent.epsilon, ep_reward))\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Solved Requirements `CartPole-v0`**\n",
"\n",
"https://github.com/openai/gym/wiki/CartPole-v0\n",
"\n",
"Considered solved when the average reward is greater than or equal to **195.0** over 100 consecutive trials."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"episodes before solving: 32\n"
]
}
],
"source": [
"def my_sma(x, N):\n",
" \"\"\"simple moving average over a window of N samples\"\"\"\n",
" filt = np.ones(N) / N\n",
" xm = np.convolve(x, filt)\n",
" xm = xm[:-(N-1)] # remove the last (N-1) elements\n",
" return xm\n",
"\n",
"eps, rewards = np.array(rewards_list).T\n",
"\n",
"# plot reward v.s. episode\n",
"plt.plot(eps, rewards)\n",
"plt.xlabel('episode')\n",
"plt.ylabel('reward')\n",
"plt.show()\n",
"\n",
"# check solved requirements\n",
"N = 100\n",
"thr = 195.0\n",
"ep_solve = np.argwhere(my_sma(rewards, N) >= thr).ravel()[0] - N # find where sma > thr \n",
"print(\"episodes before solving: {}\".format(ep_solve))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": false,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {
"height": "calc(100% - 180px)",
"left": "10px",
"top": "150px",
"width": "288px"
},
"toc_section_display": true,
"toc_window_display": true
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment