Last active
July 17, 2021 11:05
-
-
Save kngwyu/58c2aedcc7d72d866cb7c8e9c6388f32 to your computer and use it in GitHub Desktop.
PPO with Gaussian Policy implemented by JAX
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dataclasses | |
import functools | |
import typing as t | |
import chex | |
import distrax | |
import gym | |
import haiku as hk | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
import optax | |
import rlax | |
import typer | |
from chex import Array as JaxArray | |
Observation = np.ndarray | |
AgentOutput = t.Any | |
Action = np.ndarray | |
Actor = t.Callable[[Observation], t.Tuple[Action, AgentOutput]] | |
@dataclasses.dataclass | |
class ReturnReporter: | |
reward_sum: float = 0.0 | |
episode_returns: t.List[float] = dataclasses.field(default_factory=list) | |
def experience(self, reward: float, done: bool) -> None: | |
self.reward_sum += reward | |
if done: | |
print(f"Episodic return: {self.reward_sum}") | |
self.episode_returns.append(self.reward_sum) | |
self.reward_sum = 0 | |
@dataclasses.dataclass | |
class RolloutResult: | |
observations: t.List[Observation] | |
actions: t.List[Action] = dataclasses.field(default_factory=list) | |
rewards: t.List[float] = dataclasses.field(default_factory=list) | |
terminals: t.List[bool] = dataclasses.field(default_factory=list) | |
outputs: t.List[AgentOutput] = dataclasses.field(default_factory=list) | |
def rollout( | |
env: gym.Env, | |
initial_obs: Observation, | |
n_steps: int, | |
actor: Actor, | |
reporter: t.Callable[[float, bool], None], | |
) -> t.Tuple[Observation, RolloutResult]: | |
last_obs = initial_obs | |
result = RolloutResult(observations=[last_obs]) | |
for i in range(n_steps): | |
jnp_obs = jnp.array(np.expand_dims(last_obs, axis=0)) | |
action, output = jax.lax.stop_gradient(actor(jnp_obs)) | |
obs, reward, terminal, _ = env.step(action) | |
obs = obs.flatten() | |
result.observations.append(obs) | |
result.actions.append(action) | |
result.rewards.append(reward) | |
result.terminals.append(terminal) | |
result.outputs.append(output) | |
reporter(reward, terminal) | |
if terminal: | |
last_obs = env.reset() | |
else: | |
last_obs = obs | |
return last_obs, result | |
class GaussianPiAndVNet(hk.Module): | |
"""A simple network.""" | |
def __init__(self, action_dim: int) -> None: | |
super().__init__() | |
self._action_dim = action_dim | |
def __call__( | |
self, | |
observation: np.ndarray, | |
) -> t.Tuple[JaxArray, JaxArray, JaxArray]: | |
"""Process a batch of observations.""" | |
torso = hk.Sequential( | |
[hk.Flatten(), hk.Linear(128), jax.nn.relu, hk.Linear(64), jax.nn.relu] | |
) | |
hidden = torso(observation) | |
pi_mu = hk.Linear(self._action_dim)(hidden) | |
pi_log_std = hk.get_parameter( | |
"pi_log_std", | |
(1, self._action_dim), | |
init=jnp.zeros, | |
) | |
baseline = hk.Linear(1)(hidden) | |
baseline = jnp.squeeze(baseline, axis=-1) | |
return pi_mu, pi_log_std, baseline | |
@chex.dataclass(frozen=True, mappable_dataclass=False) | |
class PPOBatch: | |
observation: JaxArray | |
action: JaxArray | |
reward: JaxArray | |
mask: JaxArray | |
advantage: JaxArray | |
value_target: JaxArray | |
log_prob: JaxArray | |
def __getitem__(self, idx: JaxArray) -> t.Any: | |
return self.__class__( | |
observation=self.observation[idx], | |
action=self.action[idx], | |
reward=self.reward[idx], | |
mask=self.mask[idx], | |
advantage=self.advantage[idx], | |
value_target=self.value_target[idx], | |
log_prob=self.log_prob[idx], | |
) | |
def make_ppo_batch( | |
rollout_result: RolloutResult, | |
next_value: JaxArray, | |
gamma: float, | |
gae_lambda: float, | |
) -> PPOBatch: | |
observation, action, reward, terminal = map( | |
jnp.array, dataclasses.astuple(rollout_result)[:-1] | |
) | |
mu, logstd, value = map(jnp.concatenate, zip(*rollout_result.outputs)) | |
value = jnp.concatenate([value, next_value]) | |
mask = 1.0 - terminal | |
advantage = rlax.truncated_generalized_advantage_estimation( | |
reward, mask * gamma, gae_lambda, value | |
) | |
value_target = advantage + value[:-1] | |
policy = distrax.LogStddevNormal(mu, logstd) | |
return PPOBatch( | |
observation=observation, | |
action=action, | |
reward=reward, | |
mask=mask, | |
advantage=advantage, | |
value_target=value_target, | |
log_prob=policy.log_prob(action), | |
) | |
class Agent: | |
def __init__( | |
self, | |
network: hk.Transformed, | |
clip_epsilon: float, | |
entropy_coeff: float, | |
) -> None: | |
self._network = network | |
self._clip_epsilon = clip_epsilon | |
self._entropy_coef = entropy_coeff | |
@functools.partial(jax.jit, static_argnums=0) | |
def act( | |
self, | |
observation: JaxArray, | |
*, | |
rng_key: JaxArray, | |
params: hk.Params, | |
) -> t.Tuple[JaxArray, t.Tuple[JaxArray, JaxArray]]: | |
mu, logstd, value = self._network.apply(params, observation) | |
_, step_key = jax.random.split(rng_key) | |
distrib = distrax.LogStddevNormal(mu.flatten(), logstd.flatten()) | |
return distrib.sample(seed=step_key), (mu, logstd, value) | |
def _loss(self, params: hk.Params, batch: PPOBatch) -> JaxArray: | |
net_vmap = jax.vmap(self._network.apply, (None, 0)) | |
mu, logstd, value = net_vmap(params, batch.observation) | |
policy = distrax.LogStddevNormal(mu, logstd) | |
log_prob = policy.log_prob(batch.action) | |
prob_ratio = jnp.exp(jnp.sum(log_prob - batch.log_prob, axis=-1)) | |
clipped_prob_ratio = jnp.clip( | |
prob_ratio, | |
1.0 - self._clip_epsilon, | |
1.0 + self._clip_epsilon, | |
) | |
clipped_objective = jnp.fmin( | |
prob_ratio * batch.advantage, clipped_prob_ratio * batch.advantage | |
) | |
policy_loss = -jnp.mean(clipped_objective) | |
entropy_loss = -jnp.mean(policy.entropy()) | |
value_loss = jnp.mean(rlax.l2_loss(value - batch.value_target)) | |
return policy_loss + value_loss + self._entropy_coef * entropy_loss | |
def get_updater( | |
loss_function: t.Callable[..., t.Any], | |
updater: optax.TransformUpdateFn, | |
) -> t.Callable[..., t.Tuple[hk.Params, optax.OptState]]: | |
@jax.jit | |
def update( | |
params: hk.Params, | |
opt_state: optax.OptState, | |
ppo_batch: PPOBatch, | |
) -> t.Tuple[hk.Params, optax.OptState]: | |
g = jax.grad(loss_function)(params, ppo_batch) | |
updates, new_opt_state = updater(g, opt_state) | |
return optax.apply_updates(params, updates), new_opt_state | |
return update | |
def batch_sample_indices( | |
n_instances: int, | |
n_minibatches: int, | |
rng_key: JaxArray, | |
) -> t.Iterable[JaxArray]: | |
indices = jax.random.permutation(rng_key, n_instances) | |
minibatch_size = n_instances // n_minibatches | |
for start in range(0, n_instances, minibatch_size): | |
yield indices[start : start + minibatch_size] | |
def main( | |
total_steps: int = 100000, | |
n_rollout_steps: int = 128, | |
n_minibatches: int = 1, | |
n_epochs: int = 10, | |
gamma: float = 0.99, | |
gae_lambda: float = 0.95, | |
clip_epsilon: float = 0.1, | |
entropy_coeff: float = 0.01, | |
seed: int = 1, | |
env: str = "Hopper-v3", | |
render: bool = False, | |
) -> None: | |
env = gym.make(env) | |
current_obs = env.reset() | |
action_dim = env.action_space.shape[0] | |
network = hk.without_apply_rng( | |
hk.transform(lambda ts: GaussianPiAndVNet(action_dim)(ts)) | |
) | |
# Construct the agent | |
agent = Agent(network, clip_epsilon, entropy_coeff) | |
opt = optax.adam(3e-4, eps=1e-4) | |
updater = get_updater(agent._loss, opt.update) | |
rng_seq = hk.PRNGSequence(1) | |
# Initialize the optimizer state. | |
params = jax.jit(network.init)( | |
next(rng_seq), | |
np.expand_dims(current_obs, axis=0), | |
) | |
opt_state = opt.init(params) | |
reporter = ReturnReporter() | |
for _ in range(total_steps // n_rollout_steps): | |
current_obs, rollout_result = rollout( | |
env, | |
current_obs, | |
n_rollout_steps, | |
functools.partial( | |
agent.act, | |
rng_key=next(rng_seq), | |
params=params, | |
), | |
reporter.experience, | |
) | |
_, _, next_value = network.apply(params, np.expand_dims(current_obs, 0)) | |
ppo_batch = make_ppo_batch(rollout_result, next_value, gamma, gae_lambda) | |
for _ in range(n_epochs): | |
indices_iter = batch_sample_indices( | |
n_rollout_steps, | |
n_minibatches, | |
next(rng_seq), | |
) | |
for indices in indices_iter: | |
minibatch = ppo_batch[indices] | |
params, opt_state = updater(params, opt_state, minibatch) | |
if render: | |
env.render() | |
np.save("result.npy", np.array(reporter.episode_returns)) | |
if __name__ == "__main__": | |
typer.run(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment