Created
December 16, 2016 22:56
-
-
Save awjuliani/9149588eed921eda593bf20e6f9b7e32 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class AC_Network(): | |
def __init__(self,s_size,a_size,scope,trainer): | |
.... | |
.... | |
.... | |
if scope != 'global': | |
self.actions = tf.placeholder(shape=[None],dtype=tf.int32) | |
self.actions_onehot = tf.one_hot(self.actions,a_size,dtype=tf.float32) | |
self.target_v = tf.placeholder(shape=[None],dtype=tf.float32) | |
self.advantages = tf.placeholder(shape=[None],dtype=tf.float32) | |
self.responsible_outputs = tf.reduce_sum(self.policy * self.actions_onehot, [1]) | |
#Loss functions | |
self.value_loss = 0.5 * tf.reduce_sum(tf.square(self.target_v - tf.reshape(self.value,[-1]))) | |
self.entropy = - tf.reduce_sum(self.policy * tf.log(self.policy)) | |
self.policy_loss = -tf.reduce_sum(tf.log(self.responsible_outputs)*self.advantages) | |
self.loss = 0.5 * self.value_loss + self.policy_loss - self.entropy * 0.01 | |
#Get gradients from local network using local losses | |
local_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope) | |
self.gradients = tf.gradients(self.loss,local_vars) | |
self.var_norms = tf.global_norm(local_vars) | |
grads,self.grad_norms = tf.clip_by_global_norm(self.gradients,40.0) | |
#Apply local gradients to global network | |
global_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'global') | |
self.apply_grads = trainer.apply_gradients(zip(grads,global_vars)) | |
class Worker(): | |
.... | |
.... | |
.... | |
def train(self,global_AC,rollout,sess,gamma,bootstrap_value): | |
rollout = np.array(rollout) | |
observations = rollout[:,0] | |
actions = rollout[:,1] | |
rewards = rollout[:,2] | |
next_observations = rollout[:,3] | |
values = rollout[:,5] | |
# Here we take the rewards and values from the rollout, and use them to | |
# generate the advantage and discounted returns. | |
# The advantage function uses "Generalized Advantage Estimation" | |
self.rewards_plus = np.asarray(rewards.tolist() + [bootstrap_value]) | |
discounted_rewards = discount(self.rewards_plus,gamma)[:-1] | |
self.value_plus = np.asarray(values.tolist() + [bootstrap_value]) | |
advantages = rewards + gamma * self.value_plus[1:] - self.value_plus[:-1] | |
advantages = discount(advantages,gamma) | |
# Update the global network using gradients from loss | |
# Generate network statistics to periodically save | |
rnn_state = self.local_AC.state_init | |
feed_dict = {self.local_AC.target_v:discounted_rewards, | |
self.local_AC.inputs:np.vstack(observations), | |
self.local_AC.actions:actions, | |
self.local_AC.advantages:advantages, | |
self.local_AC.state_in[0]:rnn_state[0], | |
self.local_AC.state_in[1]:rnn_state[1]} | |
v_l,p_l,e_l,g_n,v_n,_ = sess.run([self.local_AC.value_loss, | |
self.local_AC.policy_loss, | |
self.local_AC.entropy, | |
self.local_AC.grad_norms, | |
self.local_AC.var_norms, | |
self.local_AC.apply_grads], | |
feed_dict=feed_dict) | |
return v_l / len(rollout),p_l / len(rollout),e_l / len(rollout), g_n,v_n |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
class Worker():
....
....
....
deftrain(self,global_AC,rollout,sess,gamma,bootstrap_value)
Why do I need to pass global_AC to a function
It's not used, is it?