Skip to content

Instantly share code, notes, and snippets.

@sadimanna
Last active August 28, 2023 18:19
Show Gist options
  • Save sadimanna/19117a47dda74ebcf08929598063ae4a to your computer and use it in GitHub Desktop.
Save sadimanna/19117a47dda74ebcf08929598063ae4a to your computer and use it in GitHub Desktop.
@tf.keras.saving.register_keras_serializable(name="weighted_binary_crossentropy")
def weighted_binary_crossentropy(target, output, weights):
target = tf.convert_to_tensor(target)
output = tf.convert_to_tensor(output)
weights = tf.convert_to_tensor(weights, dtype=target.dtype)
epsilon_ = tf.constant(tf.keras.backend.epsilon(), output.dtype.base_dtype)
output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_)
# Compute cross entropy from probabilities.
bce = weights[1] * target * tf.math.log(output + epsilon_)
bce += weights[0] * (1 - target) * tf.math.log(1 - output + epsilon_)
return -bce
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment