Last active
October 29, 2022 02:54
-
-
Save quocdat32461997/cae85b748ce651ff6e3013880a5659af to your computer and use it in GitHub Desktop.
Trivial example for Mixed-Gradient-Error and Mean-Gradient-Error
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
def MeanGradientError(outputs, targets, weight): | |
filter_x = tf.tile(tf.expand_dims(tf.constant([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype = outputs.dtype), axis = -1), [1, 1, outputs.shape[-1]) | |
filter_x = tf.tile(tf.expand_dims(filter_x, axis = -1), [1, 1, 1, outputs.shape[-1]]) | |
filter_y = tf.tile(tf.expand_dims(tf.constant([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype = outputs.dtype), axis = -1), [1, 1, targets.shape[-1]]) | |
filter_y = tf.tile(tf.expand_dims(filter_y, axis = -1), [1, 1, 1, targets.shape[-1]]) | |
# output gradient | |
output_gradient_x = tf.math.square(tf.nn.conv2d(outputs, filter_x, strides = 1, padding = 'SAME')) | |
output_gradient_y = tf.math.square(tf.nn.conv2d(outputs, filter_y, strides = 1, padding = 'SAME')) | |
#target gradient | |
target_gradient_x = tf.math.square(tf.nn.conv2d(targets, filter_x, strides = 1, padding = 'SAME')) | |
target_gradient_y = tf.math.square(tf.nn.conv2d(targets, filter_y, strides = 1, padding = 'SAME')) | |
# square | |
output_gradients = tf.math.sqrt(tf.math.add(output_gradient_x, output_gradient_y)) | |
target_gradients = tf.math.sqrt(tf.math.add(target_gradient_x, target_gradient_y)) | |
# compute mean gradient error | |
shape = output_gradients.shape[1:3] | |
mge = tf.math.reduce_sum(tf.math.squared_difference(output_gradients, target_gradients) / (shape[0] * shape[1])) | |
return mge * weight | |
x = tf.random.normal(shape = (224, 224, 3) | |
y = tf.random.normal(shape = (224, 224, 3) | |
gradient_loss = MeanGradientError(x, y, weight = 0.1) |
Just letting people know: This loss tends to cause NAN losses in starting periods of training. It behaves once you have a epoch or two but I wouldn't recommend risking it. Theres hould be a better implementation out somewhere... except that the author's github is no longer visible. This is with the sobel filter fix of course.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I guess there is a small problem in your code at line 4, the sobel filter should be [[-1, -2, -1], [0, 0, 0], [1, 2, 1]], but [[-1, -2, -2], [0, 0, 0], [1, 2, 1]] in your code