Created
April 7, 2020 12:00
-
-
Save georgwiese/57568ac518813b9d9f0e6785d8f707fa to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
From fb7dd5c67d70f86082efca7cd3f29061b808467f Mon Sep 17 00:00:00 2001 | |
From: Georg Wiese <[email protected]> | |
Date: Thu, 2 Apr 2020 15:38:00 +0200 | |
Subject: [PATCH] Fix keras handling of targets with no loss | |
--- | |
tensorflow/python/keras/engine/training_eager.py | 14 ++++++++------ | |
1 file changed, 8 insertions(+), 6 deletions(-) | |
diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py | |
index be1b2e89d9..e39a571a91 100644 | |
--- a/tensorflow/python/keras/engine/training_eager.py | |
+++ b/tensorflow/python/keras/engine/training_eager.py | |
@@ -143,10 +143,11 @@ def _model_loss(model, | |
output_losses = [] | |
with backend.name_scope('loss'): | |
- loss_fns = [ | |
- loss_fn for loss_fn in model.loss_functions if loss_fn is not None | |
- ] | |
- for i, loss_fn in enumerate(loss_fns): | |
+ i_target = 0 | |
+ for i, loss_fn in enumerate(model.loss_functions): | |
+ if loss_fn is None: | |
+ # Output i has no loss. | |
+ continue | |
weights = sample_weights[i] if sample_weights else None | |
mask = masks[i] | |
with backend.name_scope(model.output_names[i] + '_loss'): | |
@@ -163,7 +164,7 @@ def _model_loss(model, | |
weights *= mask | |
if hasattr(loss_fn, 'reduction'): | |
- per_sample_losses = loss_fn.call(targets[i], outs[i]) | |
+ per_sample_losses = loss_fn.call(targets[i_target], outs[i]) | |
weighted_losses = losses_utils.compute_weighted_loss( | |
per_sample_losses, | |
sample_weight=weights, | |
@@ -193,13 +194,14 @@ def _model_loss(model, | |
# as part of the loss_metrics. | |
if len(model.outputs) > 1: | |
# Keep track of the stateful output loss result. | |
- output_losses.append(output_loss_metrics[i](output_loss)) | |
+ output_losses.append(output_loss_metrics[i_target](output_loss)) | |
# Scale output loss for distribution. For custom losses we assume | |
# reduction was mean. | |
if loss_reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE: | |
output_loss = losses_utils.scale_loss_for_distribution(output_loss) | |
total_loss += model._loss_weights_list[i] * output_loss | |
+ i_target += 1 | |
# Add regularization losses | |
custom_losses = model.losses | |
-- | |
2.21.1 (Apple Git-122.3) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment