Last active
April 23, 2021 15:40
-
-
Save marty1885/0d96be2443a0bf7b50a3f24509d93b1b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
from etaler import et | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
def running_mean(x, N): | |
cumsum = np.cumsum(np.insert(x, 0, 0)) | |
return (cumsum[N:] - cumsum[:-N]) / float(N) | |
class Agent: | |
def __init__(self): | |
# The goal of the system. To be centered, 0 verlosity, 0 pole angle, 0 angular momentum | |
self.desired_state = et.concat([et.encoder.gridCell1d(state*13) for state in (0, 0, 0, 0)]) | |
self.last_state = None | |
self.last_output = None | |
self.last_action = None | |
self.associator = et.SpatialPooler((self.desired_state.size()*2, ), (1024,), 0.65) | |
self.associator.setGlobalDensity(0.1) | |
self.associator.setPermanenceInc(0.15) | |
self.associator.setPermanenceDec(0.05) | |
def step(self, obs, reward): | |
state = et.concat([et.encoder.gridCell1d(state*13) for state in obs]) | |
# Learn that the last action it took transitioned it from a state to a new one | |
if self.last_state is not None: | |
state_transition_sdr = et.concat([self.last_state, state]) | |
# Learn the opposite if it died. While prevent the agent from learning to | |
# take the opposite action. | |
masked_action_sdr = self.last_output.copy() | |
masked_action_sdr = masked_action_sdr.reshape((2, masked_action_sdr.size()//2)) | |
if reward == 1: | |
masked_action_sdr[1-self.last_action] = False | |
else: | |
masked_action_sdr[self.last_action] = False | |
self.last_state = None | |
self.associator.learn(state_transition_sdr, masked_action_sdr.flatten()) | |
# Generate a new action | |
action_sdr = self.associator.compute(et.concat([state, self.desired_state])) | |
action = np.argmax(action_sdr.reshape((2, action_sdr.size()//2)).sum(1).numpy()) | |
self.last_output = action_sdr | |
self.last_state = state | |
self.last_action = int(action) | |
return action | |
env = gym.make('CartPole-v1') | |
agent = Agent() | |
reward = 0 | |
score = [] | |
for episode in range(1000): | |
obs = env.reset() | |
for t in range(500): | |
action = agent.step(obs, reward) | |
obs, reward, done, info = env.step(action) | |
#env.render() | |
# Normalize the observation | |
obs[0] = obs[0]/4.8 | |
obs[2] = obs[2]/0.418 | |
reward = 1 | |
if done: | |
print("Ep: ", episode, "living frames: ", t) | |
reward = 0 | |
obs = env.reset() | |
score.append(t) | |
break | |
plt.title("HTM performance in Cartpole-v1") | |
plt.xlabel("episode") | |
plt.ylabel("agent reward") | |
plt.plot(score, label="reward") | |
plt.plot(running_mean(score, 30), label="reward (average over 30 eps)") | |
plt.legend(loc="upper right") | |
pd.DataFrame(score, columns=["score"]).to_csv("result.csv", index=None) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment