Last active
January 3, 2024 14:38
-
-
Save amoudgl/12a1c079fe010ac966b73766e67405d0 to your computer and use it in GitHub Desktop.
Code for our ICML23W paper "Learning to Optimize with Recurrent Hierarchical Transformers"
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
""" | |
A transformer-based learned optimizer which synthesizes inter-tensor | |
communication with self-attention and propagates CLS token as hidden | |
state to keep track of optimization history. | |
This optimizer was introduced in: | |
https://openreview.net/forum?id=MusMaHCrs2 | |
Acknowledgements: | |
* We use learned_optimization library for meta-training: | |
https://github.com/google/learned_optimization/ | |
* Haiku transformer implementation: | |
https://github.com/google-deepmind/dm-haiku/blob/master/examples/transformer/ | |
""" | |
from typing import Any, Optional, Tuple, Sequence | |
import dataclasses | |
import haiku as hk | |
import jax | |
import jax.numpy as jnp | |
import numpy as onp | |
import warnings | |
import functools | |
import flax | |
import gin | |
from jax import lax | |
from learned_optimization import summary | |
from learned_optimization import tree_utils | |
from learned_optimization.learned_optimizers import base as lopt_base | |
from learned_optimization.learned_optimizers import common | |
from learned_optimization.optimizers import base as opt_base | |
PRNGKey = jnp.ndarray | |
def layer_norm(x: jax.Array) -> jax.Array: | |
"""Applies a unique LayerNorm to x with default settings.""" | |
ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True) | |
return ln(x) | |
class MultiHeadAttention(hk.Module): | |
"""Multi-headed attention (MHA) module. | |
This module is intended for attending over sequences of vectors. | |
Rough sketch: | |
- Compute keys (K), queries (Q), and values (V) as projections of inputs. | |
- Attention weights are computed as W = softmax(QK^T / sqrt(key_size)). | |
- Output is another projection of WV^T. | |
For more detail, see the original Transformer paper: | |
"Attention is all you need" https://arxiv.org/abs/1706.03762. | |
Glossary of shapes: | |
- T: Sequence length. | |
- D: Vector (embedding) size. | |
- H: Number of attention heads. | |
""" | |
def __init__( | |
self, | |
num_heads: int, | |
key_size: int, | |
w_init_scale: Optional[float] = None, | |
*, | |
w_init: Optional[hk.initializers.Initializer] = None, | |
with_bias: bool = True, | |
b_init: Optional[hk.initializers.Initializer] = None, | |
value_size: Optional[int] = None, | |
model_size: Optional[int] = None, | |
name: Optional[str] = None, | |
): | |
"""Initialises the module. | |
Args: | |
num_heads: Number of independent attention heads (H). | |
key_size: The size of keys (K) and queries used for attention. | |
w_init_scale: DEPRECATED. Please use w_init instead. | |
w_init: Initialiser for weights in the linear map. Once `w_init_scale` is | |
fully deprecated `w_init` will become mandatory. Until then it has a | |
default value of `None` for backwards compatability. | |
with_bias: Whether to add a bias when computing various linear | |
projections. | |
b_init: Optional initializer for bias. By default, zero. | |
value_size: Optional size of the value projection (V). If None, defaults | |
to the key size (K). | |
model_size: Optional size of the output embedding (D'). If None, defaults | |
to the key size multiplied by the number of heads (K * H). | |
name: Optional name for this module. | |
""" | |
super().__init__(name=name) | |
self.num_heads = num_heads | |
self.key_size = key_size | |
self.value_size = value_size or key_size | |
self.model_size = model_size or key_size * num_heads | |
# Backwards-compatibility for w_init_scale. | |
if w_init_scale is not None: | |
warnings.warn( | |
"w_init_scale is deprecated; please pass an explicit weight " | |
"initialiser instead.", | |
DeprecationWarning, | |
) | |
if w_init and w_init_scale: | |
raise ValueError("Please provide only `w_init`, not `w_init_scale`.") | |
if w_init is None and w_init_scale is None: | |
raise ValueError( | |
"Please provide a weight initializer: `w_init`. " | |
"`w_init` will become mandatory once `w_init_scale` is " | |
"fully deprecated." | |
) | |
if w_init is None: | |
w_init = hk.initializers.VarianceScaling(w_init_scale) | |
self.w_init = w_init | |
self.with_bias = with_bias | |
self.b_init = b_init | |
def __call__( | |
self, | |
query: jax.Array, | |
key: jax.Array, | |
value: jax.Array, | |
mask: Optional[jax.Array] = None, | |
) -> jax.Array: | |
"""Computes (optionally masked) MHA with queries, keys & values. | |
This module broadcasts over zero or more 'batch-like' leading dimensions. | |
Args: | |
query: Embeddings sequence used to compute queries; shape [..., T', D_q]. | |
key: Embeddings sequence used to compute keys; shape [..., T, D_k]. | |
value: Embeddings sequence used to compute values; shape [..., T, D_v]. | |
mask: Optional mask applied to attention weights; shape [..., H=1, T', T]. | |
Returns: | |
A new sequence of embeddings, consisting of a projection of the | |
attention-weighted value projections; shape [..., T', D']. | |
""" | |
# In shape hints below, we suppress the leading dims [...] for brevity. | |
# Hence e.g. [A, B] should be read in every case as [..., A, B]. | |
*leading_dims, sequence_length, _ = query.shape | |
projection = self._linear_projection | |
# Compute key/query/values (overload K/Q/V to denote the respective sizes). | |
query_heads = projection(query, self.key_size, "query") # [T', H, Q=K] | |
key_heads = projection(key, self.key_size, "key") # [T, H, K] | |
value_heads = projection(value, self.value_size, "value") # [T, H, V] | |
# Compute attention weights. | |
attn_logits = jnp.einsum("...thd,...Thd->...htT", query_heads, key_heads) | |
attn_logits = attn_logits / onp.sqrt(self.key_size).astype(key.dtype) | |
if mask is not None: | |
if mask.ndim != attn_logits.ndim: | |
raise ValueError( | |
f"Mask dimensionality {mask.ndim} must match logits dimensionality " | |
f"{attn_logits.ndim}." | |
) | |
attn_logits = jnp.where(mask, attn_logits, -1e30) | |
attn_weights = jax.nn.softmax(attn_logits) # [H, T', T] | |
# Weight the values by the attention and flatten the head vectors. | |
attn = jnp.einsum("...htT,...Thd->...thd", attn_weights, value_heads) | |
attn = jnp.reshape(attn, (*leading_dims, sequence_length, -1)) # [T', H*V] | |
# Apply another projection to get the final embeddings. | |
final_projection = hk.Linear( | |
self.model_size, | |
w_init=self.w_init, | |
with_bias=self.with_bias, | |
b_init=self.b_init, | |
) | |
return final_projection(attn) # [T', D'] | |
@hk.transparent | |
def _linear_projection( | |
self, | |
x: jax.Array, | |
head_size: int, | |
name: Optional[str] = None, | |
) -> jax.Array: | |
y = hk.Linear( | |
self.num_heads * head_size, | |
w_init=self.w_init, | |
with_bias=self.with_bias, | |
b_init=self.b_init, | |
name=name, | |
)(x) | |
*leading_dims, _ = x.shape | |
return y.reshape((*leading_dims, self.num_heads, head_size)) | |
@dataclasses.dataclass | |
class Transformer(hk.Module): | |
"""A transformer stack, adapted from DM Haiku example implementation. | |
NOTE: Dropout is turned off in this model, it's | |
just a dummy parameter for now. | |
""" | |
num_heads: int | |
num_layers: int | |
key_size: int | |
dropout_rate: float | |
widening_factor: int = 4 | |
name: Optional[str] = None | |
def __call__( | |
self, | |
embeddings: jax.Array, # [B, T, D] | |
mask: jax.Array, # [B, T] | |
*, | |
is_training: bool = True, | |
) -> jax.Array: # [B, T, D] | |
"""Transforms input embedding sequences to output embedding sequences.""" | |
initializer = hk.initializers.VarianceScaling(2 / self.num_layers) | |
# Dropout is disabled in the optimizer | |
# dropout_rate = self.dropout_rate if is_training else 0.0 | |
seq_len, model_size = embeddings.shape | |
# Compute bidirectional mask | |
mask = mask[None, None, :] # [B, H=1, T'=1, T] | |
bidirectional_mask = onp.ones((1, seq_len, seq_len)) | |
mask = mask * bidirectional_mask # [B, H=1, T, T] | |
h = embeddings | |
for _ in range(self.num_layers): | |
# First the attention block. | |
attn_block = MultiHeadAttention( | |
num_heads=self.num_heads, | |
key_size=self.key_size, | |
model_size=model_size, | |
w_init=initializer, | |
) | |
h_norm = layer_norm(h) | |
h_attn = attn_block(h_norm, h_norm, h_norm, mask=mask) | |
# h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn) | |
h = h + h_attn | |
# Then the dense block. | |
dense_block = hk.Sequential( | |
[ | |
hk.Linear(self.widening_factor * model_size, w_init=initializer), | |
jax.nn.gelu, | |
hk.Linear(model_size, w_init=initializer), | |
] | |
) | |
h_norm = layer_norm(h) | |
h_dense = dense_block(h_norm) | |
# h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense) | |
h = h + h_dense | |
return layer_norm(h) | |
@dataclasses.dataclass | |
class EncoderModel(hk.Module): | |
"""A transformer encoder model.""" | |
transformer: Transformer | |
model_size: int | |
name: Optional[str] = None | |
def __call__( | |
self, | |
tokens: jax.Array, | |
hidden_state: jax.Array, | |
*, | |
is_training: bool = True, | |
) -> jax.Array: | |
"""Forward pass of transformer. | |
N: number of tensors in a neural net. | |
D: size of a tensor feature vector. | |
H: hidden size of transformer aka `model_size`. | |
Args: | |
tokens (jax.Array): Tensor features with shape (N, D). | |
hidden state (jax.Array): Hidden state vector with shape (D). | |
is_training (bool, optional): Training mode. Defaults to True. | |
Returns: | |
jax.Array: Transformed tensor embeddings with shape (N, H). | |
""" | |
# Embed the input tokens and positions. | |
embed_token = hk.Linear(self.model_size) | |
token_embeddings = embed_token(tokens) | |
# We just add 'token type' positinal embedding to separate hidden state | |
# from tensor tokens | |
embed_init = hk.initializers.TruncatedNormal(stddev=0.02) | |
positional_embeddings = hk.get_parameter( | |
"positional_embeddings", [2, self.model_size], init=embed_init | |
) | |
input_embeddings = token_embeddings + positional_embeddings[1] # [B, T, D] | |
hidden_embedding = hidden_state[None, :] + positional_embeddings[0] | |
input_embeddings = jnp.concatenate([hidden_embedding, token_embeddings], axis=0) | |
input_mask = jnp.ones(input_embeddings.shape[:-1]) | |
# Run the transformer over the inputs. | |
output_embeddings = self.transformer( | |
input_embeddings, | |
input_mask, | |
is_training=is_training, | |
) # [B, T, D] | |
hidden_state = output_embeddings[0] | |
embeddings = output_embeddings[1:] | |
return embeddings, hidden_state | |
def _second_moment_normalizer(x, axis, eps=1e-5): | |
return x * lax.rsqrt(eps + jnp.mean(jnp.square(x), axis=axis, keepdims=True)) | |
def _sin_embedding(iteration: jnp.ndarray) -> jnp.ndarray: | |
"""Embed the inner-training iteration with sin of various frequency.""" | |
def one_freq(timescale): | |
return jnp.sin(iteration / (jnp.float32(timescale) * jnp.pi)) | |
timescales = jnp.asarray( | |
[1, 3, 10, 30, 100, 300, 1000, 3000, 10000, 30000, 100000], dtype=jnp.float32 | |
) | |
return jax.vmap(one_freq)(timescales) | |
@flax.struct.dataclass | |
class _LossNormalizerState: | |
mean: jnp.ndarray | |
var: jnp.ndarray | |
updates: jnp.ndarray | |
class _LossNormalizer: | |
"""Tracks loss through time and normalizes to a similar range across tasks.""" | |
def __init__(self, decay: float): | |
self.decay = decay | |
def init(self) -> _LossNormalizerState: | |
return _LossNormalizerState( | |
mean=jnp.asarray(0.0), var=jnp.asarray(0.0), updates=jnp.int32(0) | |
) | |
def next_state(self, state: _LossNormalizerState, loss: jnp.ndarray) -> _LossNormalizerState: | |
new_mean = self.decay * state.mean + (1.0 - self.decay) * loss | |
new_var = self.decay * state.var + (1.0 - self.decay) * jnp.square(new_mean - loss) | |
new_updates = state.updates + 1 | |
return _LossNormalizerState(mean=new_mean, var=new_var, updates=new_updates) | |
def weight_loss(self, state: _LossNormalizerState, loss: jnp.ndarray) -> jnp.ndarray: | |
c = 1.0 / (1 - self.decay ** jnp.asarray(state.updates, jnp.float32) + 1e-8) | |
cor_mean = state.mean * c | |
cor_var = state.var * c | |
l = (loss - cor_mean) * lax.rsqrt(cor_var + 1e-8) | |
return jnp.clip(l, -5, 5) | |
def corrected_mean(self, state: _LossNormalizerState) -> jnp.ndarray: | |
c = 1.0 / (1 - self.decay ** jnp.asarray(state.updates, jnp.float32) + 1e-7) | |
return state.mean * c | |
def _avg_square_mean(tree: Any) -> jnp.ndarray: | |
return sum([jnp.mean(jnp.square(x)) for x in jax.tree_util.tree_leaves(tree)]) / len( | |
jax.tree_util.tree_leaves(tree) | |
) | |
def _clip_log_abs(value: jnp.ndarray) -> jnp.ndarray: | |
mag = jnp.log(1e-8 + jnp.abs(value)) | |
return jnp.clip(mag, -5, 5) | |
def _sorted_values(dd): | |
return list(zip(*sorted(dd.items(), key=lambda x: x[0])))[1] | |
def _unstack(a: jnp.ndarray, axis: int = 0) -> Sequence[jnp.ndarray]: | |
"""The opposite of jnp.stack().""" | |
shape = a.shape | |
return [jnp.squeeze(b, axis=axis) for b in jnp.split(a, shape[axis], axis=axis)] | |
@flax.struct.dataclass | |
class _DynamicGradientClipperState: | |
iteration: jnp.ndarray | |
value: jnp.ndarray | |
class _DynamicGradientClipper: | |
"""Keep track of gradient norms and clip gradients to reasonable range.""" | |
def __init__(self, alpha: float = 0.99, clip_mult: float = 10.0): | |
self.alpha = alpha | |
self.clip_mult = clip_mult | |
def initial_state(self) -> _DynamicGradientClipperState: | |
return _DynamicGradientClipperState( | |
jnp.asarray(1, dtype=jnp.float32), | |
jnp.asarray(1.0, dtype=jnp.float32) * (1 - self.alpha), | |
) | |
def _normalize( | |
self, state: _DynamicGradientClipperState, grads: opt_base.Params | |
) -> opt_base.Params: | |
t, snd = state.iteration, state.value | |
clip_amount = (snd / (1 - self.alpha**t)) * self.clip_mult | |
summary.summary("dynamic_grad_clip", clip_amount) | |
return jax.tree_util.tree_map(lambda g: jnp.clip(g, -clip_amount, clip_amount), grads) | |
def next_state_and_normalize( | |
self, state: _DynamicGradientClipperState, grads: opt_base.Params | |
) -> Tuple[_DynamicGradientClipperState, opt_base.Params]: | |
t, snd = state.iteration, state.value | |
clipped_grads = self._normalize(state, grads) | |
avg_squared_mean = _avg_square_mean(clipped_grads) | |
new_snd_moment = jnp.sqrt(1e-8 + avg_squared_mean) | |
next_snd = snd * self.alpha + new_snd_moment * (1.0 - self.alpha) | |
return _DynamicGradientClipperState(t + 1, next_snd), clipped_grads | |
@flax.struct.dataclass | |
class LOptState: | |
"""State used to train a Task / inner-problem.""" | |
params: opt_base.Params | |
mom_rolling: common.MomAccumulator | |
rms_rolling: common.RMSAccumulator | |
iteration: jnp.ndarray | |
state: Optional[opt_base.ModelState] | |
tx_hidden_state: Any | |
from_mlp: Any | |
train_loss_accum: Any | |
valid_loss_accum: _LossNormalizerState | |
dynamic_clip: _DynamicGradientClipperState | |
@gin.configurable | |
class TxLOpt(lopt_base.LearnedOptimizer): | |
"""Learned optimizer with a transformer and per param MLP. | |
See top level doc string for more information. | |
""" | |
def __init__( | |
self, | |
step_multiplier: float = 0.001, | |
magnitude_rate: float = 0.001, | |
hidden_size: int = 32, | |
hidden_layer: int = 2, | |
from_mlp_size: int = 16, | |
tx_to_ff: int = 17, | |
tx_hidden_size: int = 64, | |
num_heads: int = 4, | |
num_layers: int = 4, | |
decays: Sequence[float] = (0.5, 0.9, 0.99, 0.999, 0.9999), | |
): | |
self.step_multiplier = step_multiplier | |
self.magnitude_rate = magnitude_rate | |
self.hidden_size = hidden_size | |
self.hidden_layer = hidden_layer | |
self.from_mlp_size = from_mlp_size | |
self.tx_to_ff = tx_to_ff | |
self.tx_hidden_size = tx_hidden_size | |
self.decays = jnp.asarray(decays) | |
self.num_heads = num_heads | |
self.num_layers = num_layers | |
def _per_param_mlp_network(inp): | |
hiddens = [hidden_size] * hidden_layer + [2 + from_mlp_size] | |
return hk.nets.MLP(hiddens)(inp) | |
self.per_param_mlp_network = hk.without_apply_rng(hk.transform(_per_param_mlp_network)) | |
self.tx_to_mlp_network = hk.without_apply_rng( | |
hk.transform(lambda x: hk.Linear(tx_to_ff, name="tx_to_ff")(x)) | |
) | |
def _forward_tx(tokens, state): | |
tx = Transformer( | |
num_heads=num_heads, | |
num_layers=num_layers, | |
key_size=32, | |
dropout_rate=0.1, | |
) | |
encoder = EncoderModel(model_size=tx_hidden_size, transformer=tx) | |
return encoder(tokens, state) | |
self.tx_network = hk.without_apply_rng(hk.transform(_forward_tx)) | |
def initial_state(hidden_size) -> jax.Array: | |
embed_init = hk.initializers.TruncatedNormal(stddev=0.02) | |
hidden_embedding = hk.get_parameter("hidden_embedding", [hidden_size], init=embed_init) | |
return hidden_embedding | |
self.initial_state = initial_state | |
def init(self, key) -> lopt_base.MetaParams: | |
"""Initialization of the meta-parameters.""" | |
key1, key2, key3, key4, key5 = jax.random.split(key, 5) | |
# To create the weights of the transformer, we must know the number of inputs created | |
# by the `features_for_tensor` function. | |
tensor_features = 18 | |
tx_inp_size = tensor_features + self.from_mlp_size | |
# To create the weights of the MLP we must know the number of inputs created | |
# by the `mlp_features_per_tensor` function. | |
feed_forward_features = 37 | |
mlp_inp_size = feed_forward_features + self.tx_to_ff | |
_, var_init = hk.transform(hk.initializers.VarianceScaling()) | |
initial_state_fn = hk.transform(self.initial_state) | |
params = initial_state_fn.init(key2, self.tx_hidden_size) | |
tx_initial_state = initial_state_fn.apply(params, None, self.tx_hidden_size) | |
return { | |
"initial_from_mlp": var_init(None, key1, [self.from_mlp_size], dtype=jnp.float32), | |
"tx_init_state": tx_initial_state, | |
"tx_params": self.tx_network.init(key3, jnp.zeros([1, tx_inp_size]), tx_initial_state), | |
"tx_to_ff_params": self.tx_to_mlp_network.init( | |
key4, jnp.zeros([0, self.tx_hidden_size]) | |
), | |
"ffmod_params": self.per_param_mlp_network.init(key5, jnp.zeros([0, mlp_inp_size])), | |
} | |
def opt_fn(self, theta: lopt_base.MetaParams, is_training: bool = False) -> opt_base.Optimizer: | |
vec_roll_rms = common.vec_rolling_rms(self.decays) | |
vec_roll_mom = common.vec_rolling_mom(self.decays) | |
valid_loss_normalizer = _LossNormalizer(0.95) | |
train_loss_normalizer = _LossNormalizer(0.9) | |
dynamic_gradient_clip = _DynamicGradientClipper() | |
parent = self | |
class _Opt(opt_base.Optimizer): | |
"""Optimizer which contains meta-parameters.""" | |
def __init__(self, theta: lopt_base.MetaParams): | |
super().__init__() | |
self.theta = theta | |
def init( | |
self, | |
params: opt_base.Params, | |
model_state: Optional[opt_base.ModelState] = None, | |
num_steps: Optional[jnp.ndarray] = None, | |
key: Optional[PRNGKey] = None, | |
) -> LOptState: | |
# n_states: number of tensors in the optimizee net | |
n_states = len(jax.tree_util.tree_leaves(params)) | |
tx_hidden_state = self.theta["tx_init_state"] | |
from_mlp = jax.tree_util.tree_map(lambda x: self.theta["initial_from_mlp"], params) | |
return LOptState( | |
params=params, | |
mom_rolling=vec_roll_mom.init(params), | |
rms_rolling=vec_roll_rms.init(params), | |
iteration=jnp.asarray(0, dtype=jnp.int32), | |
state=model_state, | |
tx_hidden_state=tx_hidden_state, | |
from_mlp=from_mlp, | |
train_loss_accum=valid_loss_normalizer.init(), | |
valid_loss_accum=train_loss_normalizer.init(), | |
dynamic_clip=dynamic_gradient_clip.initial_state(), | |
) | |
def features_for_tensor( | |
self, | |
ms: jnp.ndarray, | |
rms: jnp.ndarray, | |
g: jnp.ndarray, | |
v: jnp.ndarray, | |
from_mlp: jnp.ndarray, | |
train_loss_feat: jnp.ndarray, | |
valid_loss_feat: jnp.ndarray, | |
) -> Sequence[jnp.ndarray]: | |
"""Compute per-tensor features. | |
This function is called once per tensor. | |
Args: | |
ms: momentum accumulators | |
rms: second moment accumulators | |
g: gradient value | |
v: parameter vaule | |
from_mlp: conditioning value sent from per-param mlp. | |
train_loss_feat: Array which contains featurized train loss | |
valid_loss_feat: Array which contains featurized valid loss | |
Returns: | |
A list of features. Each feature is a vector. | |
""" | |
inputs = {} | |
mean_ms = jnp.mean(ms) | |
inputs["mean_ms_mag"] = _clip_log_abs(mean_ms) | |
inputs["mean_ms_sign"] = jnp.sign(mean_ms) | |
var_ms = jnp.mean(jnp.square(ms - mean_ms)) | |
inputs["var_ms"] = _clip_log_abs(var_ms) | |
mean_rms = jnp.mean(rms) | |
inputs["mean_rms"] = _clip_log_abs(mean_rms) | |
inputs["mean_sign"] = jnp.sign(mean_rms) | |
var_rms = jnp.mean(jnp.square(rms - mean_rms)) | |
inputs["var_rms"] = _clip_log_abs(var_rms) | |
mean_v = jnp.mean(v) | |
inputs["mean_v_mag"] = _clip_log_abs(mean_v) | |
inputs["mean_v_sign"] = jnp.sign(mean_v) | |
var_v = jnp.mean(jnp.square(v - mean_v)) | |
inputs["var_v"] = _clip_log_abs(var_v) | |
inputs["norm_weight"] = _clip_log_abs(jnp.linalg.norm(v)) | |
g_norm = jnp.linalg.norm(g) | |
inputs["g_norm"] = _clip_log_abs(g_norm) | |
inputs["is_scalar"] = jnp.asarray( | |
1.0 if len(v.shape) == 0 else -1.0 | |
) # pylint: disable=g-explicit-length-test | |
extra_dims = [1.0] * (4 - len(v.shape)) | |
shape_stack = jnp.concatenate( | |
[onp.asarray(v.shape, jnp.float32), jnp.asarray(extra_dims)], axis=0 | |
) | |
for j in range(4): | |
# Shift so that these are closer to zero mean. | |
inputs["shape_%d" % j] = jnp.log(shape_stack)[j] - 1.0 | |
# Features from training loss | |
inputs["train_loss_feat"] = train_loss_feat | |
inputs["valid_loss_feat"] = valid_loss_feat | |
# Features from lower level MLP | |
inputs["from_mlp"] = from_mlp | |
values = _sorted_values(inputs) | |
reshaped = [ | |
jnp.expand_dims(v, 0) if len(v.shape) == 0 else v | |
for v in values # pylint: disable=g-explicit-length-test | |
] | |
return reshaped | |
def mlp_features_for_tensor( | |
self, | |
m: jnp.ndarray, | |
rms: jnp.ndarray, | |
g: jnp.ndarray, | |
v: jnp.ndarray, | |
ff_inputs: jnp.ndarray, | |
training_step: jnp.ndarray, | |
num_tensors: jnp.ndarray, | |
) -> jnp.ndarray: | |
flat_g = jnp.reshape(g, [-1, 1]) | |
# These have a trailing dim of decays. We want to reshape them so that | |
# they have the leading dimensions flattened. | |
rms = jnp.reshape(rms, [int(onp.prod(rms.shape[0:-1])), rms.shape[-1]]) | |
m = jnp.reshape(m, [int(onp.prod(m.shape[0:-1])), m.shape[-1]]) | |
rsqrt = lax.rsqrt(rms + 1e-6) | |
rms_scaled_g = m * rsqrt | |
flat_v = jnp.reshape(v, [-1, 1]) | |
# Per component features | |
inps = {} | |
inps["flat_g"] = flat_g | |
inps["flat_v"] = flat_v | |
inps["log_abs_v"] = jnp.log(jnp.abs(flat_v) + 1e-8) | |
inps["m"] = m | |
inps["rms_scaled_g"] = rms_scaled_g | |
inps["rms"] = rms | |
inps["rsqrt"] = rsqrt | |
# Stack the values to form one vector which we normalize | |
inp = jnp.concatenate(_sorted_values(inps), 1) | |
# Normalize across all the values of the tensor. | |
inp = _second_moment_normalizer(inp, axis=0) | |
step = _sin_embedding(training_step) | |
stack_step = jnp.tile(jnp.reshape(step, [1, -1]), onp.asarray([flat_g.shape[0], 1])) | |
# These are all featuers that are computed across the tensor. We tile | |
# them to be able to pass them into the MLP | |
# Subtract 1. to at least attempt to zero center. | |
log_num_tensors = jnp.log(float(num_tensors)) - 1.0 | |
stack_num_tensors = jnp.tile( | |
jnp.reshape(log_num_tensors, [1, 1]), [flat_g.shape[0], 1] | |
) | |
# Feature based on the norm of the parameters -- this should not be | |
# normalized as we care about absolute magnitude | |
log_norm = jnp.log(jnp.linalg.norm(flat_v) + 1e-8) | |
stack_log_norm = jnp.tile(jnp.reshape(log_norm, [1, 1]), [flat_g.shape[0], 1]) | |
# Feature which is number of parameters in the current layer | |
log_n_weight = jnp.log(float(flat_v.shape[0])) | |
stack_log_n_weight = jnp.tile( | |
jnp.reshape(log_n_weight, [1, 1]), [flat_g.shape[0], 1] | |
) | |
ff_inp = jnp.tile(jnp.reshape(ff_inputs, [1, -1]), [flat_g.shape[0], 1]) | |
# Stack up all the features | |
return jnp.concatenate( | |
[ | |
inp, | |
stack_step, | |
stack_num_tensors, | |
stack_log_norm, | |
stack_log_n_weight, | |
ff_inp, | |
], | |
axis=1, | |
) | |
def update( | |
self, | |
opt_state: LOptState, | |
grads, | |
loss: Optional[jnp.ndarray] = None, | |
model_state: Optional[opt_base.ModelState] = None, | |
is_valid: bool = False, | |
key: Optional[PRNGKey] = None, | |
**kwargs, | |
) -> LOptState: | |
"""Perform a single inner-problem update.""" | |
if loss is None: | |
raise ValueError("This optimizer must be called with a loss!") | |
# Instead of doing jax.lax.cond to swap implementations, | |
# we will run both computations and select one. This is required to get | |
# summaries to work through a cond. This is fine as the validation path | |
# is quite cheap. | |
opt1 = self.update_is_valid(opt_state, loss) | |
opt2 = self.update_is_training(opt_state, grads, loss, model_state) | |
return jax.lax.cond(is_valid, lambda _: opt1, lambda _: opt2, ()) | |
def update_is_valid(self, opt_state, loss) -> LOptState: | |
# When computing an update with vaidation data, all we do is update the | |
# validation loss. | |
next_valid_loss_accum = valid_loss_normalizer.next_state( | |
opt_state.valid_loss_accum, loss | |
) | |
next_opt_state = opt_state.replace( | |
iteration=opt_state.iteration + 1, | |
valid_loss_accum=next_valid_loss_accum, | |
) | |
return tree_utils.match_type(next_opt_state, opt_state) | |
def update_is_training(self, opt_state, grads, loss, model_state) -> LOptState: | |
theta = self.theta | |
# Update the training loss. | |
next_train_loss_accum = train_loss_normalizer.next_state( | |
opt_state.train_loss_accum, loss | |
) | |
# Compute various loss features | |
train_loss_feat = train_loss_normalizer.weight_loss(next_train_loss_accum, loss) | |
valid_loss = valid_loss_normalizer.corrected_mean(opt_state.valid_loss_accum) | |
valid_loss_feat = train_loss_normalizer.weight_loss( | |
next_train_loss_accum, valid_loss | |
) | |
summary.summary("valid_loss", valid_loss) | |
# Clip and update gradient clipper | |
( | |
next_dynamic_clip, | |
grads, | |
) = dynamic_gradient_clip.next_state_and_normalize(opt_state.dynamic_clip, grads) | |
next_mom_rolling = vec_roll_mom.update(opt_state.mom_rolling, grads) | |
next_rms_rolling = vec_roll_rms.update(opt_state.rms_rolling, grads) | |
ms = next_mom_rolling.m | |
rms = next_rms_rolling.rms | |
param_tree = jax.tree_util.tree_structure(ms) | |
def to_map_per_tensor(ms, rms, g, v, from_mlp): | |
return self.features_for_tensor( | |
ms, rms, g, v, from_mlp, train_loss_feat, valid_loss_feat | |
) | |
tree_args = (ms, rms, grads, opt_state.params, opt_state.from_mlp) | |
flat_args = [jax.tree_util.tree_leaves(a) for a in tree_args] | |
stacked_inp_tree = jax.tree_util.tree_map(to_map_per_tensor, *flat_args) | |
# We stack all the different tensors together so that we can run the | |
# transformer only once. | |
tx_inputs = jnp.stack([jnp.concatenate(v, axis=0) for v in stacked_inp_tree]) | |
# Run the transformer on the features | |
tx_out, next_tx_hidden_state = parent.tx_network.apply( | |
theta["tx_params"], tx_inputs, opt_state.tx_hidden_state | |
) | |
# Compute values passed from the transformer into the FF network. | |
ff_inputs = parent.tx_to_mlp_network.apply(theta["tx_to_ff_params"], tx_out) | |
# These need to be unstacked as they are currently concatenated | |
ff_inputs = _unstack(ff_inputs) | |
# And need to be converted back to a parameter tree structure. | |
ff_inputs = jax.tree_util.tree_unflatten(param_tree, ff_inputs) | |
num_tensors = len(jax.tree_util.tree_leaves(opt_state.params)) | |
def to_map_get_mlp_features(m, rms, g, v, ff_inputs): | |
return self.mlp_features_for_tensor( | |
m, | |
rms, | |
g, | |
v, | |
ff_inputs, # pytype: disable=wrong-arg-types # jax-ndarray | |
opt_state.iteration, | |
num_tensors, | |
) | |
# Prep the features | |
ff_feats = jax.tree_util.tree_map( | |
to_map_get_mlp_features, ms, rms, grads, opt_state.params, ff_inputs | |
) | |
# Apply the per parameter mlp on these features. | |
outputs = jax.tree_util.tree_map( | |
functools.partial(parent.per_param_mlp_network.apply, theta["ffmod_params"]), | |
ff_feats, | |
) | |
# Split apart the outputs and create both the next parameters, and the | |
# inputs needed for the next learned optimizer application. | |
new_params = [] | |
from_mlp = [] | |
for o, v in zip( | |
jax.tree_util.tree_leaves(outputs), | |
jax.tree_util.tree_leaves(opt_state.params), | |
): | |
direction = o[:, 0:1] | |
magnitude = o[:, 1:2] | |
step = ( | |
direction | |
* jnp.exp(magnitude * parent.magnitude_rate) | |
* parent.step_multiplier | |
) | |
step = step.reshape(v.shape) | |
new_params.append(v - step) | |
to_tx = jnp.mean(o[:, 2:], axis=0) | |
from_mlp.append(to_tx) | |
# Convert these structures back to match the parameter tree. | |
new_params = jax.tree_util.tree_unflatten(param_tree, new_params) | |
new_from_mlp = jax.tree_util.tree_unflatten(param_tree, from_mlp) | |
# Finally, package all these values up and return. | |
next_opt_state = LOptState( | |
params=new_params, | |
mom_rolling=next_mom_rolling, | |
rms_rolling=next_rms_rolling, | |
iteration=opt_state.iteration + 1, | |
state=model_state, | |
tx_hidden_state=next_tx_hidden_state, | |
from_mlp=new_from_mlp, | |
train_loss_accum=next_train_loss_accum, | |
valid_loss_accum=opt_state.valid_loss_accum, | |
dynamic_clip=next_dynamic_clip, | |
) | |
return tree_utils.match_type(next_opt_state, opt_state) | |
return _Opt(theta) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
If you find this work useful, please cite: