Created
October 22, 2016 20:34
-
-
Save chetandhembre/9d83dc8ea1bc2daf034e4d2edc5d71de to your computer and use it in GitHub Desktop.
Clif World Q- Learning Vs SARSA learning
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://dl.dropboxusercontent.com/u/47591917/cliff_world_episode_length.png | |
# https://dl.dropboxusercontent.com/u/47591917/cliff_world_episode_reward.png | |
# https://dl.dropboxusercontent.com/u/47591917/cliff_world_episode_timestamp.png | |
# https://dl.dropboxusercontent.com/u/47591917/cliff_world_path.png | |
import numpy as np | |
import matplotlib | |
import math | |
import pandas as pd | |
from matplotlib import pyplot as plt | |
from matplotlib.patches import Circle, Wedge, Polygon | |
from matplotlib.collections import PatchCollection | |
import matplotlib.patches as patches | |
import matplotlib.patches as mpatches | |
UP = 0 | |
DOWN = 1 | |
RIGHT = 2 | |
LEFT = 3 | |
ALLOWED_ACTIONS = [UP, DOWN, RIGHT, LEFT] | |
NORMAL_REWARD = -1 | |
CLIFF_REWARD = -100 | |
class State(object): | |
def __init__(self, x, y): | |
self.x = x | |
self.y = y | |
def __hash__(self): | |
return hash((self.x, self.y)) | |
def __eq__(self, other): | |
return (self.x == other.x) and (self.y == other.y) | |
def __str__(self): | |
return str((self.x, self.y)) | |
class StateAction(object): | |
def __init__(self, state, action): | |
self.state = state | |
self.action = action | |
def __hash__(self): | |
return hash((self.state, self.action)) | |
def __eq__(self, other): | |
return (self.state == other.state and self.action == other.action) | |
def __str__(self): | |
return str((str(self.state), self.y)) | |
class ValueMap(object): | |
def __init__(self, width, height, epsilon, discount_factor, alpha): | |
self.width = width | |
self.height = height | |
self.epsilon = epsilon | |
self.discount_factor = discount_factor | |
self.alpha = alpha | |
self.value_map = {} | |
self.cliff = {} | |
def initialize(self): | |
for i in range(self.height): | |
for j in range(self.width): | |
state = State(i, j) | |
for action in ALLOWED_ACTIONS: | |
action_state = StateAction(state, action) | |
if i == 0 and self.width - 1 > j > 0: | |
self.cliff[state] = 1 | |
self.value_map[action_state] = -9999 | |
else: | |
self.cliff[state] = 0 | |
self.value_map[action_state] = 0 | |
def get_greddy_action(self, state): | |
max_value = -np.inf | |
max_action = None | |
for action in ALLOWED_ACTIONS: | |
action_state = StateAction(state, action) | |
value = self.value_map[action_state] | |
if max_value <= value: | |
max_value = value | |
max_action = action | |
return max_action | |
def select_next_action(self, current_state, is_greddy=False): | |
greedy_action = self.get_greddy_action(current_state) | |
if is_greddy: | |
return greedy_action | |
actions_probabilities =np.ones(len(ALLOWED_ACTIONS)) * self.epsilon / len(ALLOWED_ACTIONS) | |
actions_probabilities[greedy_action] = actions_probabilities[greedy_action] + (1 - self.epsilon) | |
return np.random.choice(np.arange(len(actions_probabilities)), p=actions_probabilities) | |
def _get_next_state(self, current_state, current_action): | |
x, y = current_state.x, current_state.y | |
x_new, y_new = x, y | |
if current_action == UP: | |
x_new = x + 1 | |
elif current_action == RIGHT: | |
y_new = y + 1 | |
elif current_action == DOWN: | |
x_new = x - 1 | |
elif current_action == LEFT: | |
y_new = y - 1 | |
x_new = max(0, x_new) | |
x_new = min(self.height - 1, x_new) | |
y_new = min(self.width - 1, y_new) | |
y_new = max(0, y_new) | |
return State(x_new, y_new) | |
def _get_reward(self, state): | |
if self.cliff[state]: | |
return CLIFF_REWARD | |
return NORMAL_REWARD | |
def update_value_map(self, current_state, current_action, is_sarsa=False, is_greddy=False): | |
current_action_state = StateAction(current_state, current_action) | |
next_state = self._get_next_state(current_state, current_action) | |
reward = NORMAL_REWARD | |
if self.cliff[next_state]: | |
reward = CLIFF_REWARD | |
next_action = self.select_next_action(next_state, is_greddy=is_greddy) | |
if not is_sarsa: | |
next_action = self.get_greddy_action(next_state) | |
next_action_state = StateAction(next_state, next_action) | |
target_value = self.value_map[next_action_state] | |
old_value = self.value_map[current_action_state] | |
self.value_map[current_action_state] = old_value + self.alpha * float((reward + self.discount_factor * target_value - old_value)) | |
return next_state if not self.cliff[next_state] else None, reward | |
class Game(object): | |
def __init__(self, width, height, no_episodes, alpha=0.1, discount_factor=1, epsilon=0.1): | |
self.width = width | |
self.height = height | |
self.value_map = ValueMap(width, height, epsilon, discount_factor, alpha) | |
self.no_action_per_episodes = {} | |
self.rewards_per_episodes = {} | |
self.no_episodes = no_episodes | |
self.start = State(0, 0) | |
self.end = State(0, width - 1) | |
self.actions_episodes = [] | |
self.rewards_episodes = [] | |
self.last_episode_actions = [] | |
self.sarsa_actions_episodes = [] | |
self.sarsa_rewards_episodes = [] | |
self.sasra_last_episode_actions = [] | |
def plot(self): | |
noshow = True | |
labels = [] | |
labels.append(r'Q learning') | |
labels.append(r'SARSA learning') | |
# Plot the episode length over time | |
fig1 = plt.figure(figsize=(10,6)) | |
plt.plot(self.actions_episodes) | |
plt.plot(self.sarsa_actions_episodes) | |
plt.xlabel("Epsiode") | |
plt.ylabel("Epsiode Length") | |
plt.title("Episode Length over Time") | |
plt.legend(labels, ncol=4, loc='center left', | |
bbox_to_anchor=[0.5, 1.1], | |
columnspacing=1.0, labelspacing=0.0, | |
handletextpad=0.0, handlelength=1.5, | |
fancybox=True, shadow=True) | |
plt.savefig('cliff_world_episode_length.png') | |
if noshow: | |
plt.close(fig1) | |
else: | |
plt.show(fig1) | |
# # Plot the episode reward over time | |
fig2 = plt.figure(figsize=(10,6)) | |
smoothing_window = 10 | |
rewards_smoothed = pd.Series(self.rewards_episodes).rolling(smoothing_window, min_periods=smoothing_window).mean() | |
plt.plot(rewards_smoothed) | |
rewards_smoothed = pd.Series(self.sarsa_rewards_episodes).rolling(smoothing_window, min_periods=smoothing_window).mean() | |
plt.plot(rewards_smoothed) | |
plt.xlabel("Epsiode") | |
plt.ylabel("Epsiode Reward (Smoothed)") | |
plt.title("Episode Reward over Time (Smoothed over window size {})".format(smoothing_window)) | |
plt.legend(labels, ncol=4, loc='center left', | |
bbox_to_anchor=[0.5, 1.1], | |
columnspacing=1.0, labelspacing=0.0, | |
handletextpad=0.0, handlelength=1.5, | |
fancybox=True, shadow=True) | |
plt.savefig('cliff_world_episode_reward.png') | |
if noshow: | |
plt.close(fig2) | |
else: | |
plt.show(fig2) | |
# Plot time steps and episode number | |
fig3 = plt.figure(figsize=(10,6)) | |
plt.plot(np.cumsum(self.actions_episodes), np.arange(len(self.actions_episodes))) | |
plt.plot(np.cumsum(self.sarsa_actions_episodes), np.arange(len(self.sarsa_actions_episodes))) | |
plt.xlabel("Time Steps") | |
plt.ylabel("Episode") | |
plt.title("Episode per time step") | |
plt.legend(labels, ncol=4, loc='center left', | |
bbox_to_anchor=[0.5, 1.1], | |
columnspacing=1.0, labelspacing=0.0, | |
handletextpad=0.0, handlelength=1.5, | |
fancybox=True, shadow=True) | |
plt.savefig('cliff_world_episode_timestamp.png') | |
if noshow: | |
plt.close(fig3) | |
else: | |
plt.show(fig3) | |
plt.plot(self.start.y, self.start.x, 'x', markersize=20) | |
previous_x = self.start.y | |
previous_y = self.start.x | |
for position in self.last_episode_actions[1:]: | |
x, y = position.y, position.x | |
plt.arrow(previous_x, previous_y, x - previous_x, y - previous_y, head_width=0.3, head_length=0.3, overhang=0, color='blue', label="Q learning") | |
plt.plot(x, y, 'o', markersize=5) | |
previous_x = x | |
previous_y = y | |
previous_x = self.start.y | |
previous_y = self.start.x | |
for position in self.sasra_last_episode_actions[1:]: | |
x, y = position.y, position.x | |
plt.arrow(previous_x, previous_y, x - previous_x, y - previous_y, head_width=0.3, head_length=0.3, overhang=0, color='red', label="SARSA learning") | |
plt.plot(x, y, 'o', markersize=5) | |
previous_x = x | |
previous_y = y | |
plt.plot(self.end.y, self.end.x, 'x', markersize=20) | |
axes = plt.gca() | |
axes.set_xticks(range(-1, self.width + 1)) | |
axes.set_yticks(range(-1, self.height + 1)) | |
axes.set_title('path to reach destination') | |
axes.add_patch( | |
patches.Rectangle( | |
(1, 0), # (x,y) | |
self.width - 3, # width | |
1, # height | |
alpha=0.1 | |
) | |
) | |
red_patch = mpatches.Patch(color='red', label='SARSA Learning') | |
blue_patch = mpatches.Patch(color='blue', label='Q Learning') | |
plt.legend(handles=[red_patch, blue_patch]) | |
plt.grid() | |
# plt.show() | |
plt.savefig('cliff_world_path.png') | |
def play(self, is_sarsa=False): | |
self.value_map.initialize() | |
for i in range(self.no_episodes): | |
current_state = self.start | |
actions = 0 | |
rewards = 0 | |
is_greddy = False | |
if i == self.no_episodes - 1: | |
is_greddy = True | |
if is_sarsa: | |
self.sasra_last_episode_actions.append(current_state) | |
else: | |
self.last_episode_actions.append(current_state) | |
while not(current_state == self.end): | |
if actions > 2000: | |
break | |
current_action = self.value_map.select_next_action(current_state, is_greddy=is_greddy) | |
next_state, reward = self.value_map.update_value_map(current_state, current_action, is_sarsa=is_sarsa, is_greddy=is_greddy) | |
if next_state is None: | |
next_state = self.start | |
current_state = next_state | |
actions = actions + 1 | |
rewards = rewards + reward | |
if i == self.no_episodes - 1: | |
print current_action, current_state | |
if is_sarsa: | |
self.sasra_last_episode_actions.append(current_state) | |
else: | |
self.last_episode_actions.append(current_state) | |
if is_sarsa: | |
self.sarsa_actions_episodes.append(actions) | |
self.sarsa_rewards_episodes.append(rewards) | |
else: | |
self.actions_episodes.append(actions) | |
self.rewards_episodes.append(rewards) | |
game = Game(12, 4, 1000) | |
game.play() | |
game.play(is_sarsa=True) | |
game.plot() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment