Last active
June 17, 2019 16:18
-
-
Save yusugomori/1315dd25c2e2631f64865efdfd805282 to your computer and use it in GitHub Desktop.
AdaBound + AMSBound implementations with Keras
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.optimizers import Optimizer | |
from keras.legacy import interfaces | |
from keras import backend as K | |
import tensorflow as tf | |
class Adabound(Optimizer): | |
def __init__(self, lr=0.001, | |
beta_1=0.9, beta_2=0.999, | |
gamma=0.001, | |
final_lr=0.1, | |
epsilon=None, | |
decay=0., | |
amsbound=False, | |
**kwargs): | |
super().__init__(**kwargs) | |
with K.name_scope(self.__class__.__name__): | |
self.iterations = K.variable(0, dtype='int64', name='iterations') | |
self.lr = K.variable(lr, name='lr') | |
self.beta_1 = K.variable(beta_1, name='beta_1') | |
self.beta_2 = K.variable(beta_1, name='beta_2') | |
self.gamma = K.variable(gamma, name='gamma') | |
self.final_lr = K.variable(final_lr, name='final_lr') | |
self.decay = K.variable(decay, name='decay') | |
if epsilon is None: | |
epsilon = K.epsilon() | |
self.epsilon = epsilon | |
self.initial_decay = decay | |
self.amsbound = amsbound | |
@interfaces.legacy_get_updates_support | |
def get_updates(self, loss, params): | |
grads = self.get_gradients(loss, params) | |
self.updates = [K.update_add(self.iterations, 1)] | |
lr = self.lr | |
if self.initial_decay > 0: | |
lr = lr * (1. / (1. + self.decay * K.cast(self.iterations, | |
K.dtype(self.decay)))) | |
t = K.cast(self.iterations, K.floatx()) + 1 | |
lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) / | |
(1. - K.pow(self.beta_1, t))) | |
ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] | |
vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] | |
if self.amsbound: | |
vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] | |
else: | |
vhats = [K.zeros((1, 1)) for _ in params] | |
self.weights = [self.iterations] + ms + vs + vhats | |
for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats): | |
m_t = (self.beta_1 * m) + (1. - self.beta_1) * g | |
v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g) | |
lower_bound_t = self.final_lr * (1 - 1 / self.gamma * t + 1) | |
upper_bound_t = self.final_lr * (1 + 1 / self.gamma * t) | |
if self.amsbound: | |
vhat_t = K.maximum(vhat, v_t) | |
# p_t = p - K.clip(lr_t * (K.sqrt(vhat_t) + self.epsilon), | |
# lower_bound_t, | |
# upper_bound_t) | |
p_t = p - tf.clip_by_value( | |
lr_t * (K.sqrt(vhat_t) + self.epsilon), | |
lower_bound_t, | |
upper_bound_t) | |
else: | |
# p_t = p - K.clip(lr_t * (K.sqrt(v_t) + self.epsilon), | |
# lower_bound_t, | |
# upper_bound_t) | |
p_t = p - tf.clip_by_value( | |
lr_t * (K.sqrt(v_t) + self.epsilon), | |
lower_bound_t, | |
upper_bound_t) | |
self.updates.append(K.update(m, m_t)) | |
self.updates.append(K.update(v, v_t)) | |
new_p = p_t | |
# Apply constraints. | |
if getattr(p, 'constraint', None) is not None: | |
new_p = p.constraint(new_p) | |
self.updates.append(K.update(p, new_p)) | |
return self.updates | |
def get_config(self): | |
config = {'lr': float(K.get_value(self.lr)), | |
'beta_1': float(K.get_value(self.beta_1)), | |
'beta_2': float(K.get_value(self.beta_2)), | |
'gamma': float(K.get_value(self.gamma)), | |
'final_lr': float(K.get_value(self.final_lr)), | |
'decay': float(K.get_value(self.decay)), | |
'epsilon': self.epsilon, | |
'amsbound': self.amsbound} | |
base_config = super().get_config() | |
return dict(list(base_config.items()) + list(config.items())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment