Created
December 17, 2018 17:55
-
-
Save tano297/7637467ab2577c21bead1a8f87711989 to your computer and use it in GitHub Desktop.
value_iteration_gridworld.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import numpy as np | |
# define the grid size | |
size_h = 4 | |
size_w = 4 | |
# define the actions | |
actions = np.array(["up", "down", "left", "right"]) | |
# define the reward for each action (-1 everywhere for all actions, | |
# except for the terminal states) | |
reward = np.full((size_h, size_w, len(actions)), -1.0) | |
reward[0, 0] = np.zeros((4), dtype=np.float32) | |
reward[-1, -1] = np.zeros((4), dtype=np.float32) | |
# s'|s,a in this problem is deterministic, so I can just define it as a 4x4, | |
transfer = np.zeros((size_h, size_w, len(actions), 2), dtype=np.int32) | |
for y in range(size_h): | |
for x in range(size_w): | |
for a in range(len(actions)): | |
if actions[a] == "up": | |
if y > 0: | |
transfer[y, x, a, 0] = y - 1 | |
else: | |
transfer[y, x, a, 0] = y | |
transfer[y, x, a, 1] = x | |
elif actions[a] == "down": | |
if y < size_h - 1: | |
transfer[y, x, a, 0] = y + 1 | |
else: | |
transfer[y, x, a, 0] = y | |
transfer[y, x, a, 1] = x | |
elif actions[a] == "left": | |
if x > 0: | |
transfer[y, x, a, 1] = x - 1 | |
else: | |
transfer[y, x, a, 1] = x | |
transfer[y, x, a, 0] = y | |
elif actions[a] == "right": | |
if x < size_w - 1: | |
transfer[y, x, a, 1] = x + 1 | |
else: | |
transfer[y, x, a, 1] = x | |
transfer[y, x, a, 0] = y | |
# now fill up the transfer at the end nodes | |
transfer[0, 0] = np.zeros((len(actions), 2)) | |
transfer[-1, -1] = np.full((len(actions), 2), -1) | |
# initial value function | |
value_0 = np.zeros((size_h, size_w), dtype=np.float32) | |
# iterate externally over value iteration NO POLICY USED!!!! | |
iterations_value_iter = 10000 | |
epsilon = 0.0001 | |
for it in range(iterations_value_iter): | |
value_t = np.zeros_like(value_0) | |
# do one bellman step in each state | |
for y in range(value_0.shape[0]): | |
for x in range(value_0.shape[1]): | |
# define a bogus max for value in this state | |
max_v = -float("inf") | |
for a, action in enumerate(actions): | |
# get the coordinates where I go with this action | |
newy, newx = transfer[y, x, a] | |
# make one lookahead step for this action | |
v = reward[y, x, a] + value_0[newy, newx] | |
if v > max_v: | |
max_v = v | |
# update the value function with the max | |
value_t[y, x] = max_v | |
# check convergence for value function, otherwise iterate | |
# if value converged, exit | |
norm = 0.0 | |
for y in range(value_t.shape[0]): | |
for x in range(value_t.shape[1]): | |
norm += np.abs(value_0[y, x] - value_t[y, x]) | |
norm /= np.array(value_t.shape, dtype=np.float32).sum() | |
# print(norm) | |
if norm < epsilon: | |
print("!" * 80) | |
print("Exiting loop because I converged the value") | |
print("!" * 80) | |
break | |
else: | |
# if not converged, save current as old to iterate | |
value_0 = np.copy(value_t) | |
# once I have the optimal value function, I can calculate the optimal policy | |
pi = np.zeros((size_h, size_w, len(actions)), dtype=np.uint32) | |
for y in range(value_t.shape[0]): | |
for x in range(value_t.shape[1]): | |
max_v = -float("inf") | |
max_v_idx = 0 | |
for a, action in enumerate(actions): | |
# get the coordinates where I go with this action | |
newy, newx = transfer[y, x, a] | |
# make one lookahead step for this action | |
v = reward[y, x, a] + value_t[newy, newx] | |
if v > max_v: | |
max_v = v | |
max_v_idx = a | |
# update policy with argmax | |
pi[y, x] = np.zeros((len(actions)), dtype=np.float32) | |
pi[y, x, max_v_idx] = 1.0 | |
print("-" * 40) | |
print("iterations: ", it + 1) | |
print("value:") | |
print(value_t) | |
print("policy:") | |
print(actions[np.argmax(pi, axis=-1)]) | |
print("careful with reading this, because some members of the policy can be 2 options and I only consider the argmax, so the first one in the actions list is returned") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
RL Course David Silver, Lecture 3, 1:00:03
David doesn't explain an example, but I felt the continuity was better by doing the exact same example of the gridworld
https://youtu.be/Nd1-UUMVfz4