Created
June 22, 2021 23:11
-
-
Save angeligareta/77ea24e08f46a124c4761a908b6cfdb9 to your computer and use it in GitHub Desktop.
Custom Early Stopping callback to monitor multiple metrics by combining them using a harmonic mean calculation.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
class CustomEarlyStopping(tf.keras.callbacks.Callback): | |
""" | |
Custom Early Stopping callback to monitor multiple metrics by combining them using a harmonic mean calculation. | |
Adapted from (TensorFlow EarlyStopping source)[https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/python/keras/callbacks.py#L1683-L1823]. | |
Author: Angel Igareta ([email protected]) | |
""" | |
def __init__( | |
self, | |
metrics_names=["loss"], | |
mode="min", | |
patience=0, | |
restore_weights=False, | |
logdir=None, | |
): | |
super(CustomEarlyStopping, self).__init__() | |
self.metrics_names = metrics_names | |
self.mode = mode | |
self.patience = patience | |
self.restore_weights = restore_weights | |
self.logdir = logdir | |
self.best_weights = None | |
def on_train_begin(self, logs=None): | |
# The number of epoch it has waited when loss is no longer minimum. | |
self.wait = 0 | |
# The epoch the training stops at. | |
self.stopped_epoch = 0 | |
# Initialize the best as infinity. | |
self.best_combined_metric = np.Inf if self.mode is "min" else -np.Inf | |
def on_epoch_end(self, epoch, logs=None): | |
metrics = [logs.get(name) for name in self.metrics_names] | |
metrics = tf.cast(metrics, dtype=tf.float32) | |
metrics_count = tf.cast(tf.size(metrics), dtype=tf.float32) | |
# Combined metric is the harmonic mean of the metrics_names. | |
combined_metric = tf.math.divide( | |
metrics_count, tf.math.reduce_sum(tf.math.reciprocal_no_nan(metrics)) | |
) | |
# Specify logdir if you want to log the combined metric | |
if self.logdir: | |
with tf.summary.create_file_writer(self.logdir).as_default(): | |
tf.summary.scalar("combined_metric", data=combined_metric, step=epoch) | |
# If harmonic mean is np.greater or np.less depending on min-max mode. | |
if ( | |
self.mode is "min" and np.less(combined_metric, self.best_combined_metric) | |
) or ( | |
self.mode is "max" | |
and np.greater(combined_metric, self.best_combined_metric) | |
): | |
self.best_combined_metric = combined_metric | |
self.wait = 0 | |
# Record the best weights if current results is better. | |
self.best_weights = self.model.get_weights() | |
else: | |
self.wait = 1 | |
if self.wait >= self.patience: | |
self.stopped_epoch = epoch | |
self.model.stop_training = True | |
# Restoring model weights from the end of the best epoch | |
if self.restore_weights: | |
self.model.set_weights(self.best_weights) | |
def on_train_end(self, logs=None): | |
if self.stopped_epoch > 0: | |
print("Epoch %05d: early stopping" % (self.stopped_epoch + 1)) | |
# Use as an standard Keras callback | |
early_stopping_callback = CustomEarlyStopping( | |
metrics_names=["val_precision", "val_recall"], | |
mode="max", | |
patience=10, | |
restore_weights=True, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment