Created
December 5, 2017 09:28
-
-
Save ByungSunBae/b23ee052c75333e8f4e83d7315964b03 to your computer and use it in GitHub Desktop.
A2C pytorch version
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# References: | |
import argparse | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
import torchvision.transforms as T | |
import torch.backends.cudnn as cudnn | |
import numpy as np | |
import random | |
import sys, os | |
import gym | |
import pandas as pd | |
from collections import deque | |
from collections import namedtuple | |
from skimage.color import rgb2gray | |
from skimage.transform import resize | |
parser = argparse.ArgumentParser(description= | |
"This is Deep Reinforcement Learning in Breakout-v0. We have Advantage Actor-Critic algorithm.") | |
#parser.add_argument("-IsDuelingDQN", action="store", type=bool, | |
# default=False, dest="IsDuelingDQN", | |
# help="Whether using DuelingDQN with Average Method") | |
#parser.add_argument("-IsDoubleDQN", action="store", type=bool, | |
# default=False, dest="IsDoubleDQN", | |
# help="Whether using DoubleDQN") | |
parser.add_argument("-GameName", action="store", type=str, | |
default="Breakout-v0", dest="GameName", | |
help="If you want to execute Berzerk-v0, just write -GameName=Berzerk-v0") | |
results = parser.parse_args() | |
#IsDoubleDQN = results.IsDoubleDQN | |
#IsDuelingDQN = results.IsDuelingDQN | |
GameName = results.GameName | |
#env = gym.make('Breakout-v0') | |
env = gym.make(GameName) | |
use_cuda = torch.cuda.is_available() | |
#use_cuda = False | |
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor | |
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor | |
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor | |
Tensor = FloatTensor | |
if use_cuda: | |
print("Use First GPU") | |
torch.cuda.set_device(0) | |
else: | |
print("Use CPU") | |
#cudnn.enabled = False | |
batch_size = 100000 | |
gamma = 0.99 | |
#eps_start = 1. | |
#eps_end = 0.1 | |
#eps_decay = 1000000 | |
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward', 'done')) | |
class Actor(nn.Module): | |
def __init__(self, name): | |
super(Actor, self).__init__() | |
self.name = name | |
self.conv1 = nn.Conv2d(4, 1, kernel_size=1, stride=1, bias=False) | |
self.conv2 = nn.Conv2d(1, 16, kernel_size=8, stride=4, bias=False) | |
self.conv3 = nn.Conv2d(16, 32, kernel_size=4, stride=2, bias=False) | |
self.head1 = nn.Linear(9 * 9 * 32, 512, bias=False) | |
self.head2 = nn.Linear(512, env.action_space.n, bias=False) | |
def forward(self, x): | |
x = F.relu(self.conv1(x)) | |
x = F.relu(self.conv2(x)) | |
x = F.relu(self.conv3(x)) | |
x = F.relu(self.head1(x.view(x.size(0), -1))) | |
return F.softmax(self.head2(x)) | |
class Critic(nn.Module): | |
def __init__(self, name): | |
super(Critic, self).__init__() | |
self.name = name | |
self.conv1 = nn.Conv2d(4, 16, kernel_size=8, stride=4, bias=False) | |
self.conv2 = nn.Conv2d(16, 32, kernel_size=4, stride=2, bias=False) | |
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, bias=False) | |
self.head1 = nn.Linear(7 * 7 * 64, 512, bias=False) | |
self.head2_1 = nn.Linear(512, 1, bias=False) | |
def forward(self, x): | |
x = F.relu(self.conv1(x)) | |
x = F.relu(self.conv2(x)) | |
x = F.relu(self.conv3(x)) | |
x = F.relu(self.head1(x.view(x.size(0), -1))) | |
self.V = self.head2_1(x) | |
return self.V | |
M1_name = "A2C_Actor" | |
M2_name = "A2C_Critic" | |
ActorModel = Actor(name = M1_name).cuda() if use_cuda else Actor(name = M1_name) | |
CriticModel = Critic(name = M2_name).cuda() if use_cuda else Critic(name = M2_name) | |
ActorOptimizer = optim.Adam(ActorModel.parameters(), lr = 0.00025) | |
CriticOptimizer = optim.RMSprop(CriticModel.parameters(), lr = 0.00025, eps = 1e-2, momentum = 0.95) | |
class ReplayMemory(): | |
def __init__(self): | |
#self.max_len = max_len | |
self.memory = deque() | |
def PutExperience(self, *args): | |
"""Save state transition. | |
*args : state, action, next_state, reward, done | |
""" | |
self.memory.append(Transition(*args)) | |
def Sample(self): | |
samples = self.memory | |
return samples | |
def __len__(self): | |
return len(self.memory) | |
#memory = ReplayMemory(5000) | |
## => for mini test | |
def SelectAction(state): | |
global steps_done, frame | |
probs = ActorModel(Variable(state, volatile=True).type(FloatTensor)) | |
probs = probs.cpu().data.numpy()[0] | |
act_t = np.random.choice(np.arange(env.action_space.n), p=probs) | |
return act_t | |
def get_screen(screen): | |
screen = rgb2gray(screen) | |
screen = np.ascontiguousarray(screen, dtype=np.float32) | |
screen = resize(screen, (84, 84), mode = 'reflect') | |
return np.expand_dims(screen, axis=0) | |
M_name = "A2C" | |
folder_path = "./Atari" | |
gamename_path = str(GameName) | |
mid_path = str(folder_path) + "/" + gamename_path | |
save_path = mid_path + "/" +str(M_name) | |
if os.path.exists(folder_path) is not True: | |
os.mkdir(folder_path) | |
if os.path.exists(mid_path) is not True: | |
os.mkdir(mid_path) | |
def batch_func(iterable, n=1): | |
l = len(iterable) | |
for ndx in range(0, l, n): | |
yield np.array(list(iterable))[ndx:min(ndx + n, l-1)] | |
def OptimizeModel(): | |
transitions = memory.Sample() | |
batch = Transition(*zip(*transitions)) | |
ActorOptimizer.zero_grad() | |
CriticOptimizer.zero_grad() | |
reward_batch = Variable(torch.from_numpy(np.vstack(np.expand_dims(batch.reward, axis=1)))).type(Tensor) | |
rollings = [] | |
for i in range(reward_batch.size()[0]): | |
tmp = reward_batch[i:,] | |
discounted_factors = Variable(torch.from_numpy(np.array([np.power(0.99, k) for k in range(tmp.size()[0])]))).type(FloatTensor) | |
rollings.append(torch.sum(tmp * discounted_factors)) | |
del tmp | |
rollings = torch.cat(rollings).unsqueeze(1) | |
#if len(batch[1]) < batch_size: | |
# batch_size_t = len(batch[1]) | |
#else: | |
# batch_size_t = batch_size | |
#for orders in batch_func(range(0, len(batch[1])), batch_size_t): | |
# start, end = orders[0], orders[1] | |
non_final_mask = Variable(torch.from_numpy(np.vstack(np.expand_dims(batch.done,axis=1)))).type(Tensor) | |
next_state_batch = Variable(torch.cat(torch.from_numpy(np.concatenate(np.expand_dims(batch.next_state, axis=0))).unsqueeze(0))).type(FloatTensor) | |
state_batch = Variable(torch.cat(torch.from_numpy(np.concatenate(np.expand_dims(batch.state, axis=0))).unsqueeze(0))).type(FloatTensor) | |
action_batch = Variable(torch.from_numpy(np.vstack(batch.action))).type(LongTensor) | |
reward_batch = rollings | |
#state_action_values = MainNet(state_batch).gather(1, action_batch) | |
probs = ActorModel(state_batch) | |
state_values = CriticModel(state_batch) | |
#next_state_values = CriticModel(next_state_batch) | |
#target_state_values = (non_final_mask * next_state_values * gamma) + reward_batch | |
beta = 0.01 | |
probs = torch.clamp(probs, min=1e-10, max=1.0) | |
entropy = -1 * torch.sum(probs * torch.log(probs), 1) | |
advantages = reward_batch - state_values | |
advantages = advantages.cpu().data.numpy() | |
policy_loss = -torch.log(probs.gather(1, action_batch)) * Variable(torch.from_numpy(advantages)).type(FloatTensor) - beta * entropy | |
policy_loss = torch.mean(policy_loss).cuda() if use_cuda else torch.mean(policy_loss) | |
value_loss = torch.mean(torch.pow(reward_batch - state_values, 2)).cuda() if use_cuda else torch.mean(torch.pow(reward_batch - state_values, 2)) | |
#loss = 0.5 * value_loss + policy_loss | |
#loss.backward() | |
policy_loss.backward() | |
ActorOptimizer.step() | |
value_loss.backward() | |
CriticOptimizer.step() | |
ep_reward = 0 | |
recent_100_reward = deque(maxlen=100) | |
frame = 0 | |
C = 10000 | |
save_C = 500 | |
LogData = [] | |
average_dq = deque() | |
episode = 0 | |
num_episodes = 50000 | |
for i_episode in range(num_episodes): | |
ep_reward = 0 | |
episode += 1 | |
state_dq = deque(maxlen=4) | |
life_dq = deque(maxlen=2) | |
memory = ReplayMemory() | |
for i in range(3): | |
state_dq.append(np.zeros(shape=[1, 84, 84])) | |
curr_frame = get_screen(env.reset()) | |
state_dq.append(curr_frame) | |
done = False | |
while done is False: | |
frame += 1 | |
curr_state = np.vstack(state_dq) | |
action = SelectAction(torch.from_numpy(curr_state).unsqueeze(0)) | |
if action == 0: | |
real_action = 1 | |
else: | |
real_action = action | |
next_frame, reward, done, info = env.step(real_action) | |
reward_t = reward | |
#if int(info['ale.lives']) is not 0: | |
life_dq.append(info['ale.lives']) | |
if done is False: | |
if len(life_dq) == 2: | |
if life_dq[0] > life_dq[1]: | |
done_t = 0 | |
reward_t = -1 | |
else: | |
done_t = 1 | |
reward_t = reward | |
else: | |
done_t = 1 - int(done) | |
reward_t = reward | |
else: | |
done_t = 1 - int(done) | |
reward_t = -1 | |
#else: | |
#if done is False: | |
# done_t = 1 | |
# reward_t = reward | |
#else: | |
# done_t = 0 | |
#reward_t = -1 | |
next_frame = get_screen(next_frame) | |
state_dq.append(next_frame) | |
next_state = np.vstack(state_dq) | |
ep_reward += reward | |
reward_T = np.clip(reward_t, -1.0, 1.0) | |
done_T = int(done_t) | |
if int(np.sum(curr_state[0])) != 0: | |
memory.PutExperience(curr_state, action, next_state, reward_T, done_T) | |
if done: | |
OptimizeModel() | |
recent_100_reward.append(ep_reward) | |
if episode % save_C == 0: | |
torch.save(ActorModel.state_dict(), save_path + "_Actor" + str(episode)) | |
torch.save(CriticModel.state_dict(), save_path + "_Critic" + str(episode)) | |
print("Save Model !! : {}".format(episode)) | |
if episode >= 10: | |
print("Episode %1d Done, Frames : %1d, Scores : %.1f, Mean100Ep_Scores : %5f" % (episode, | |
frame, ep_reward, np.mean(recent_100_reward))) | |
LogData.append((episode, frame, ep_reward, np.mean(recent_100_reward))) | |
LogDataDF = pd.DataFrame(LogData) | |
LogDataDF.columns = ['Episode', 'Frames', 'Scores_per_ep','Mean100Ep_Scores'] | |
LogDataDF.to_csv(save_path + "LogData.csv", index=False, header=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment