Created
February 28, 2022 09:32
-
-
Save allenday/24bbdbc72b6019b919bc37d571665da5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "771857ee", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import datetime\n", | |
"import numpy as np\n", | |
"import tensorflow as tf\n", | |
"import tensorflow_datasets as tfds\n", | |
"from classification_models.keras import Classifiers\n", | |
"from PIL import Image" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "f030ab29", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ds, info = tfds.load(\n", | |
" \"mnist\",\n", | |
" split=[\"train\",\"test\"],\n", | |
" shuffle_files=True,\n", | |
" as_supervised=True,\n", | |
" with_info=True,\n", | |
" )\n", | |
"NUM_CLASSES = info.features[\"label\"].num_classes\n", | |
"SIZE = (28,28)\n", | |
"BATCH_SIZE=400\n", | |
"EPOCHS=5\n", | |
"model_base='mnist-model'\n", | |
"thresh=0\n", | |
"labels = info.features['label'].names\n", | |
"ds_train, ds_validation = ds[0], ds[1]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "179be129", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"tfds.core.DatasetInfo(\n", | |
" name='mnist',\n", | |
" full_name='mnist/3.0.1',\n", | |
" description=\"\"\"\n", | |
" The MNIST database of handwritten digits.\n", | |
" \"\"\",\n", | |
" homepage='http://yann.lecun.com/exdb/mnist/',\n", | |
" data_path='/home/allenday/tensorflow_datasets/mnist/3.0.1',\n", | |
" download_size=11.06 MiB,\n", | |
" dataset_size=21.00 MiB,\n", | |
" features=FeaturesDict({\n", | |
" 'image': Image(shape=(28, 28, 1), dtype=tf.uint8),\n", | |
" 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),\n", | |
" }),\n", | |
" supervised_keys=('image', 'label'),\n", | |
" disable_shuffling=False,\n", | |
" splits={\n", | |
" 'test': <SplitInfo num_examples=10000, num_shards=1>,\n", | |
" 'train': <SplitInfo num_examples=60000, num_shards=1>,\n", | |
" },\n", | |
" citation=\"\"\"@article{lecun2010mnist,\n", | |
" title={MNIST handwritten digit database},\n", | |
" author={LeCun, Yann and Cortes, Corinna and Burges, CJ},\n", | |
" journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},\n", | |
" volume={2},\n", | |
" year={2010}\n", | |
" }\"\"\",\n", | |
")\n", | |
"Number of classes: 10\n", | |
"Number of training samples: 60000\n", | |
"Number of validation samples: 10000\n" | |
] | |
} | |
], | |
"source": [ | |
"print(info)\n", | |
"print(\"Number of classes: %d\" % NUM_CLASSES)\n", | |
"print(\"Number of training samples: %d\" % tf.data.experimental.cardinality(ds_train))\n", | |
"print(\"Number of validation samples: %d\" % tf.data.experimental.cardinality(ds_validation))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "5e65f447", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ds_train = ds_train.map(lambda x, y: (tf.image.resize(tf.image.grayscale_to_rgb(x), SIZE), y))\n", | |
"ds_validation = ds_validation.map(lambda x, y: (tf.image.resize(tf.image.grayscale_to_rgb(x), SIZE), y))\n", | |
"\n", | |
"# As you fit the dataset in memory, cache it before shuffling for a better performance.\n", | |
"# Note: Random transformations should be applied after caching.\n", | |
"ds_train = ds_train.cache()\n", | |
"# For true randomness, set the shuffle buffer to the full dataset size.\n", | |
"# Note: For large datasets that can't fit in memory, use buffer_size=1000 if your system allows it.\n", | |
"ds_train = ds_train.shuffle(1000)\n", | |
"# Batch elements of the dataset after shuffling to get unique batches at each epoch.\n", | |
"ds_train = ds_train.batch(BATCH_SIZE)\n", | |
"# It is good practice to end the pipeline by prefetching for performance.\n", | |
"ds_train = ds_train.prefetch(tf.data.AUTOTUNE)\n", | |
"\n", | |
"ds_validation = ds_validation.cache()\n", | |
"ds_validation = ds_validation.batch(BATCH_SIZE)\n", | |
"ds_validation = ds_validation.prefetch(tf.data.AUTOTUNE)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "cc86a2ed", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ResNet18, preprocess_input = Classifiers.get('resnet18')\n", | |
"\n", | |
"scale_layer = tf.keras.layers.Rescaling(scale=1 / 127.5, offset=-1)\n", | |
"\n", | |
"\n", | |
"base_model = ResNet18(\n", | |
" input_shape=(28,28,3),\n", | |
" weights='imagenet',\n", | |
"# weights=None,\n", | |
" include_top=False\n", | |
" )\n", | |
"base_model.trainable = True\n", | |
"#base_model.trainable = False\n", | |
"\n", | |
"inputs = tf.keras.Input(shape=(28,28,3))\n", | |
"x = scale_layer(inputs)\n", | |
"\n", | |
"x = base_model(x, training=True)\n", | |
"x = tf.keras.layers.GlobalAveragePooling2D()(x)\n", | |
"x = tf.keras.layers.Dense(NUM_CLASSES, activation=None)(x)\n", | |
"outputs = tf.keras.layers.Activation(activation=\"softmax\", name=\"activation\")(x)\n", | |
"loss_function = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "02002aee", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Define the per-epoch callbacks\n", | |
"logdir = \"logs/position/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n", | |
"cbTensorBoard = tf.keras.callbacks.TensorBoard(log_dir = logdir, histogram_freq = 1)\n", | |
"cbEarlyStop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=100)\n", | |
"cbCheckPoint = tf.keras.callbacks.ModelCheckpoint(\n", | |
" filepath=model_base,\n", | |
" monitor = \"val_sparse_categorical_accuracy\",\n", | |
" verbose=1,\n", | |
" save_best_only=True,\n", | |
" mode='max',\n", | |
" initial_value_threshold=thresh\n", | |
" )\n", | |
"opt = tf.keras.optimizers.SGD(learning_rate=0.001,momentum=0.9)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "06ea64b9", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Model: \"model_1\"\n", | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"input_1 (InputLayer) [(None, 28, 28, 3)] 0 \n", | |
"_________________________________________________________________\n", | |
"rescaling (Rescaling) (None, 28, 28, 3) 0 \n", | |
"_________________________________________________________________\n", | |
"model (Functional) (None, 1, 1, 512) 11186889 \n", | |
"_________________________________________________________________\n", | |
"global_average_pooling2d (Gl (None, 512) 0 \n", | |
"_________________________________________________________________\n", | |
"dense (Dense) (None, 10) 5130 \n", | |
"_________________________________________________________________\n", | |
"activation (Activation) (None, 10) 0 \n", | |
"=================================================================\n", | |
"Total params: 11,192,019\n", | |
"Trainable params: 11,184,077\n", | |
"Non-trainable params: 7,942\n", | |
"_________________________________________________________________\n" | |
] | |
} | |
], | |
"source": [ | |
"model = tf.keras.Model(inputs, outputs)\n", | |
"model.summary()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "558affcc", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model.compile(\n", | |
" optimizer=opt,\n", | |
" loss=loss_function,\n", | |
" metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "ccc6d981", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/allenday/venv/tf/lib/python3.6/site-packages/keras/utils/generic_utils.py:497: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n", | |
" category=CustomMaskWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 1/5\n", | |
"150/150 [==============================] - 8s 31ms/step - loss: 0.4073 - sparse_categorical_accuracy: 0.8792 - val_loss: 0.1005 - val_sparse_categorical_accuracy: 0.9683\n", | |
"\n", | |
"Epoch 00001: val_sparse_categorical_accuracy improved from -inf to 0.96830, saving model to mnist-model\n", | |
"INFO:tensorflow:Assets written to: mnist-model/assets\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Assets written to: mnist-model/assets\n", | |
"/home/allenday/venv/tf/lib/python3.6/site-packages/keras/utils/generic_utils.py:497: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n", | |
" category=CustomMaskWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 2/5\n", | |
"150/150 [==============================] - 4s 24ms/step - loss: 0.0747 - sparse_categorical_accuracy: 0.9776 - val_loss: 0.0720 - val_sparse_categorical_accuracy: 0.9770\n", | |
"\n", | |
"Epoch 00002: val_sparse_categorical_accuracy improved from 0.96830 to 0.97700, saving model to mnist-model\n", | |
"INFO:tensorflow:Assets written to: mnist-model/assets\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Assets written to: mnist-model/assets\n", | |
"/home/allenday/venv/tf/lib/python3.6/site-packages/keras/utils/generic_utils.py:497: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n", | |
" category=CustomMaskWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 3/5\n", | |
"150/150 [==============================] - 4s 24ms/step - loss: 0.0469 - sparse_categorical_accuracy: 0.9859 - val_loss: 0.0608 - val_sparse_categorical_accuracy: 0.9806\n", | |
"\n", | |
"Epoch 00003: val_sparse_categorical_accuracy improved from 0.97700 to 0.98060, saving model to mnist-model\n", | |
"INFO:tensorflow:Assets written to: mnist-model/assets\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Assets written to: mnist-model/assets\n", | |
"/home/allenday/venv/tf/lib/python3.6/site-packages/keras/utils/generic_utils.py:497: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n", | |
" category=CustomMaskWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 4/5\n", | |
"150/150 [==============================] - 4s 24ms/step - loss: 0.0342 - sparse_categorical_accuracy: 0.9901 - val_loss: 0.0546 - val_sparse_categorical_accuracy: 0.9817\n", | |
"\n", | |
"Epoch 00004: val_sparse_categorical_accuracy improved from 0.98060 to 0.98170, saving model to mnist-model\n", | |
"INFO:tensorflow:Assets written to: mnist-model/assets\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Assets written to: mnist-model/assets\n", | |
"/home/allenday/venv/tf/lib/python3.6/site-packages/keras/utils/generic_utils.py:497: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n", | |
" category=CustomMaskWarning)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 5/5\n", | |
"150/150 [==============================] - 4s 24ms/step - loss: 0.0259 - sparse_categorical_accuracy: 0.9929 - val_loss: 0.0509 - val_sparse_categorical_accuracy: 0.9832\n", | |
"\n", | |
"Epoch 00005: val_sparse_categorical_accuracy improved from 0.98170 to 0.98320, saving model to mnist-model\n", | |
"INFO:tensorflow:Assets written to: mnist-model/assets\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Assets written to: mnist-model/assets\n", | |
"/home/allenday/venv/tf/lib/python3.6/site-packages/keras/utils/generic_utils.py:497: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.\n", | |
" category=CustomMaskWarning)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<keras.callbacks.History at 0x7ffa206281d0>" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.fit(\n", | |
" ds_train,\n", | |
" epochs=EPOCHS,\n", | |
" validation_data=ds_validation,\n", | |
" callbacks=[cbCheckPoint, cbEarlyStop, cbTensorBoard],\n", | |
" verbose=1,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "90dc0b84", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"true=3\n", | |
"prob=tf.Tensor(\n", | |
"[0.1004887 0.09967535 0.10066849 0.09976515 0.09914842 0.10214555\n", | |
" 0.09826583 0.09943706 0.09950258 0.10090292], shape=(10,), dtype=float32)\n", | |
"pred=5\t0.102145545\n" | |
] | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAIAAAD9b0jDAAAB6UlEQVR4nO2Uva8hURyGT8b4nEaDAoVK4aNRShQKiU5rIjT+gGkoRhQSjaChQD9RINGoFEKtnIRIRMZHQggxEjKT+N3ZQlZu7iI79la7+5TvyXnOez5yEPrPX4JCofD5fKVSSRCEj59IkuRwON40er3eTqcDj6Bp+tks/NmAUqnMZDIURanVaoTQ9Xqt1WqiKDqdzkAggBDa7/eyO8ZisVujzWZTKBSMRuMtTyQSt/yd7VMUNRwOSZK02Wz3MBQK8TwPAI1GgyAI2VIMw3Q63eeEpunz+QwAHMdptVrZxi8QBBEOhwVBAIDBYOByuf7UqFKpUqnU/dLX63Uul7NYLG/qcByPx+Pz+fzX98SyrMFgkG1UKBT5fP5u2W63vV4vEolks9nVagUAfr//aZsXXkmSEEIsy3a73Wq1OpvN7uslk0m3293r9d4pq9frNRrN59Dj8ZxOp9fv9FVTADgej1/CYDBIEATP85fLRXbNh5hMJo7jAKBcLn+PESHEMAwALJdLHH+1RRk4HI7FYgEA6XT6e4wIoel0CgCtVgvDMNmT7XZ7u90WRVEUxXq9XiwWGYYRRREALpcLSZLvNKpUKg9/5clkEo1Gf8fw4Lz7/b7VajWbzbvdbjwe38JmszkajQ6Hwzs1/zF+AAjUMpuD885eAAAAAElFTkSuQmCC\n", | |
"text/plain": [ | |
"<PIL.Image.Image image mode=RGB size=28x28 at 0x7FF9AC0C8A58>" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"OFF=88\n", | |
"# QA on training data\n", | |
"a = list(ds_train.take(1))[0] \n", | |
"b = a[0][OFF].numpy()\n", | |
"image = Image.fromarray(b.astype(np.uint8))\n", | |
"image_np = np.array(image)/ 127.5\n", | |
"input_tensor = tf.convert_to_tensor(image_np)\n", | |
"input_tensor = image_np[tf.newaxis, ...]\n", | |
"input_tensor.shape\n", | |
"detections = model.predict(input_tensor)[0]\n", | |
"detections = tf.nn.softmax(detections)\n", | |
"maxval = tf.math.argmax(detections)\n", | |
"maxlab = labels[maxval]\n", | |
"print(\"true=\" + labels[a[1][OFF]])\n", | |
"print(\"prob=\" + str(detections))\n", | |
"print(\"pred=\" + maxlab + \"\\t\" + str(detections[maxval].numpy()))\n", | |
"image" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "e08e6ad9", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment