Skip to content

Instantly share code, notes, and snippets.

@irhum
Last active October 1, 2021 07:28
Show Gist options
  • Save irhum/25e3a281741f7c742b5b50e077c2aac8 to your computer and use it in GitHub Desktop.
Save irhum/25e3a281741f7c742b5b50e077c2aac8 to your computer and use it in GitHub Desktop.
def vae_loss(input_img, output):
# compute the average MSE error, then scale it up i.e. simply sum on all axes
reconstruction_loss = K.sum(K.square(output-input_img))
# compute the KL loss
kl_loss = -0.5 * K.sum(1 + log_stddev - K.square(mean) - K.square(K.exp(log_stddev)), axis=-1)
# return the average loss over all images in batch
total_loss = K.mean(reconstruction_loss + kl_loss)
return total_loss
@irhum
Copy link
Author

irhum commented Oct 1, 2021

Copy-pasting over from the old one, it appears the previous gist + links broke

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment