Last active
December 7, 2018 09:14
-
-
Save nikola-j/cd12c5916bd229c7002e7ecbf66ae287 to your computer and use it in GitHub Desktop.
Table Q learning
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from time import sleep | |
import gym | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import sys | |
import os | |
def val_to_bin(obs): | |
"""Turns observation into positions in q matrix""" | |
res = [] | |
for i in range(len(obs)): | |
loc = np.clip(int(round((obs[i] - minimums[i]) / precisions[i])), | |
0, | |
q_elem_num[i] - 1) | |
res.append(loc) | |
return tuple(res) | |
def block_print(): | |
sys.stdout = open(os.devnull, 'w') | |
def enable_print(): | |
sys.stdout = sys.__stdout__ | |
if __name__ == "__main__": | |
env = gym.make('CartPole-v0') | |
env.reset() | |
# Minimum and maximum values observed for each observation variable | |
minimums = [-2.4, -5, -0.7295476, -5] | |
maxes = [2.4, 5, 0.7295476, 5] | |
# What precisions to use for binning observations | |
precisions = [0.5, 1, 0.1, 1] | |
q_elem_num = [] | |
for i in range(4): | |
q_elem_num.append(int(round((maxes[i] - minimums[i]) / precisions[i]))) | |
# Add two possible actions | |
q_elem_num.append(2) | |
# Prints number of bins for each observation + 2 actions | |
print(q_elem_num) | |
# Create a random 5 dimensional matrix with q_elem_num numbers of elements | |
q = np.random.uniform(low=-1, high=1, size=q_elem_num) | |
# Episode lenghts | |
best_len = 0 | |
lengths = [] | |
curr_avg = 0 | |
avg_len = [] | |
q_best = q | |
num_episodes = 3000 | |
# Step size / learning rate | |
lr = 5e-2 | |
discount = 0.99 | |
for j in range(0, num_episodes): | |
eps = 0.1 # We can fix eps to 0.1 since Q learning will learn the | |
observation = env.reset() | |
position = val_to_bin(observation) | |
values = q[position] | |
action = np.argmax(values) | |
done = False | |
episode_length = 0 | |
reward = 1 | |
while episode_length < 10000 and reward > 0: | |
if np.random.random() < eps: | |
action = env.action_space.sample() | |
# Disable WARN from env, this minigame is meant to be played for only 200 steps | |
# but we play for 10k steps | |
block_print() | |
new_observation, reward, done, _ = env.step(action) | |
enable_print() | |
if reward < 1: # When we get a negative reward the episode ends | |
reward = -300 | |
done = True | |
else: | |
done = False | |
# Save best Q table for inference | |
if episode_length >= best_len: | |
best_len = episode_length | |
if episode_length > 100: | |
q_best = np.copy(q) | |
new_observation = np.array(new_observation) | |
new_position = val_to_bin(new_observation) | |
new_action = np.argmax(q[new_position]) | |
new_action_value = q[new_position][new_action] | |
if done: | |
q[position][action] = q[position][action] + lr * reward | |
else: | |
q[position][action] = q[position][action] + lr * ( | |
reward + discount * new_action_value - q[position][action]) | |
action = new_action | |
observation = new_observation | |
position = new_position | |
episode_length += 1 | |
lengths.append(episode_length) | |
# Moving average of episode length | |
curr_avg += (episode_length - curr_avg) / 100 | |
avg_len.append(curr_avg) | |
if j % 100 == 0: | |
print("Iter:", j, "Best len", best_len, "Eps:", eps, "Discount:", | |
discount, "Average len:", avg_len[-1]) | |
# Stop early if average is high | |
if curr_avg > 5000: | |
break | |
# Plot results | |
plt.plot(lengths) | |
plt.plot(avg_len) | |
plt.show() | |
# Take best q table for inference | |
q = q_best | |
i = 0 | |
observation = env.reset() | |
position = val_to_bin(observation) | |
values = q[tuple(position)] | |
action = np.argmax(values) | |
# Run inference indefinitely | |
while True: | |
observation, reward, done, _ = env.step(action) | |
if reward < 1: | |
sleep(3) | |
env.reset() | |
position = val_to_bin(observation) | |
values = q[tuple(position)] | |
action = np.argmax(values) | |
env.render() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
README
Requirements
Environment
The problem consists of balancing a pole connected with one joint on top of a moving cart. The only actions are to add a force of -1 or +1 to the cart, pushing it left or right.
The observations are:
Running
Run the code with:
This will train the agent for 3000 episodes, after it is done it will output a graph showing the episode lengths, with a curve showing the average episode length.
After that, it will run inference using the q table with the highest score indefinitely and render it.
Results
Most of the time it will converge to a very good solution that can keep the pole balanced for a long time, the times it doesn't can be attributed to several things: