Created
September 25, 2016 20:25
-
-
Save chetandhembre/717ed816ca615b58be0edbfcdaaf6cdf to your computer and use it in GitHub Desktop.
Gambler's Problem: Reinforcement Learning
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#resultant graphs are here: https://dl.dropboxusercontent.com/u/47591917/demo.png | |
import numpy as np | |
import math | |
from matplotlib import pyplot as plt | |
class State(object): | |
def __init__(self, capital): | |
self.capital = capital | |
def staked(self, stake): | |
self.capital += stack | |
def get_reward(self): | |
return 1 if self.capital == 100 else 0 | |
def __eq__(self, other): | |
return self.capital == other.capital | |
def __hash__(self): | |
return hash((self.capital)) | |
class ValueMap(object): | |
def __init__(self): | |
self.capital = {} | |
def initialize_capital(self): | |
#gambler has to earn $100 to win. | |
#we are initialize stack gambler will put on for winning to randomly | |
# 0th and 100th state are teminal state so we are not adding it here | |
probabilities = np.random.uniform(low=0, high=1, size=100) | |
for i in range(1, 100): | |
probability_to_win = np.random.uniform() | |
state = State(i) | |
self.capital[state] = float(0) | |
def __getitem__(self, key): | |
return self.capital.get(key, 0) | |
def __setitem__(self, key, value): | |
self.capital[key] = value | |
class Policy(object): | |
def __init__(self): | |
self.optimal_stake = {} | |
def __getitem__(self, key): | |
return self.optimal_stake.get(key, 0) | |
def __setitem__(self, key, value): | |
self.optimal_stake[key] = value | |
def get_allowed_stake(capital): | |
return range(0, min(capital, 100 - capital) + 1) | |
class Game(object): | |
def __init__(self, head_probability): | |
self.value_map = ValueMap() | |
self.value_map.initialize_capital() | |
self.policy = Policy() | |
self.head_probability = head_probability | |
def get_value_action(self, capital, action): | |
winState = State(capital.capital + action) | |
lossState = State(capital.capital - action) | |
value = float(self.head_probability * (winState.get_reward() + self.value_map[winState]) + (1 - self.head_probability) * (lossState.get_reward() + self.value_map[lossState])) | |
return value | |
def value_iteration(self, theta): | |
diff = theta + 1 | |
while diff > theta: | |
diff = 0 | |
capitals = self.value_map.capital.keys() | |
for capital in capitals: | |
current_value = self.value_map[capital] | |
_max = current_value | |
max_action = self.policy[capital] | |
for action in get_allowed_stake(capital.capital): | |
try: | |
value = self.get_value_action(capital, action) | |
except: | |
print "loss!!" | |
break | |
if _max < value: | |
_max = value | |
max_action = action | |
diff = max(diff, abs(_max - current_value)) | |
self.value_map[capital] = _max | |
self.policy[capital] = int(max_action) | |
def get_policy_value(self): | |
for capital in get_state_itr(): | |
max_action = 0 | |
_max = float(-1) | |
for action in get_allowed_stake(capital.capital): | |
value = self.get_value_action(capital, action) | |
if _max <= value: | |
max_action = action | |
_max = value | |
self.policy[capital] = int(max_action) | |
def print_policy(self): | |
print "results graph:" | |
plt.figure(1, figsize=(10,15)) | |
x = [] | |
y = [] | |
for capital in self.value_map.capital: | |
x.append(capital.capital) | |
y.append(game.value_map[capital]) | |
plt.subplot(211) | |
plt.plot(x, y) | |
plt.xlabel('Capital \n') | |
plt.ylabel('Value \n Estimates', rotation=0) | |
x = [] | |
y = [] | |
for capital in self.policy.optimal_stake: | |
x.append(capital.capital) | |
y.append(game.policy[capital]) | |
plt.subplot(212) | |
plt.step(x, y) | |
plt.xlabel('Capital') | |
plt.ylabel('Final Policy \n (stake)', rotation=0) | |
plt.savefig('demo.png') | |
# plt.show() | |
def play(self): | |
self.value_iteration(1e-9) | |
self.print_policy() | |
game = Game(0.4) | |
game.play() | |
#solution graphs are here |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment