Created
August 4, 2022 12:58
-
-
Save naiborhujosua/5010c0e68c213e98100cfe09c15e40da to your computer and use it in GitHub Desktop.
EffiencientNetB0
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 1. Create base model with tf.keras.applications | |
base_model = tf.keras.applications.EfficientNetB0(include_top=False) | |
# 2. Freeze the base model (so the pre-learned patterns remain) | |
base_model.trainable = False | |
# 3. Create inputs into the base model | |
inputs = tf.keras.layers.Input(shape=(224, 224, 3), name="input_layer") | |
# 4. If using ResNet50V2, add this to speed up convergence, remove for EfficientNet | |
# x = tf.keras.layers.experimental.preprocessing.Rescaling(1./255)(inputs) | |
# 5. Pass the inputs to the base_model (note: using tf.keras.applications, EfficientNet inputs don't have to be normalized) | |
x = base_model(inputs) | |
# Check data shape after passing it to base_model | |
print(f"Shape after base_model: {x.shape}") | |
# 6. Average pool the outputs of the base model (aggregate all the most important information, reduce number of computations) | |
x = tf.keras.layers.GlobalAveragePooling2D(name="global_average_pooling_layer")(x) | |
print(f"After GlobalAveragePooling2D(): {x.shape}") | |
# 7. Create the output activation layer | |
outputs = tf.keras.layers.Dense(6, activation="softmax", name="output_layer")(x) | |
# 8. Combine the inputs with the outputs into a model | |
model_2 = tf.keras.Model(inputs, outputs) | |
# 9. Compile the model | |
model_2.compile(loss='categorical_crossentropy', | |
optimizer=tf.keras.optimizers.Adam(), | |
metrics=["accuracy"]) | |
# 10. Fit the model (we use less steps for validation so it's faster) | |
history_2 = model_2.fit(train_data, | |
epochs=10, | |
steps_per_epoch=len(train_data), | |
validation_data=val_data, | |
# Track model training logs | |
callbacks=[create_tensorboard_callback("CNN_learning", "transfer_learning_efficientNet_0")]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment