Last active
February 21, 2019 12:30
-
-
Save ita9naiwa/c4ad65931c8a49499671355351b79bce to your computer and use it in GitHub Desktop.
Reinforce algorithm with multiprocesses
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# mainly brought from https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py | |
import gym | |
import numpy as np | |
import torch | |
from torch import nn | |
from torch import optim | |
import torch.nn.functional as F | |
from torch.distributions import Categorical | |
import torch.multiprocessing as mp | |
log_interval = 10 | |
class Policy(nn.Module): | |
def __init__(self): | |
super(Policy, self).__init__() | |
self.affine1 = nn.Linear(4, 10) | |
self.affine2 = nn.Linear(10, 2) | |
def forward(self, x): | |
x = torch.relu(self.affine1(x)) | |
action_scores = self.affine2(x) | |
return F.softmax(action_scores, dim=1) | |
def rollout2(worker, env, param_queue, calc_baseine_batch=5): | |
tot_returns = [] | |
tot_logprobs = [] | |
baselines = np.zeros((calc_baseine_batch, 1000)) | |
for batch in range(calc_baseine_batch): | |
rewards, logprobs = [], [] | |
state, ep_reward = env.reset(), 0 | |
for _ in range(10000): | |
state = torch.from_numpy(state).float().unsqueeze(0) | |
probs = worker(state) | |
m = Categorical(probs) | |
action = m.sample() | |
logprob = m.log_prob(action) | |
action = action.item() | |
state, reward, done, _ = env.step(action) | |
logprobs.append(logprob) | |
rewards.append(reward) | |
ep_reward += reward | |
if done: | |
break | |
R = 0 | |
policy_loss = [] | |
returns = [] | |
for r in rewards[::-1]: | |
R = r + 0.99 * R | |
returns.insert(0, R) | |
ep_len = len(returns) | |
returns = np.array(returns) | |
baselines[batch, :ep_len] = returns | |
tot_returns.append(returns) | |
tot_logprobs.append(torch.cat(logprobs)) | |
baseline = np.mean(baselines, axis=0) | |
policy_loss = 0.0 | |
for Return, logprobs in zip(tot_returns, tot_logprobs): | |
ret = torch.Tensor(Return - baseline[:len(Return)]) | |
policy_loss += torch.sum(-(ret * logprobs)) | |
worker.zero_grad() | |
policy_loss.backward() | |
grads = [x.grad for x in worker.parameters()] | |
param_queue.put((grads, ep_reward)) | |
return grads, ep_reward | |
def main(): | |
num_batch = 4 | |
list_of_grads = [] | |
step_size = 10 | |
running_reward = 10.0 | |
param_queue = mp.Queue() | |
policies = [Policy() for x in range(num_batch)] | |
global_policy = Policy() | |
global_opt = optim.Adam(global_policy.parameters(), lr=1e-2) | |
for p in policies: | |
p.load_state_dict(global_policy.state_dict()) | |
envs = [gym.make('CartPole-v0') for x in range(num_batch)] | |
torch.multiprocessing.set_sharing_strategy('file_system') | |
for i_episode in range(10000): | |
ps = [] | |
for i in range(num_batch): | |
p = mp.Process( | |
target=rollout2, | |
args=(policies[i], envs[i], param_queue, step_size)) | |
ps.append(p) | |
for p in ps: | |
p.start() | |
for p in ps: | |
p.join() | |
list_of_grads = [] | |
rewards = [] | |
for i in range(num_batch): | |
g, r = param_queue.get() | |
list_of_grads.append(g) | |
rewards.append(r) | |
ep_reward = np.mean(rewards) | |
running_reward = 0.05 * ep_reward + (1 - 0.05) * running_reward | |
global_opt.zero_grad() | |
for grads in list_of_grads: | |
for p, g in zip(global_policy.parameters(), grads): | |
if p.grad is None: | |
p.grad = g | |
else: | |
p.grad += g | |
for p in global_policy.parameters(): | |
p.grad /= len(list_of_grads) | |
ep_reward = np.mean(rewards) | |
global_opt.step() | |
for p in policies: | |
p.load_state_dict(global_policy.state_dict()) | |
if i_episode % log_interval == 0: | |
print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format( | |
i_episode, ep_reward, running_reward)) | |
if running_reward > envs[0].spec.reward_threshold: | |
print("Solved! Running reward is now {} and " | |
"the last episode runs to {} time steps!".format(running_reward, i_episode)) | |
break | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment