Created
April 7, 2020 17:53
-
-
Save quocdat32461997/10358455066ccc76817b54d20613c1dd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow | |
from tensorflow.keras.layers import Lambda, Dense, Input | |
from tensorflow.keras.models import Model | |
from tensorflow.keras import backend as K | |
def loss_fn(args): | |
return K.constant(1, dtype = 'float32')# creating model | |
inputs = Input(shape = (784,)) | |
dense1 = Dense(512, activation = 'relu')(inputs) | |
dense2 = Dense(128, activation = 'relu')(dense1) | |
dense3 = Dense(32, activation = 'relu')(dense2) | |
# create classification output | |
classification_output = Dense(10, activation = 'softmax')(dense3) | |
outputs = Lambda(loss_fn, name = 'loss', output_shape = (1,))(classification_output) | |
model = Model(inputs = inputs, outputs = outputs) | |
model.compile(tensorflow.keras.optimizers.Adam(learning_rate = 0.01), loss = lambda y_true, y_pred: y_pred) | |
~ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment