Created
January 20, 2022 00:01
-
-
Save sparticlesteve/201e8b1fc6118dc8df2c20d0091e92ae to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow_datasets as tfds | |
import tensorflow as tf | |
import os | |
# Download the dataset | |
tfds.disable_progress_bar() | |
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True) | |
mnist_train, mnist_test = datasets['train'], datasets['test'] | |
# Define the distribution strategy | |
strategy = tf.distribute.MirroredStrategy() | |
print('Number of devices: {}'.format(strategy.num_replicas_in_sync)) | |
# Setup the input pipeline | |
# You can also do info.splits.total_num_examples to get the total | |
# number of examples in the dataset. | |
num_train_examples = info.splits['train'].num_examples | |
num_test_examples = info.splits['test'].num_examples | |
BUFFER_SIZE = 10000 | |
BATCH_SIZE_PER_REPLICA = 8 | |
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync | |
def scale(image, label): | |
image = tf.cast(image, tf.float32) | |
image /= 255 | |
return image, label | |
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE) | |
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE) | |
# Create the model | |
with strategy.scope(): | |
model = tf.keras.Sequential([ | |
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)), | |
tf.keras.layers.MaxPooling2D(), | |
tf.keras.layers.Flatten(), | |
tf.keras.layers.Dense(64, activation='relu'), | |
tf.keras.layers.Dense(10) | |
]) | |
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), | |
optimizer=tf.keras.optimizers.Adam(), | |
metrics=['accuracy']) | |
# Define the callbacks | |
# Define the checkpoint directory to store the checkpoints. | |
checkpoint_dir = './training_checkpoints' | |
# Define the name of the checkpoint files. | |
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}") | |
# Define a function for decaying the learning rate. | |
# You can define any decay function you need. | |
def decay(epoch): | |
if epoch < 3: | |
return 1e-3 | |
elif epoch >= 3 and epoch < 7: | |
return 1e-4 | |
else: | |
return 1e-5 | |
# Define a callback for printing the learning rate at the end of each epoch. | |
class PrintLR(tf.keras.callbacks.Callback): | |
def on_epoch_end(self, epoch, logs=None): | |
print('\nLearning rate for epoch {} is {}'.format(epoch + 1, | |
model.optimizer.lr.numpy())) | |
# Put all the callbacks together. | |
callbacks = [ | |
tf.keras.callbacks.TensorBoard(log_dir='./logs'), | |
tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix, | |
save_weights_only=True), | |
tf.keras.callbacks.LearningRateScheduler(decay), | |
PrintLR() | |
] | |
# Train and evaluate | |
EPOCHS = 12 | |
model.fit(train_dataset, epochs=EPOCHS, callbacks=callbacks) | |
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir)) | |
eval_loss, eval_acc = model.evaluate(eval_dataset) | |
print('Eval loss: {}, Eval accuracy: {}'.format(eval_loss, eval_acc)) | |
print('Done') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment