Last active
October 28, 2019 18:38
-
-
Save arose13/a3618531bf1387f67695f41120a84143 to your computer and use it in GitHub Desktop.
Automatic One Hot encoding layer for Keras
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow.keras as k | |
import tensorflow.keras.backend as K | |
def _one_hot_layer(num_classes: int): | |
""" | |
One hot encoding layer to save massive amounts of memory in Keras | |
:param num_classes: | |
:return: | |
""" | |
return k.layers.Lambda(lambda x: K.one_hot(K.cast(x, 'int64'), num_classes)) | |
OneHotLayer = _one_hot_layer |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment