Skip to content

Instantly share code, notes, and snippets.

@alessiamarcolini
Created June 7, 2018 16:50
Show Gist options
  • Save alessiamarcolini/cb3214f5bc72f15f73579995d02bb318 to your computer and use it in GitHub Desktop.
Save alessiamarcolini/cb3214f5bc72f15f73579995d02bb318 to your computer and use it in GitHub Desktop.
Snapshot Ensembles - Keras
def get_callbacks(self, model_prefix='Model'):
"""
Creates a list of callbacks that can be used during training to create a
snapshot ensemble of the model.
Args:
model_prefix: prefix for the filename of the weights.
Returns: list of 3 callbacks [ModelCheckpoint, LearningRateScheduler,
SnapshotModelCheckpoint] which can be provided to the 'fit' function
"""
if not os.path.exists('weights/'):
os.makedirs('weights/')
callback_list = [ModelCheckpoint('weights/%s-Best.h5' % model_prefix, monitor='val_acc',
save_best_only=True, save_weights_only=True),
LearningRateScheduler(schedule=self._cosine_anneal_schedule),
SnapshotModelCheckpoint(self.T, self.M, fn_prefix='weights/%s' % model_prefix)]
return callback_list
def _cosine_anneal_schedule(self, t):
cos_inner = np.pi * (t % (self.T // self.M))
cos_inner /= self.T // self.M
cos_out = np.cos(cos_inner) + 1
return float(self.alpha_zero / 2 * cos_out)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment