Last active
October 29, 2024 12:07
-
-
Save martinferianc/d6090fffb4c95efed6f1152d5fde079d to your computer and use it in GitHub Desktop.
Quantisation example in PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "PyTorch Quantisation.ipynb", | |
"provenance": [], | |
"collapsed_sections": [] | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU", | |
"widgets": { | |
"application/vnd.jupyter.widget-state+json": { | |
"4ef3ac3a55b1405cbbfb495d511f8c57": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"state": { | |
"_view_name": "HBoxView", | |
"_dom_classes": [], | |
"_model_name": "HBoxModel", | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"box_style": "", | |
"layout": "IPY_MODEL_63735be95ec2408abd58846e255a7547", | |
"_model_module": "@jupyter-widgets/controls", | |
"children": [ | |
"IPY_MODEL_09c7204e26a44c95872da1e227c30315", | |
"IPY_MODEL_1d71085620904406b58384c3bb0b1edf" | |
] | |
} | |
}, | |
"63735be95ec2408abd58846e255a7547": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
}, | |
"09c7204e26a44c95872da1e227c30315": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"state": { | |
"_view_name": "ProgressView", | |
"style": "IPY_MODEL_31805a8fbfde40cc9867a3643e829a44", | |
"_dom_classes": [], | |
"description": "", | |
"_model_name": "FloatProgressModel", | |
"bar_style": "info", | |
"max": 1, | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"value": 1, | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"orientation": "horizontal", | |
"min": 0, | |
"description_tooltip": null, | |
"_model_module": "@jupyter-widgets/controls", | |
"layout": "IPY_MODEL_e65e31421f9b40d4a3c59f36622e3163" | |
} | |
}, | |
"1d71085620904406b58384c3bb0b1edf": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"state": { | |
"_view_name": "HTMLView", | |
"style": "IPY_MODEL_485835aae0cf48feac7afade58a6a521", | |
"_dom_classes": [], | |
"description": "", | |
"_model_name": "HTMLModel", | |
"placeholder": "", | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"value": " 9920512/? [00:19<00:00, 1054429.46it/s]", | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"description_tooltip": null, | |
"_model_module": "@jupyter-widgets/controls", | |
"layout": "IPY_MODEL_502a360a8774499c8b9a950c34a684e5" | |
} | |
}, | |
"31805a8fbfde40cc9867a3643e829a44": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"state": { | |
"_view_name": "StyleView", | |
"_model_name": "ProgressStyleModel", | |
"description_width": "initial", | |
"_view_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.2.0", | |
"bar_color": null, | |
"_model_module": "@jupyter-widgets/controls" | |
} | |
}, | |
"e65e31421f9b40d4a3c59f36622e3163": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
}, | |
"485835aae0cf48feac7afade58a6a521": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"state": { | |
"_view_name": "StyleView", | |
"_model_name": "DescriptionStyleModel", | |
"description_width": "", | |
"_view_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.2.0", | |
"_model_module": "@jupyter-widgets/controls" | |
} | |
}, | |
"502a360a8774499c8b9a950c34a684e5": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
}, | |
"c1d10c7a472146e6bc09d22632fd848e": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"state": { | |
"_view_name": "HBoxView", | |
"_dom_classes": [], | |
"_model_name": "HBoxModel", | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"box_style": "", | |
"layout": "IPY_MODEL_a5171886b4bc4c67b758ce052c1cee77", | |
"_model_module": "@jupyter-widgets/controls", | |
"children": [ | |
"IPY_MODEL_deb6986103c4477cad2b07f7a5ea4974", | |
"IPY_MODEL_a1653abe0c694d5d8676a4c09bfe5800" | |
] | |
} | |
}, | |
"a5171886b4bc4c67b758ce052c1cee77": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
}, | |
"deb6986103c4477cad2b07f7a5ea4974": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"state": { | |
"_view_name": "ProgressView", | |
"style": "IPY_MODEL_b18f211b24934aba92ec6c43d564061c", | |
"_dom_classes": [], | |
"description": "", | |
"_model_name": "FloatProgressModel", | |
"bar_style": "success", | |
"max": 1, | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"value": 1, | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"orientation": "horizontal", | |
"min": 0, | |
"description_tooltip": null, | |
"_model_module": "@jupyter-widgets/controls", | |
"layout": "IPY_MODEL_90a2ffa504ff42329da39e7c1497c888" | |
} | |
}, | |
"a1653abe0c694d5d8676a4c09bfe5800": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"state": { | |
"_view_name": "HTMLView", | |
"style": "IPY_MODEL_1fba5940cf064fa68317af9792da8c3b", | |
"_dom_classes": [], | |
"description": "", | |
"_model_name": "HTMLModel", | |
"placeholder": "", | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"value": " 32768/? [00:00<00:00, 112823.38it/s]", | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"description_tooltip": null, | |
"_model_module": "@jupyter-widgets/controls", | |
"layout": "IPY_MODEL_88accd52e8fe4ccd8907edf082b47d97" | |
} | |
}, | |
"b18f211b24934aba92ec6c43d564061c": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"state": { | |
"_view_name": "StyleView", | |
"_model_name": "ProgressStyleModel", | |
"description_width": "initial", | |
"_view_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.2.0", | |
"bar_color": null, | |
"_model_module": "@jupyter-widgets/controls" | |
} | |
}, | |
"90a2ffa504ff42329da39e7c1497c888": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
}, | |
"1fba5940cf064fa68317af9792da8c3b": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"state": { | |
"_view_name": "StyleView", | |
"_model_name": "DescriptionStyleModel", | |
"description_width": "", | |
"_view_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.2.0", | |
"_model_module": "@jupyter-widgets/controls" | |
} | |
}, | |
"88accd52e8fe4ccd8907edf082b47d97": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
}, | |
"f9f8b1e0739844809134cba5ab30e74d": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"state": { | |
"_view_name": "HBoxView", | |
"_dom_classes": [], | |
"_model_name": "HBoxModel", | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"box_style": "", | |
"layout": "IPY_MODEL_4722aa0e04264c76964d58b5e3a65a32", | |
"_model_module": "@jupyter-widgets/controls", | |
"children": [ | |
"IPY_MODEL_dec70d900ac442138b73f1d58da684af", | |
"IPY_MODEL_f369bec9f8ff41338384dd48065e2d59" | |
] | |
} | |
}, | |
"4722aa0e04264c76964d58b5e3a65a32": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
}, | |
"dec70d900ac442138b73f1d58da684af": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"state": { | |
"_view_name": "ProgressView", | |
"style": "IPY_MODEL_3f37affbd69e486dad7cddeb8c8c9cb5", | |
"_dom_classes": [], | |
"description": "", | |
"_model_name": "FloatProgressModel", | |
"bar_style": "info", | |
"max": 1, | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"value": 1, | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"orientation": "horizontal", | |
"min": 0, | |
"description_tooltip": null, | |
"_model_module": "@jupyter-widgets/controls", | |
"layout": "IPY_MODEL_84107f986d5b45bfb4d9a9c9f934d1de" | |
} | |
}, | |
"f369bec9f8ff41338384dd48065e2d59": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"state": { | |
"_view_name": "HTMLView", | |
"style": "IPY_MODEL_4725f0f119e04f81a780a2489a8ef8dd", | |
"_dom_classes": [], | |
"description": "", | |
"_model_name": "HTMLModel", | |
"placeholder": "", | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"value": " 1654784/? [00:18<00:00, 546804.35it/s]", | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"description_tooltip": null, | |
"_model_module": "@jupyter-widgets/controls", | |
"layout": "IPY_MODEL_ce6dcf9416d141c9b5d5ad1e6394b621" | |
} | |
}, | |
"3f37affbd69e486dad7cddeb8c8c9cb5": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"state": { | |
"_view_name": "StyleView", | |
"_model_name": "ProgressStyleModel", | |
"description_width": "initial", | |
"_view_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.2.0", | |
"bar_color": null, | |
"_model_module": "@jupyter-widgets/controls" | |
} | |
}, | |
"84107f986d5b45bfb4d9a9c9f934d1de": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
}, | |
"4725f0f119e04f81a780a2489a8ef8dd": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"state": { | |
"_view_name": "StyleView", | |
"_model_name": "DescriptionStyleModel", | |
"description_width": "", | |
"_view_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.2.0", | |
"_model_module": "@jupyter-widgets/controls" | |
} | |
}, | |
"ce6dcf9416d141c9b5d5ad1e6394b621": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
}, | |
"a4a8ee6bf1da4549a0e1fcfdba392b17": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HBoxModel", | |
"state": { | |
"_view_name": "HBoxView", | |
"_dom_classes": [], | |
"_model_name": "HBoxModel", | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"box_style": "", | |
"layout": "IPY_MODEL_aaf6e3a398a84014ad173b40d54a1a45", | |
"_model_module": "@jupyter-widgets/controls", | |
"children": [ | |
"IPY_MODEL_777a4203229d4a00b9dbf1e7603e0d36", | |
"IPY_MODEL_093c877eaa43432eabc2aa71c7f2b587" | |
] | |
} | |
}, | |
"aaf6e3a398a84014ad173b40d54a1a45": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
}, | |
"777a4203229d4a00b9dbf1e7603e0d36": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "FloatProgressModel", | |
"state": { | |
"_view_name": "ProgressView", | |
"style": "IPY_MODEL_5840df2790f24f2c9c8d4d8427a2922b", | |
"_dom_classes": [], | |
"description": "", | |
"_model_name": "FloatProgressModel", | |
"bar_style": "success", | |
"max": 1, | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"value": 1, | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"orientation": "horizontal", | |
"min": 0, | |
"description_tooltip": null, | |
"_model_module": "@jupyter-widgets/controls", | |
"layout": "IPY_MODEL_3c122842f89b418cbf2133d83e5ff48e" | |
} | |
}, | |
"093c877eaa43432eabc2aa71c7f2b587": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "HTMLModel", | |
"state": { | |
"_view_name": "HTMLView", | |
"style": "IPY_MODEL_1c08dece76594deea344614336c0991d", | |
"_dom_classes": [], | |
"description": "", | |
"_model_name": "HTMLModel", | |
"placeholder": "", | |
"_view_module": "@jupyter-widgets/controls", | |
"_model_module_version": "1.5.0", | |
"value": " 8192/? [00:00<00:00, 19092.15it/s]", | |
"_view_count": null, | |
"_view_module_version": "1.5.0", | |
"description_tooltip": null, | |
"_model_module": "@jupyter-widgets/controls", | |
"layout": "IPY_MODEL_24c95a6a11b449c0b4c698ce5bdda18b" | |
} | |
}, | |
"5840df2790f24f2c9c8d4d8427a2922b": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "ProgressStyleModel", | |
"state": { | |
"_view_name": "StyleView", | |
"_model_name": "ProgressStyleModel", | |
"description_width": "initial", | |
"_view_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.2.0", | |
"bar_color": null, | |
"_model_module": "@jupyter-widgets/controls" | |
} | |
}, | |
"3c122842f89b418cbf2133d83e5ff48e": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
}, | |
"1c08dece76594deea344614336c0991d": { | |
"model_module": "@jupyter-widgets/controls", | |
"model_name": "DescriptionStyleModel", | |
"state": { | |
"_view_name": "StyleView", | |
"_model_name": "DescriptionStyleModel", | |
"description_width": "", | |
"_view_module": "@jupyter-widgets/base", | |
"_model_module_version": "1.5.0", | |
"_view_count": null, | |
"_view_module_version": "1.2.0", | |
"_model_module": "@jupyter-widgets/controls" | |
} | |
}, | |
"24c95a6a11b449c0b4c698ce5bdda18b": { | |
"model_module": "@jupyter-widgets/base", | |
"model_name": "LayoutModel", | |
"state": { | |
"_view_name": "LayoutView", | |
"grid_template_rows": null, | |
"right": null, | |
"justify_content": null, | |
"_view_module": "@jupyter-widgets/base", | |
"overflow": null, | |
"_model_module_version": "1.2.0", | |
"_view_count": null, | |
"flex_flow": null, | |
"width": null, | |
"min_width": null, | |
"border": null, | |
"align_items": null, | |
"bottom": null, | |
"_model_module": "@jupyter-widgets/base", | |
"top": null, | |
"grid_column": null, | |
"overflow_y": null, | |
"overflow_x": null, | |
"grid_auto_flow": null, | |
"grid_area": null, | |
"grid_template_columns": null, | |
"flex": null, | |
"_model_name": "LayoutModel", | |
"justify_items": null, | |
"grid_row": null, | |
"max_height": null, | |
"align_content": null, | |
"visibility": null, | |
"align_self": null, | |
"height": null, | |
"min_height": null, | |
"padding": null, | |
"grid_auto_rows": null, | |
"grid_gap": null, | |
"max_width": null, | |
"order": null, | |
"_view_module_version": "1.2.0", | |
"grid_template_areas": null, | |
"object_position": null, | |
"object_fit": null, | |
"grid_auto_columns": null, | |
"margin": null, | |
"display": null, | |
"left": null | |
} | |
} | |
} | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "EENlZvOtPDZ6", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Quantization tutorial\n", | |
"\n", | |
"This tutorial shows how to do post-training static quantization, as well as illustrating two more advanced techniques - per-channel quantization and quantization-aware training - to further improve the model’s accuracy. The task is to classify MNIST digits with a simple LeNet architecture. It is also hosted on Google Colab: [](https://colab.research.google.com/drive/1ptzMOHcU5IrtWaSjHvxGYsX6BozcxgF6?usp=sharing)\n", | |
"\n", | |
"\n", | |
"Thsi is a mimialistic tutorial to show you a starting point for quantisation in PyTorch. For theory and more in-depth explanations of what is acutally happening I would recommend to check out: [Quantizing deep convolutional networks for efficient inference: A whitepaper\n", | |
"](https://arxiv.org/abs/1806.08342). \n", | |
"\n", | |
"The tutorial is heavily adapted from: https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "zTvIwDlYvBzC", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"### Initial Setup\n", | |
"\n", | |
"Before beginning the assignment, we import the MNIST dataset, and train a simple convolutional neural network (CNN) to classify it." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "hbiiMcdNJI--", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import torch\n", | |
"import torchvision\n", | |
"import torchvision.transforms as transforms\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"import torch.optim as optim\n", | |
"import os\n", | |
"from torch.utils.data import DataLoader\n", | |
"import torch.quantization\n", | |
"from torch.quantization import QuantStub, DeQuantStub" | |
], | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "nCaMDWYArEXO", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Load training and test data from the MNIST dataset and apply a normalizing transformation.\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "_5UuOjjrnogR", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 386, | |
"referenced_widgets": [ | |
"4ef3ac3a55b1405cbbfb495d511f8c57", | |
"63735be95ec2408abd58846e255a7547", | |
"09c7204e26a44c95872da1e227c30315", | |
"1d71085620904406b58384c3bb0b1edf", | |
"31805a8fbfde40cc9867a3643e829a44", | |
"e65e31421f9b40d4a3c59f36622e3163", | |
"485835aae0cf48feac7afade58a6a521", | |
"502a360a8774499c8b9a950c34a684e5", | |
"c1d10c7a472146e6bc09d22632fd848e", | |
"a5171886b4bc4c67b758ce052c1cee77", | |
"deb6986103c4477cad2b07f7a5ea4974", | |
"a1653abe0c694d5d8676a4c09bfe5800", | |
"b18f211b24934aba92ec6c43d564061c", | |
"90a2ffa504ff42329da39e7c1497c888", | |
"1fba5940cf064fa68317af9792da8c3b", | |
"88accd52e8fe4ccd8907edf082b47d97", | |
"f9f8b1e0739844809134cba5ab30e74d", | |
"4722aa0e04264c76964d58b5e3a65a32", | |
"dec70d900ac442138b73f1d58da684af", | |
"f369bec9f8ff41338384dd48065e2d59", | |
"3f37affbd69e486dad7cddeb8c8c9cb5", | |
"84107f986d5b45bfb4d9a9c9f934d1de", | |
"4725f0f119e04f81a780a2489a8ef8dd", | |
"ce6dcf9416d141c9b5d5ad1e6394b621", | |
"a4a8ee6bf1da4549a0e1fcfdba392b17", | |
"aaf6e3a398a84014ad173b40d54a1a45", | |
"777a4203229d4a00b9dbf1e7603e0d36", | |
"093c877eaa43432eabc2aa71c7f2b587", | |
"5840df2790f24f2c9c8d4d8427a2922b", | |
"3c122842f89b418cbf2133d83e5ff48e", | |
"1c08dece76594deea344614336c0991d", | |
"24c95a6a11b449c0b4c698ce5bdda18b" | |
] | |
}, | |
"outputId": "84bb5a95-be28-428d-e2b5-4ff5b03b8aa0" | |
}, | |
"source": [ | |
"transform = transforms.Compose(\n", | |
" [transforms.ToTensor(),\n", | |
" transforms.Normalize((0.5,), (0.5,))])\n", | |
"\n", | |
"trainset = torchvision.datasets.MNIST(root='./data', train=True,\n", | |
" download=True, transform=transform)\n", | |
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,\n", | |
" shuffle=True, num_workers=16, pin_memory=True)\n", | |
"\n", | |
"testset = torchvision.datasets.MNIST(root='./data', train=False,\n", | |
" download=True, transform=transform)\n", | |
"testloader = torch.utils.data.DataLoader(testset, batch_size=64,\n", | |
" shuffle=False, num_workers=16, pin_memory=True)" | |
], | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "4ef3ac3a55b1405cbbfb495d511f8c57", | |
"version_minor": 0, | |
"version_major": 2 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw\n", | |
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "c1d10c7a472146e6bc09d22632fd848e", | |
"version_minor": 0, | |
"version_major": 2 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw\n", | |
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz\n", | |
"\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "f9f8b1e0739844809134cba5ab30e74d", | |
"version_minor": 0, | |
"version_major": 2 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw\n", | |
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "a4a8ee6bf1da4549a0e1fcfdba392b17", | |
"version_minor": 0, | |
"version_major": 2 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw\n", | |
"Processing...\n", | |
"Done!\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py:469: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:141.)\n", | |
" return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)\n" | |
], | |
"name": "stderr" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "aG5qXPDxnUnj", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Define some helper functions and classes that help us to track the statistics and accuracy with respect to the train/test data." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "WetzHpQybN1k", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "f772c10f-da62-46c7-ccdb-c217a34f9a3d" | |
}, | |
"source": [ | |
"class AverageMeter(object):\n", | |
" \"\"\"Computes and stores the average and current value\"\"\"\n", | |
" def __init__(self, name, fmt=':f'):\n", | |
" self.name = name\n", | |
" self.fmt = fmt\n", | |
" self.reset()\n", | |
"\n", | |
" def reset(self):\n", | |
" self.val = 0\n", | |
" self.avg = 0\n", | |
" self.sum = 0\n", | |
" self.count = 0\n", | |
"\n", | |
" def update(self, val, n=1):\n", | |
" self.val = val\n", | |
" self.sum += val * n\n", | |
" self.count += n\n", | |
" self.avg = self.sum / self.count\n", | |
"\n", | |
" def __str__(self):\n", | |
" fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'\n", | |
" return fmtstr.format(**self.__dict__)\n", | |
"\n", | |
"def accuracy(output, target):\n", | |
" \"\"\" Computes the top 1 accuracy \"\"\"\n", | |
" with torch.no_grad():\n", | |
" batch_size = target.size(0)\n", | |
"\n", | |
" _, pred = output.topk(1, 1, True, True)\n", | |
" pred = pred.t()\n", | |
" correct = pred.eq(target.view(1, -1).expand_as(pred))\n", | |
"\n", | |
" res = []\n", | |
" correct_one = correct[:1].view(-1).float().sum(0, keepdim=True)\n", | |
" return correct_one.mul_(100.0 / batch_size).item()\n", | |
"\n", | |
"def print_size_of_model(model):\n", | |
" \"\"\" Prints the real size of the model \"\"\"\n", | |
" torch.save(model.state_dict(), \"temp.p\")\n", | |
" print('Size (MB):', os.path.getsize(\"temp.p\")/1e6)\n", | |
" os.remove('temp.p')\n", | |
"\n", | |
"def load_model(quantized_model, model):\n", | |
" \"\"\" Loads in the weights into an object meant for quantization \"\"\"\n", | |
" state_dict = model.state_dict()\n", | |
" model = model.to('cpu')\n", | |
" quantized_model.load_state_dict(state_dict)\n", | |
"\n", | |
"def fuse_modules(model):\n", | |
" \"\"\" Fuse together convolutions/linear layers and ReLU \"\"\"\n", | |
" torch.quantization.fuse_modules(model, [['conv1', 'relu1'], \n", | |
" ['conv2', 'relu2'],\n", | |
" ['fc1', 'relu3'],\n", | |
" ['fc2', 'relu4']], inplace=True)" | |
], | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "l62CkyIwtSOv", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Define a simple CNN that classifies MNIST images.\n", | |
"\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "9fL3F-7Rntog", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"class Net(nn.Module):\n", | |
" def __init__(self, q = False):\n", | |
" # By turning on Q we can turn on/off the quantization\n", | |
" super(Net, self).__init__()\n", | |
" self.conv1 = nn.Conv2d(1, 6, 5, bias=False)\n", | |
" self.relu1 = nn.ReLU()\n", | |
" self.pool1 = nn.MaxPool2d(2, 2)\n", | |
" self.conv2 = nn.Conv2d(6, 16, 5, bias=False)\n", | |
" self.relu2 = nn.ReLU()\n", | |
" self.pool2 = nn.MaxPool2d(2, 2)\n", | |
" self.fc1 = nn.Linear(256, 120, bias=False)\n", | |
" self.relu3 = nn.ReLU()\n", | |
" self.fc2 = nn.Linear(120, 84, bias=False)\n", | |
" self.relu4 = nn.ReLU()\n", | |
" self.fc3 = nn.Linear(84, 10, bias=False)\n", | |
" self.q = q\n", | |
" if q:\n", | |
" self.quant = QuantStub()\n", | |
" self.dequant = DeQuantStub()\n", | |
"\n", | |
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n", | |
" if self.q:\n", | |
" x = self.quant(x)\n", | |
" x = self.conv1(x)\n", | |
" x = self.relu1(x)\n", | |
" x = self.pool1(x)\n", | |
" x = self.conv2(x)\n", | |
" x = self.relu2(x)\n", | |
" x = self.pool2(x)\n", | |
" # Be careful to use reshape here instead of view\n", | |
" x = x.reshape(x.shape[0], -1)\n", | |
" x = self.fc1(x)\n", | |
" x = self.relu3(x)\n", | |
" x = self.fc2(x)\n", | |
" x = self.relu4(x)\n", | |
" x = self.fc3(x)\n", | |
" if self.q:\n", | |
" x = self.dequant(x)\n", | |
" return x" | |
], | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "W9_LdxSTb3BJ", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 136 | |
}, | |
"outputId": "b5223438-8aed-4612-f040-120188d23e3c" | |
}, | |
"source": [ | |
"net = Net(q=False).cuda()\n", | |
"print_size_of_model(net)" | |
], | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"/usr/local/lib/python3.6/dist-packages/torch/cuda/__init__.py:125: UserWarning: \n", | |
"Tesla T4 with CUDA capability sm_75 is not compatible with the current PyTorch installation.\n", | |
"The current PyTorch install supports CUDA capabilities sm_37 sm_50 sm_60 sm_70.\n", | |
"If you want to use the Tesla T4 GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/\n", | |
"\n", | |
" warnings.warn(incompatible_device_warn.format(device_name, capability, \" \".join(arch_list), device_name))\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Size (MB): 0.178947\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Nijieuxptag6", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Train this CNN on the training dataset (this may take a few moments)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "CzK6ohj5oNCT", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def train(model: nn.Module, dataloader: DataLoader, cuda=False, q=False):\n", | |
" criterion = nn.CrossEntropyLoss()\n", | |
" optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n", | |
" model.train()\n", | |
" for epoch in range(20): # loop over the dataset multiple times\n", | |
"\n", | |
" running_loss = AverageMeter('loss')\n", | |
" acc = AverageMeter('train_acc')\n", | |
" for i, data in enumerate(dataloader, 0):\n", | |
" # get the inputs; data is a list of [inputs, labels]\n", | |
" inputs, labels = data\n", | |
" if cuda:\n", | |
" inputs = inputs.cuda()\n", | |
" labels = labels.cuda()\n", | |
"\n", | |
" # zero the parameter gradients\n", | |
" optimizer.zero_grad()\n", | |
"\n", | |
" if epoch>=3 and q:\n", | |
" model.apply(torch.quantization.disable_observer)\n", | |
"\n", | |
" # forward + backward + optimize\n", | |
" outputs = model(inputs)\n", | |
" loss = criterion(outputs, labels)\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
"\n", | |
" # print statistics\n", | |
" running_loss.update(loss.item(), outputs.shape[0])\n", | |
" acc.update(accuracy(outputs, labels), outputs.shape[0])\n", | |
" if i % 100 == 0: # print every 100 mini-batches\n", | |
" print('[%d, %5d] ' %\n", | |
" (epoch + 1, i + 1), running_loss, acc)\n", | |
" print('Finished Training')\n", | |
"\n", | |
"\n", | |
"def test(model: nn.Module, dataloader: DataLoader, cuda=False) -> float:\n", | |
" correct = 0\n", | |
" total = 0\n", | |
" model.eval()\n", | |
" with torch.no_grad():\n", | |
" for data in dataloader:\n", | |
" inputs, labels = data\n", | |
"\n", | |
" if cuda:\n", | |
" inputs = inputs.cuda()\n", | |
" labels = labels.cuda()\n", | |
"\n", | |
" outputs = model(inputs)\n", | |
" _, predicted = torch.max(outputs.data, 1)\n", | |
" total += labels.size(0)\n", | |
" correct += (predicted == labels).sum().item()\n", | |
" \n", | |
" return 100 * correct / total" | |
], | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "HixhBHaqtmZU", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000 | |
}, | |
"outputId": "60a4467f-d04f-4e89-edb6-80d33dafaa8c" | |
}, | |
"source": [ | |
"train(net, trainloader, cuda=True)" | |
], | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[1, 1] loss 2.304688 (2.304688) train_acc 7.812500 (7.812500)\n", | |
"[1, 101] loss 2.298546 (2.301924) train_acc 14.062500 (11.154084)\n", | |
"[1, 201] loss 2.298169 (2.300080) train_acc 18.750000 (12.927550)\n", | |
"[1, 301] loss 2.287714 (2.298172) train_acc 18.750000 (14.150748)\n", | |
"[1, 401] loss 2.284940 (2.295719) train_acc 18.750000 (15.410692)\n", | |
"[1, 501] loss 2.268452 (2.292275) train_acc 31.250000 (16.869386)\n", | |
"[1, 601] loss 2.245095 (2.287118) train_acc 43.750000 (18.711002)\n", | |
"[1, 701] loss 2.186713 (2.278292) train_acc 45.312500 (21.143902)\n", | |
"[1, 801] loss 1.973244 (2.257265) train_acc 64.062500 (24.475265)\n", | |
"[1, 901] loss 1.177035 (2.186206) train_acc 71.875000 (28.799598)\n", | |
"[2, 1] loss 0.966906 (0.966906) train_acc 75.000000 (75.000000)\n", | |
"[2, 101] loss 0.651747 (0.683141) train_acc 71.875000 (79.146040)\n", | |
"[2, 201] loss 0.314332 (0.600203) train_acc 92.187500 (81.825249)\n", | |
"[2, 301] loss 0.536951 (0.552717) train_acc 87.500000 (83.284884)\n", | |
"[2, 401] loss 0.487067 (0.516860) train_acc 82.812500 (84.289277)\n", | |
"[2, 501] loss 0.214907 (0.485390) train_acc 93.750000 (85.170284)\n", | |
"[2, 601] loss 0.267149 (0.459768) train_acc 92.187500 (86.025894)\n", | |
"[2, 701] loss 0.268161 (0.438155) train_acc 92.187500 (86.697575)\n", | |
"[2, 801] loss 0.184571 (0.420378) train_acc 96.875000 (87.273720)\n", | |
"[2, 901] loss 0.207725 (0.404604) train_acc 93.750000 (87.777469)\n", | |
"[3, 1] loss 0.135524 (0.135524) train_acc 96.875000 (96.875000)\n", | |
"[3, 101] loss 0.464874 (0.251084) train_acc 85.937500 (92.527847)\n", | |
"[3, 201] loss 0.139773 (0.241396) train_acc 96.875000 (92.716107)\n", | |
"[3, 301] loss 0.220053 (0.238938) train_acc 95.312500 (92.732558)\n", | |
"[3, 401] loss 0.188812 (0.231188) train_acc 92.187500 (92.962905)\n", | |
"[3, 501] loss 0.149028 (0.227094) train_acc 93.750000 (93.032685)\n", | |
"[3, 601] loss 0.312316 (0.222020) train_acc 87.500000 (93.167637)\n", | |
"[3, 701] loss 0.171428 (0.217455) train_acc 95.312500 (93.295292)\n", | |
"[3, 801] loss 0.167455 (0.212298) train_acc 95.312500 (93.463249)\n", | |
"[3, 901] loss 0.075115 (0.207077) train_acc 96.875000 (93.651151)\n", | |
"[4, 1] loss 0.057395 (0.057395) train_acc 98.437500 (98.437500)\n", | |
"[4, 101] loss 0.313285 (0.156543) train_acc 87.500000 (95.049505)\n", | |
"[4, 201] loss 0.101192 (0.152110) train_acc 98.437500 (95.304726)\n", | |
"[4, 301] loss 0.210053 (0.151548) train_acc 95.312500 (95.395556)\n", | |
"[4, 401] loss 0.089693 (0.151048) train_acc 98.437500 (95.370948)\n", | |
"[4, 501] loss 0.119123 (0.148517) train_acc 95.312500 (95.487151)\n", | |
"[4, 601] loss 0.071205 (0.145240) train_acc 98.437500 (95.572483)\n", | |
"[4, 701] loss 0.250533 (0.144675) train_acc 93.750000 (95.586662)\n", | |
"[4, 801] loss 0.203986 (0.143856) train_acc 93.750000 (95.581695)\n", | |
"[4, 901] loss 0.110775 (0.143300) train_acc 92.187500 (95.607311)\n", | |
"[5, 1] loss 0.543333 (0.543333) train_acc 90.625000 (90.625000)\n", | |
"[5, 101] loss 0.079244 (0.115989) train_acc 96.875000 (96.813119)\n", | |
"[5, 201] loss 0.075053 (0.113501) train_acc 98.437500 (96.719527)\n", | |
"[5, 301] loss 0.099348 (0.112655) train_acc 96.875000 (96.620640)\n", | |
"[5, 401] loss 0.063704 (0.114829) train_acc 98.437500 (96.524314)\n", | |
"[5, 501] loss 0.100347 (0.117312) train_acc 98.437500 (96.460205)\n", | |
"[5, 601] loss 0.147361 (0.116615) train_acc 96.875000 (96.461626)\n", | |
"[5, 701] loss 0.056320 (0.116270) train_acc 98.437500 (96.449269)\n", | |
"[5, 801] loss 0.116084 (0.114737) train_acc 90.625000 (96.484863)\n", | |
"[5, 901] loss 0.182074 (0.114901) train_acc 92.187500 (96.463998)\n", | |
"[6, 1] loss 0.168051 (0.168051) train_acc 96.875000 (96.875000)\n", | |
"[6, 101] loss 0.056149 (0.104143) train_acc 98.437500 (97.060644)\n", | |
"[6, 201] loss 0.036050 (0.102476) train_acc 98.437500 (97.077114)\n", | |
"[6, 301] loss 0.036538 (0.101845) train_acc 98.437500 (97.051495)\n", | |
"[6, 401] loss 0.085093 (0.101028) train_acc 96.875000 (97.030860)\n", | |
"[6, 501] loss 0.132319 (0.100827) train_acc 95.312500 (97.009107)\n", | |
"[6, 601] loss 0.160901 (0.101214) train_acc 93.750000 (96.981593)\n", | |
"[6, 701] loss 0.061668 (0.099115) train_acc 98.437500 (97.008738)\n", | |
"[6, 801] loss 0.052418 (0.098736) train_acc 100.000000 (96.993992)\n", | |
"[6, 901] loss 0.080187 (0.097460) train_acc 95.312500 (97.027608)\n", | |
"[7, 1] loss 0.028568 (0.028568) train_acc 100.000000 (100.000000)\n", | |
"[7, 101] loss 0.050156 (0.087661) train_acc 98.437500 (97.261757)\n", | |
"[7, 201] loss 0.043363 (0.088958) train_acc 98.437500 (97.294776)\n", | |
"[7, 301] loss 0.119842 (0.090724) train_acc 96.875000 (97.264327)\n", | |
"[7, 401] loss 0.162804 (0.090694) train_acc 96.875000 (97.229582)\n", | |
"[7, 501] loss 0.057630 (0.088077) train_acc 98.437500 (97.296033)\n", | |
"[7, 601] loss 0.094305 (0.089035) train_acc 96.875000 (97.285774)\n", | |
"[7, 701] loss 0.024901 (0.088697) train_acc 100.000000 (97.302960)\n", | |
"[7, 801] loss 0.017982 (0.087811) train_acc 100.000000 (97.349017)\n", | |
"[7, 901] loss 0.093661 (0.086671) train_acc 96.875000 (97.386584)\n", | |
"[8, 1] loss 0.083651 (0.083651) train_acc 98.437500 (98.437500)\n", | |
"[8, 101] loss 0.155878 (0.078249) train_acc 96.875000 (97.555693)\n", | |
"[8, 201] loss 0.103502 (0.077470) train_acc 95.312500 (97.605721)\n", | |
"[8, 301] loss 0.036938 (0.077900) train_acc 98.437500 (97.591362)\n", | |
"[8, 401] loss 0.023906 (0.076936) train_acc 100.000000 (97.658198)\n", | |
"[8, 501] loss 0.034372 (0.074795) train_acc 100.000000 (97.713947)\n", | |
"[8, 601] loss 0.205080 (0.075102) train_acc 98.437500 (97.719946)\n", | |
"[8, 701] loss 0.121761 (0.075344) train_acc 95.312500 (97.713088)\n", | |
"[8, 801] loss 0.029189 (0.075255) train_acc 98.437500 (97.715746)\n", | |
"[8, 901] loss 0.050637 (0.075133) train_acc 98.437500 (97.738624)\n", | |
"[9, 1] loss 0.239281 (0.239281) train_acc 92.187500 (92.187500)\n", | |
"[9, 101] loss 0.118830 (0.064811) train_acc 98.437500 (97.911510)\n", | |
"[9, 201] loss 0.037551 (0.068871) train_acc 98.437500 (97.807836)\n", | |
"[9, 301] loss 0.073378 (0.070502) train_acc 96.875000 (97.721138)\n", | |
"[9, 401] loss 0.033203 (0.069951) train_acc 100.000000 (97.817955)\n", | |
"[9, 501] loss 0.061615 (0.069760) train_acc 100.000000 (97.813748)\n", | |
"[9, 601] loss 0.054376 (0.069367) train_acc 98.437500 (97.852537)\n", | |
"[9, 701] loss 0.011276 (0.068432) train_acc 100.000000 (97.895863)\n", | |
"[9, 801] loss 0.057917 (0.068809) train_acc 98.437500 (97.885456)\n", | |
"[9, 901] loss 0.117029 (0.068969) train_acc 95.312500 (97.882561)\n", | |
"[10, 1] loss 0.143128 (0.143128) train_acc 96.875000 (96.875000)\n", | |
"[10, 101] loss 0.010576 (0.060005) train_acc 100.000000 (98.097153)\n", | |
"[10, 201] loss 0.171911 (0.062318) train_acc 92.187500 (97.971082)\n", | |
"[10, 301] loss 0.159923 (0.063579) train_acc 95.312500 (97.959925)\n", | |
"[10, 401] loss 0.087808 (0.064096) train_acc 96.875000 (97.973815)\n", | |
"[10, 501] loss 0.015892 (0.063259) train_acc 100.000000 (98.007111)\n", | |
"[10, 601] loss 0.070107 (0.064106) train_acc 98.437500 (98.013727)\n", | |
"[10, 701] loss 0.102943 (0.064184) train_acc 98.437500 (98.002853)\n", | |
"[10, 801] loss 0.119761 (0.064228) train_acc 96.875000 (98.002497)\n", | |
"[10, 901] loss 0.085556 (0.063829) train_acc 98.437500 (98.033435)\n", | |
"[11, 1] loss 0.094940 (0.094940) train_acc 96.875000 (96.875000)\n", | |
"[11, 101] loss 0.061906 (0.057557) train_acc 98.437500 (98.189975)\n", | |
"[11, 201] loss 0.030660 (0.058833) train_acc 98.437500 (98.118781)\n", | |
"[11, 301] loss 0.216185 (0.060542) train_acc 95.312500 (98.094892)\n", | |
"[11, 401] loss 0.079462 (0.058608) train_acc 98.437500 (98.188123)\n", | |
"[11, 501] loss 0.104220 (0.059387) train_acc 95.312500 (98.153693)\n", | |
"[11, 601] loss 0.054998 (0.058962) train_acc 98.437500 (98.185316)\n", | |
"[11, 701] loss 0.089961 (0.058905) train_acc 96.875000 (98.192315)\n", | |
"[11, 801] loss 0.064019 (0.058255) train_acc 96.875000 (98.215122)\n", | |
"[11, 901] loss 0.008153 (0.058769) train_acc 100.000000 (98.205119)\n", | |
"[12, 1] loss 0.157928 (0.157928) train_acc 95.312500 (95.312500)\n", | |
"[12, 101] loss 0.058884 (0.051442) train_acc 98.437500 (98.576733)\n", | |
"[12, 201] loss 0.156951 (0.054435) train_acc 95.312500 (98.406405)\n", | |
"[12, 301] loss 0.050974 (0.057206) train_acc 98.437500 (98.245432)\n", | |
"[12, 401] loss 0.111937 (0.057305) train_acc 96.875000 (98.238778)\n", | |
"[12, 501] loss 0.027392 (0.055179) train_acc 98.437500 (98.328343)\n", | |
"[12, 601] loss 0.022829 (0.055833) train_acc 100.000000 (98.315308)\n", | |
"[12, 701] loss 0.013768 (0.055340) train_acc 100.000000 (98.305991)\n", | |
"[12, 801] loss 0.010317 (0.055156) train_acc 100.000000 (98.299001)\n", | |
"[12, 901] loss 0.022112 (0.055095) train_acc 98.437500 (98.295297)\n", | |
"[13, 1] loss 0.076422 (0.076422) train_acc 98.437500 (98.437500)\n", | |
"[13, 101] loss 0.030248 (0.051467) train_acc 98.437500 (98.452970)\n", | |
"[13, 201] loss 0.031034 (0.051938) train_acc 98.437500 (98.491915)\n", | |
"[13, 301] loss 0.008323 (0.050674) train_acc 100.000000 (98.489410)\n", | |
"[13, 401] loss 0.035221 (0.050260) train_acc 98.437500 (98.445293)\n", | |
"[13, 501] loss 0.010200 (0.051086) train_acc 100.000000 (98.425025)\n", | |
"[13, 601] loss 0.082217 (0.050737) train_acc 98.437500 (98.442700)\n", | |
"[13, 701] loss 0.014336 (0.050259) train_acc 100.000000 (98.457561)\n", | |
"[13, 801] loss 0.088226 (0.051342) train_acc 93.750000 (98.386782)\n", | |
"[13, 901] loss 0.076974 (0.050634) train_acc 98.437500 (98.420158)\n", | |
"[14, 1] loss 0.100131 (0.100131) train_acc 98.437500 (98.437500)\n", | |
"[14, 101] loss 0.040545 (0.047123) train_acc 96.875000 (98.452970)\n", | |
"[14, 201] loss 0.012062 (0.049064) train_acc 100.000000 (98.491915)\n", | |
"[14, 301] loss 0.006671 (0.046118) train_acc 100.000000 (98.593231)\n", | |
"[14, 401] loss 0.031865 (0.046901) train_acc 100.000000 (98.585567)\n", | |
"[14, 501] loss 0.011261 (0.048362) train_acc 100.000000 (98.543538)\n", | |
"[14, 601] loss 0.024894 (0.048486) train_acc 98.437500 (98.528494)\n", | |
"[14, 701] loss 0.020390 (0.049081) train_acc 100.000000 (98.497682)\n", | |
"[14, 801] loss 0.006450 (0.048856) train_acc 100.000000 (98.511626)\n", | |
"[14, 901] loss 0.033789 (0.048329) train_acc 98.437500 (98.529412)\n", | |
"[15, 1] loss 0.020200 (0.020200) train_acc 100.000000 (100.000000)\n", | |
"[15, 101] loss 0.037451 (0.048491) train_acc 96.875000 (98.654084)\n", | |
"[15, 201] loss 0.050812 (0.045352) train_acc 96.875000 (98.670709)\n", | |
"[15, 301] loss 0.067918 (0.045223) train_acc 98.437500 (98.655523)\n", | |
"[15, 401] loss 0.098601 (0.046091) train_acc 96.875000 (98.585567)\n", | |
"[15, 501] loss 0.015852 (0.045776) train_acc 100.000000 (98.599676)\n", | |
"[15, 601] loss 0.065180 (0.046779) train_acc 98.437500 (98.554493)\n", | |
"[15, 701] loss 0.029441 (0.046184) train_acc 98.437500 (98.542261)\n", | |
"[15, 801] loss 0.055909 (0.046825) train_acc 98.437500 (98.513577)\n", | |
"[15, 901] loss 0.140400 (0.046106) train_acc 96.875000 (98.520741)\n", | |
"[16, 1] loss 0.054232 (0.054232) train_acc 98.437500 (98.437500)\n", | |
"[16, 101] loss 0.026743 (0.043004) train_acc 100.000000 (98.654084)\n", | |
"[16, 201] loss 0.044241 (0.040062) train_acc 98.437500 (98.756219)\n", | |
"[16, 301] loss 0.010601 (0.041674) train_acc 100.000000 (98.748962)\n", | |
"[16, 401] loss 0.057229 (0.042652) train_acc 96.875000 (98.698566)\n", | |
"[16, 501] loss 0.037290 (0.043432) train_acc 98.437500 (98.643338)\n", | |
"[16, 601] loss 0.090433 (0.043415) train_acc 95.312500 (98.635087)\n", | |
"[16, 701] loss 0.036102 (0.043018) train_acc 98.437500 (98.662625)\n", | |
"[16, 801] loss 0.004385 (0.043026) train_acc 100.000000 (98.689139)\n", | |
"[16, 901] loss 0.152339 (0.043535) train_acc 96.875000 (98.666412)\n", | |
"[17, 1] loss 0.015599 (0.015599) train_acc 100.000000 (100.000000)\n", | |
"[17, 101] loss 0.022941 (0.041634) train_acc 98.437500 (98.746906)\n", | |
"[17, 201] loss 0.083984 (0.041910) train_acc 96.875000 (98.732898)\n", | |
"[17, 301] loss 0.020260 (0.040495) train_acc 100.000000 (98.723007)\n", | |
"[17, 401] loss 0.041039 (0.039550) train_acc 98.437500 (98.768703)\n", | |
"[17, 501] loss 0.007996 (0.040921) train_acc 100.000000 (98.715070)\n", | |
"[17, 601] loss 0.029095 (0.041137) train_acc 100.000000 (98.694884)\n", | |
"[17, 701] loss 0.030888 (0.040955) train_acc 98.437500 (98.698288)\n", | |
"[17, 801] loss 0.058018 (0.041377) train_acc 98.437500 (98.677434)\n", | |
"[17, 901] loss 0.082321 (0.041022) train_acc 96.875000 (98.690691)\n", | |
"[18, 1] loss 0.069109 (0.069109) train_acc 98.437500 (98.437500)\n", | |
"[18, 101] loss 0.013385 (0.034434) train_acc 100.000000 (98.917079)\n", | |
"[18, 201] loss 0.065414 (0.039689) train_acc 98.437500 (98.779540)\n", | |
"[18, 301] loss 0.024400 (0.037887) train_acc 98.437500 (98.811254)\n", | |
"[18, 401] loss 0.058656 (0.038318) train_acc 98.437500 (98.823254)\n", | |
"[18, 501] loss 0.077966 (0.038661) train_acc 98.437500 (98.824227)\n", | |
"[18, 601] loss 0.040021 (0.038291) train_acc 98.437500 (98.822275)\n", | |
"[18, 701] loss 0.005405 (0.038337) train_acc 100.000000 (98.811965)\n", | |
"[18, 801] loss 0.009531 (0.038522) train_acc 100.000000 (98.804229)\n", | |
"[18, 901] loss 0.056885 (0.038404) train_acc 96.875000 (98.812084)\n", | |
"[19, 1] loss 0.048130 (0.048130) train_acc 98.437500 (98.437500)\n", | |
"[19, 101] loss 0.018478 (0.037876) train_acc 98.437500 (98.700495)\n", | |
"[19, 201] loss 0.018664 (0.034083) train_acc 100.000000 (98.857276)\n", | |
"[19, 301] loss 0.029210 (0.034680) train_acc 100.000000 (98.816445)\n", | |
"[19, 401] loss 0.017671 (0.036185) train_acc 100.000000 (98.784289)\n", | |
"[19, 501] loss 0.007982 (0.035806) train_acc 100.000000 (98.793039)\n", | |
"[19, 601] loss 0.033168 (0.036059) train_acc 98.437500 (98.798877)\n", | |
"[19, 701] loss 0.010806 (0.036784) train_acc 100.000000 (98.791904)\n", | |
"[19, 801] loss 0.052159 (0.036979) train_acc 98.437500 (98.790574)\n", | |
"[19, 901] loss 0.019309 (0.037361) train_acc 100.000000 (98.775666)\n", | |
"[20, 1] loss 0.039824 (0.039824) train_acc 98.437500 (98.437500)\n", | |
"[20, 101] loss 0.036246 (0.035860) train_acc 96.875000 (98.948020)\n", | |
"[20, 201] loss 0.006700 (0.033980) train_acc 100.000000 (99.028296)\n", | |
"[20, 301] loss 0.032477 (0.035361) train_acc 98.437500 (98.935839)\n", | |
"[20, 401] loss 0.131802 (0.035441) train_acc 92.187500 (98.912874)\n", | |
"[20, 501] loss 0.028100 (0.035540) train_acc 98.437500 (98.911552)\n", | |
"[20, 601] loss 0.142932 (0.035561) train_acc 96.875000 (98.913270)\n", | |
"[20, 701] loss 0.008505 (0.035124) train_acc 100.000000 (98.916726)\n", | |
"[20, 801] loss 0.012377 (0.034781) train_acc 100.000000 (98.915418)\n", | |
"[20, 901] loss 0.036740 (0.035315) train_acc 98.437500 (98.895325)\n", | |
"Finished Training\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "EJggxnCVuRxU", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Now that the CNN has been trained, let's test it on our test dataset." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "y27_n-djuEdz", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "ec5b6136-ca04-4ba0-d591-e19229d6ec24" | |
}, | |
"source": [ | |
"score = test(net, testloader, cuda=True)\n", | |
"print('Accuracy of the network on the test images: {}% - FP32'.format(score))" | |
], | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Accuracy of the network on the test images: 98.65% - FP32\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "_Lp-ElDsrKua", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"### Post-training quantization\n", | |
"\n", | |
"Define a new quantized network architeture, where we also define the quantization and dequantization stubs that will be important at the start and at the end.\n", | |
"\n", | |
"Next, we’ll “fuse modules”; this can both make the model faster by saving on memory access while also improving numerical accuracy. While this can be used with any model, this is especially common with quantized models." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "X-nQWDXrhItv", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"qnet = Net(q=True)\n", | |
"load_model(qnet, net)\n", | |
"fuse_modules(qnet)" | |
], | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "wQQRNAEGYVUe", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 51 | |
}, | |
"outputId": "596f0caf-1be9-471f-8473-4e3df00d66d0" | |
}, | |
"source": [ | |
"print_size_of_model(qnet)\n", | |
"score = test(qnet, testloader, cuda=False)\n", | |
"print('Accuracy of the fused network on the test images: {}% - FP32'.format(score))" | |
], | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Size (MB): 0.179144\n", | |
"Accuracy of the fused network on the test images: 98.65% - FP32\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "qiaQkj6wJuC6", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Post-training static quantization involves not just converting the weights from float to int, as in dynamic quantization, but also performing the additional\n", | |
"step of first feeding batches of data through the network and computing the resulting distributions of the different activations (specifically,\n", | |
"this is done by inserting observer modules at different\n", | |
"points that record this data). These distributions are then used to determine how the specifically the different activations should be quantized at\n", | |
"inference time (a simple technique would be to simply divide the entire range of activations into 256 levels.\n", | |
"Importantly, this additional step allows us to pass quantized values between operations instead of converting these values to floats - and then back to ints - between every operation, \n", | |
"resulting in a significant speed-up." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "x-ZaMV4bUb6-", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 411 | |
}, | |
"outputId": "37a7ada4-2f61-4b57-80da-836c3e1ab469" | |
}, | |
"source": [ | |
"qnet.qconfig = torch.quantization.default_qconfig\n", | |
"print(qnet.qconfig)\n", | |
"torch.quantization.prepare(qnet, inplace=True)\n", | |
"print('Post Training Quantization Prepare: Inserting Observers')\n", | |
"print('\\n Conv1: After observer insertion \\n\\n', qnet.conv1)\n", | |
"\n", | |
"test(qnet, trainloader, cuda=False)\n", | |
"print('Post Training Quantization: Calibration done')\n", | |
"torch.quantization.convert(qnet, inplace=True)\n", | |
"print('Post Training Quantization: Convert done')\n", | |
"print('\\n Conv1: After fusion and quantization \\n\\n', qnet.conv1)\n", | |
"print(\"Size of model after quantization\")\n", | |
"print_size_of_model(qnet)" | |
], | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"QConfig(activation=functools.partial(<class 'torch.quantization.observer.MinMaxObserver'>, reduce_range=True), weight=functools.partial(<class 'torch.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))\n", | |
"Post Training Quantization Prepare: Inserting Observers\n", | |
"\n", | |
" Conv1: After observer insertion \n", | |
"\n", | |
" ConvReLU2d(\n", | |
" (0): Conv2d(\n", | |
" 1, 6, kernel_size=(5, 5), stride=(1, 1), bias=False\n", | |
" (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([]))\n", | |
" )\n", | |
" (1): ReLU(\n", | |
" (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([]))\n", | |
" )\n", | |
")\n", | |
"Post Training Quantization: Calibration done\n", | |
"Post Training Quantization: Convert done\n", | |
"\n", | |
" Conv1: After fusion and quantization \n", | |
"\n", | |
" QuantizedConvReLU2d(1, 6, kernel_size=(5, 5), stride=(1, 1), scale=0.07420612871646881, zero_point=0, bias=False)\n", | |
"Size of model after quantization\n", | |
"Size (MB): 0.050052\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "wbDvGBtMavCO", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "2b9f7749-3905-4f3d-bbec-1d87ac6299ee" | |
}, | |
"source": [ | |
"score = test(qnet, testloader, cuda=False)\n", | |
"print('Accuracy of the fused and quantized network on the test images: {}% - INT8'.format(score))" | |
], | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Accuracy of the fused and quantized network on the test images: 98.67% - INT8\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "lcv6Gi45lZ4L", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"We can also define a cusom quantization configuration, where we replace the default observers and instead of quantising with respect to max/min we can take an average of the observed max/min, hopefully for a better generalization performance. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "qNj6TNFu1ljn", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 428 | |
}, | |
"outputId": "9e9db78b-dd1b-4b43-b0dc-718d5d8491c5" | |
}, | |
"source": [ | |
"from torch.quantization.observer import MovingAverageMinMaxObserver\n", | |
"\n", | |
"qnet = Net(q=True)\n", | |
"load_model(qnet, net)\n", | |
"fuse_modules(qnet)\n", | |
"\n", | |
"qnet.qconfig = torch.quantization.QConfig(\n", | |
" activation=MovingAverageMinMaxObserver.with_args(reduce_range=True), \n", | |
" weight=MovingAverageMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))\n", | |
"print(qnet.qconfig)\n", | |
"torch.quantization.prepare(qnet, inplace=True)\n", | |
"print('Post Training Quantization Prepare: Inserting Observers')\n", | |
"print('\\n Conv1: After observer insertion \\n\\n', qnet.conv1)\n", | |
"\n", | |
"test(qnet, trainloader, cuda=False)\n", | |
"print('Post Training Quantization: Calibration done')\n", | |
"torch.quantization.convert(qnet, inplace=True)\n", | |
"print('Post Training Quantization: Convert done')\n", | |
"print('\\n Conv1: After fusion and quantization \\n\\n', qnet.conv1)\n", | |
"print(\"Size of model after quantization\")\n", | |
"print_size_of_model(qnet)\n", | |
"score = test(qnet, testloader, cuda=False)\n", | |
"print('Accuracy of the fused and quantized network on the test images: {}% - INT8'.format(score))" | |
], | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"QConfig(activation=functools.partial(<class 'torch.quantization.observer.MovingAverageMinMaxObserver'>, reduce_range=True), weight=functools.partial(<class 'torch.quantization.observer.MovingAverageMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric))\n", | |
"Post Training Quantization Prepare: Inserting Observers\n", | |
"\n", | |
" Conv1: After observer insertion \n", | |
"\n", | |
" ConvReLU2d(\n", | |
" (0): Conv2d(\n", | |
" 1, 6, kernel_size=(5, 5), stride=(1, 1), bias=False\n", | |
" (activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([]), max_val=tensor([]))\n", | |
" )\n", | |
" (1): ReLU(\n", | |
" (activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([]), max_val=tensor([]))\n", | |
" )\n", | |
")\n", | |
"Post Training Quantization: Calibration done\n", | |
"Post Training Quantization: Convert done\n", | |
"\n", | |
" Conv1: After fusion and quantization \n", | |
"\n", | |
" QuantizedConvReLU2d(1, 6, kernel_size=(5, 5), stride=(1, 1), scale=0.07174129039049149, zero_point=0, bias=False)\n", | |
"Size of model after quantization\n", | |
"Size (MB): 0.050052\n", | |
"Accuracy of the fused and quantized network on the test images: 98.69% - INT8\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "8LXNCT7fgcMx", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"In addition, we can significantly improve on the accuracy simply by using a different quantization configuration. We repeat the same exercise with the recommended configuration for quantizing for x86 architectures. This configuration does the following:\n", | |
"Quantizes weights on a per-channel basis. It \n", | |
"uses a histogram observer that collects a histogram of activations and then picks quantization parameters in an optimal manner." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "-nZq5yF_gWBs", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"qnet = Net(q=True)\n", | |
"load_model(qnet, net)\n", | |
"fuse_modules(qnet)" | |
], | |
"execution_count": 14, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "HXv5pAwVlGFh", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 88 | |
}, | |
"outputId": "d7119fc5-7aef-4da4-8f9a-1de9705b086f" | |
}, | |
"source": [ | |
"qnet.qconfig = torch.quantization.get_default_qconfig('fbgemm')\n", | |
"print(qnet.qconfig)\n", | |
"\n", | |
"torch.quantization.prepare(qnet, inplace=True)\n", | |
"test(qnet, trainloader, cuda=False)\n", | |
"torch.quantization.convert(qnet, inplace=True)\n", | |
"print(\"Size of model after quantization\")\n", | |
"print_size_of_model(qnet)" | |
], | |
"execution_count": 15, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"QConfig(activation=functools.partial(<class 'torch.quantization.observer.HistogramObserver'>, reduce_range=True), weight=functools.partial(<class 'torch.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric))\n", | |
"Size of model after quantization\n", | |
"Size (MB): 0.056182\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "X5Vjyayimv8n", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "348eb6fb-0873-4a36-97c9-03da9adc1a27" | |
}, | |
"source": [ | |
"score = test(qnet, testloader, cuda=False)\n", | |
"print('Accuracy of the fused and quantized network on the test images: {}% - INT8'.format(score))" | |
], | |
"execution_count": 16, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Accuracy of the fused and quantized network on the test images: 98.64% - INT8\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "5A_G3tsasU6U", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"### Quantization aware training\n", | |
"\n", | |
"Quantization-aware training (QAT) is the quantization method that typically results in the highest accuracy. With QAT, all weights and activations are “fake quantized” during both the forward and backward passes of training: that is, float values are rounded to mimic int8 values, but all computations are still done with floating point numbers. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "o-mGba7QsXzf", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 255 | |
}, | |
"outputId": "a49c21e3-4c8a-4dc5-bcc7-51f360eb0863" | |
}, | |
"source": [ | |
"qnet = Net(q=True)\n", | |
"fuse_modules(qnet)\n", | |
"qnet.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')\n", | |
"torch.quantization.prepare_qat(qnet, inplace=True)\n", | |
"print('\\n Conv1: After fusion and quantization \\n\\n', qnet.conv1)\n", | |
"qnet=qnet.cuda()" | |
], | |
"execution_count": 17, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
" Conv1: After fusion and quantization \n", | |
"\n", | |
" ConvReLU2d(\n", | |
" 1, 6, kernel_size=(5, 5), stride=(1, 1), bias=False\n", | |
" (activation_post_process): FakeQuantize(\n", | |
" fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), scale=tensor([1.]), zero_point=tensor([0])\n", | |
" (activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([]), max_val=tensor([]))\n", | |
" )\n", | |
" (weight_fake_quant): FakeQuantize(\n", | |
" fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), scale=tensor([1.]), zero_point=tensor([0])\n", | |
" (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))\n", | |
" )\n", | |
")\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "mmiecLHIuRI4", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000 | |
}, | |
"outputId": "d427a094-fe7c-4649-f9cb-8d22b023030b" | |
}, | |
"source": [ | |
"train(qnet, trainloader, cuda=True)" | |
], | |
"execution_count": 18, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[1, 1] loss 2.301113 (2.301113) train_acc 10.937500 (10.937500)\n", | |
"[1, 101] loss 2.304500 (2.300252) train_acc 17.187500 (14.511139)\n", | |
"[1, 201] loss 2.280855 (2.295559) train_acc 25.000000 (17.918221)\n", | |
"[1, 301] loss 2.280437 (2.289621) train_acc 18.750000 (20.681063)\n", | |
"[1, 401] loss 2.225562 (2.280254) train_acc 28.125000 (22.919264)\n", | |
"[1, 501] loss 2.135546 (2.262463) train_acc 32.812500 (24.120509)\n", | |
"[1, 601] loss 1.859752 (2.219802) train_acc 50.000000 (26.526102)\n", | |
"[1, 701] loss 1.159023 (2.118334) train_acc 76.562500 (31.361448)\n", | |
"[1, 801] loss 0.852159 (1.974637) train_acc 70.312500 (36.641698)\n", | |
"[1, 901] loss 0.576404 (1.831879) train_acc 82.812500 (41.374168)\n", | |
"[2, 1] loss 0.672451 (0.672451) train_acc 79.687500 (79.687500)\n", | |
"[2, 101] loss 0.583598 (0.543110) train_acc 78.125000 (83.539604)\n", | |
"[2, 201] loss 0.255931 (0.519815) train_acc 92.187500 (84.227301)\n", | |
"[2, 301] loss 0.608585 (0.494296) train_acc 82.812500 (84.930440)\n", | |
"[2, 401] loss 0.388425 (0.477482) train_acc 90.625000 (85.345231)\n", | |
"[2, 501] loss 0.305046 (0.461796) train_acc 89.062500 (85.875125)\n", | |
"[2, 601] loss 0.489494 (0.446556) train_acc 84.375000 (86.428869)\n", | |
"[2, 701] loss 0.262945 (0.435555) train_acc 90.625000 (86.706491)\n", | |
"[2, 801] loss 0.504077 (0.422689) train_acc 85.937500 (87.047441)\n", | |
"[2, 901] loss 0.158998 (0.411567) train_acc 95.312500 (87.392481)\n", | |
"[3, 1] loss 0.429101 (0.429101) train_acc 92.187500 (92.187500)\n", | |
"[3, 101] loss 0.228263 (0.287405) train_acc 93.750000 (91.336634)\n", | |
"[3, 201] loss 0.226833 (0.282767) train_acc 92.187500 (91.371269)\n", | |
"[3, 301] loss 0.298973 (0.275515) train_acc 92.187500 (91.626869)\n", | |
"[3, 401] loss 0.269069 (0.267401) train_acc 90.625000 (91.871883)\n", | |
"[3, 501] loss 0.192888 (0.262484) train_acc 90.625000 (91.984780)\n", | |
"[3, 601] loss 0.160487 (0.255122) train_acc 90.625000 (92.164101)\n", | |
"[3, 701] loss 0.229499 (0.250564) train_acc 93.750000 (92.285574)\n", | |
"[3, 801] loss 0.224272 (0.245883) train_acc 90.625000 (92.454744)\n", | |
"[3, 901] loss 0.159028 (0.240255) train_acc 93.750000 (92.612375)\n", | |
"[4, 1] loss 0.152180 (0.152180) train_acc 96.875000 (96.875000)\n", | |
"[4, 101] loss 0.182985 (0.183208) train_acc 92.187500 (94.399752)\n", | |
"[4, 201] loss 0.181597 (0.182311) train_acc 93.750000 (94.449627)\n", | |
"[4, 301] loss 0.288057 (0.179880) train_acc 93.750000 (94.601329)\n", | |
"[4, 401] loss 0.130797 (0.178715) train_acc 98.437500 (94.607232)\n", | |
"[4, 501] loss 0.192190 (0.174472) train_acc 93.750000 (94.748004)\n", | |
"[4, 601] loss 0.188324 (0.175655) train_acc 92.187500 (94.693740)\n", | |
"[4, 701] loss 0.115135 (0.172211) train_acc 96.875000 (94.773092)\n", | |
"[4, 801] loss 0.157602 (0.170105) train_acc 93.750000 (94.830680)\n", | |
"[4, 901] loss 0.093530 (0.168401) train_acc 95.312500 (94.847739)\n", | |
"[5, 1] loss 0.137994 (0.137994) train_acc 96.875000 (96.875000)\n", | |
"[5, 101] loss 0.125304 (0.132565) train_acc 95.312500 (96.209777)\n", | |
"[5, 201] loss 0.122497 (0.135991) train_acc 93.750000 (95.911070)\n", | |
"[5, 301] loss 0.072496 (0.137191) train_acc 96.875000 (95.826412)\n", | |
"[5, 401] loss 0.039640 (0.137389) train_acc 98.437500 (95.846322)\n", | |
"[5, 501] loss 0.189348 (0.135704) train_acc 93.750000 (95.852046)\n", | |
"[5, 601] loss 0.154829 (0.134721) train_acc 95.312500 (95.871464)\n", | |
"[5, 701] loss 0.069499 (0.132365) train_acc 98.437500 (95.936608)\n", | |
"[5, 801] loss 0.183797 (0.132483) train_acc 95.312500 (95.921114)\n", | |
"[5, 901] loss 0.281056 (0.131096) train_acc 90.625000 (95.959351)\n", | |
"[6, 1] loss 0.167680 (0.167680) train_acc 93.750000 (93.750000)\n", | |
"[6, 101] loss 0.152279 (0.112083) train_acc 93.750000 (96.689356)\n", | |
"[6, 201] loss 0.096758 (0.120158) train_acc 95.312500 (96.276430)\n", | |
"[6, 301] loss 0.053522 (0.116962) train_acc 98.437500 (96.402616)\n", | |
"[6, 401] loss 0.032996 (0.115049) train_acc 96.875000 (96.465867)\n", | |
"[6, 501] loss 0.077693 (0.114345) train_acc 98.437500 (96.466442)\n", | |
"[6, 601] loss 0.075313 (0.113844) train_acc 98.437500 (96.459027)\n", | |
"[6, 701] loss 0.201612 (0.113165) train_acc 93.750000 (96.502764)\n", | |
"[6, 801] loss 0.096951 (0.112418) train_acc 96.875000 (96.512172)\n", | |
"[6, 901] loss 0.064057 (0.110610) train_acc 98.437500 (96.592328)\n", | |
"[7, 1] loss 0.156796 (0.156796) train_acc 92.187500 (92.187500)\n", | |
"[7, 101] loss 0.045035 (0.092368) train_acc 98.437500 (97.091584)\n", | |
"[7, 201] loss 0.218562 (0.092530) train_acc 92.187500 (97.077114)\n", | |
"[7, 301] loss 0.135667 (0.094215) train_acc 96.875000 (97.030731)\n", | |
"[7, 401] loss 0.097405 (0.095017) train_acc 96.875000 (96.991895)\n", | |
"[7, 501] loss 0.104499 (0.094556) train_acc 93.750000 (96.993513)\n", | |
"[7, 601] loss 0.262179 (0.095343) train_acc 96.875000 (96.965994)\n", | |
"[7, 701] loss 0.178149 (0.095932) train_acc 96.875000 (96.973074)\n", | |
"[7, 801] loss 0.023350 (0.095856) train_acc 100.000000 (96.972534)\n", | |
"[7, 901] loss 0.184805 (0.096887) train_acc 95.312500 (96.965178)\n", | |
"[8, 1] loss 0.095157 (0.095157) train_acc 96.875000 (96.875000)\n", | |
"[8, 101] loss 0.059145 (0.081031) train_acc 98.437500 (97.679455)\n", | |
"[8, 201] loss 0.175315 (0.083644) train_acc 93.750000 (97.613495)\n", | |
"[8, 301] loss 0.021776 (0.082633) train_acc 100.000000 (97.627699)\n", | |
"[8, 401] loss 0.024037 (0.084577) train_acc 100.000000 (97.517924)\n", | |
"[8, 501] loss 0.028669 (0.084019) train_acc 100.000000 (97.517465)\n", | |
"[8, 601] loss 0.038613 (0.084986) train_acc 98.437500 (97.485961)\n", | |
"[8, 701] loss 0.103677 (0.085835) train_acc 96.875000 (97.456758)\n", | |
"[8, 801] loss 0.037021 (0.086333) train_acc 98.437500 (97.417291)\n", | |
"[8, 901] loss 0.007270 (0.085118) train_acc 100.000000 (97.445547)\n", | |
"[9, 1] loss 0.068532 (0.068532) train_acc 96.875000 (96.875000)\n", | |
"[9, 101] loss 0.050806 (0.070079) train_acc 98.437500 (97.865099)\n", | |
"[9, 201] loss 0.102909 (0.070620) train_acc 95.312500 (97.800062)\n", | |
"[9, 301] loss 0.038826 (0.073330) train_acc 98.437500 (97.715947)\n", | |
"[9, 401] loss 0.091265 (0.074437) train_acc 96.875000 (97.665991)\n", | |
"[9, 501] loss 0.090041 (0.076175) train_acc 95.312500 (97.601672)\n", | |
"[9, 601] loss 0.104969 (0.075018) train_acc 98.437500 (97.652350)\n", | |
"[9, 701] loss 0.151797 (0.075251) train_acc 95.312500 (97.666280)\n", | |
"[9, 801] loss 0.082873 (0.075931) train_acc 95.312500 (97.637718)\n", | |
"[9, 901] loss 0.108088 (0.077397) train_acc 98.437500 (97.610294)\n", | |
"[10, 1] loss 0.065671 (0.065671) train_acc 96.875000 (96.875000)\n", | |
"[10, 101] loss 0.092590 (0.072622) train_acc 96.875000 (97.586634)\n", | |
"[10, 201] loss 0.122239 (0.071050) train_acc 95.312500 (97.737873)\n", | |
"[10, 301] loss 0.047051 (0.069421) train_acc 98.437500 (97.804194)\n", | |
"[10, 401] loss 0.101737 (0.068446) train_acc 96.875000 (97.907575)\n", | |
"[10, 501] loss 0.016245 (0.068498) train_acc 100.000000 (97.894835)\n", | |
"[10, 601] loss 0.030443 (0.070868) train_acc 100.000000 (97.800541)\n", | |
"[10, 701] loss 0.059260 (0.069604) train_acc 98.437500 (97.835681)\n", | |
"[10, 801] loss 0.073055 (0.070661) train_acc 96.875000 (97.811330)\n", | |
"[10, 901] loss 0.058888 (0.070754) train_acc 96.875000 (97.807991)\n", | |
"[11, 1] loss 0.124136 (0.124136) train_acc 95.312500 (95.312500)\n", | |
"[11, 101] loss 0.049476 (0.061531) train_acc 98.437500 (97.834158)\n", | |
"[11, 201] loss 0.103255 (0.068060) train_acc 98.437500 (97.745647)\n", | |
"[11, 301] loss 0.025115 (0.067871) train_acc 100.000000 (97.783430)\n", | |
"[11, 401] loss 0.022324 (0.066899) train_acc 100.000000 (97.845231)\n", | |
"[11, 501] loss 0.103879 (0.067301) train_acc 95.312500 (97.851173)\n", | |
"[11, 601] loss 0.038985 (0.065532) train_acc 98.437500 (97.862937)\n", | |
"[11, 701] loss 0.082056 (0.065909) train_acc 96.875000 (97.862429)\n", | |
"[11, 801] loss 0.023336 (0.065132) train_acc 100.000000 (97.891308)\n", | |
"[11, 901] loss 0.004702 (0.065778) train_acc 100.000000 (97.866953)\n", | |
"[12, 1] loss 0.141098 (0.141098) train_acc 96.875000 (96.875000)\n", | |
"[12, 101] loss 0.054198 (0.060383) train_acc 98.437500 (98.097153)\n", | |
"[12, 201] loss 0.035249 (0.060116) train_acc 98.437500 (97.994403)\n", | |
"[12, 301] loss 0.062350 (0.058442) train_acc 98.437500 (98.094892)\n", | |
"[12, 401] loss 0.014107 (0.061312) train_acc 100.000000 (98.020574)\n", | |
"[12, 501] loss 0.011643 (0.061494) train_acc 100.000000 (98.053892)\n", | |
"[12, 601] loss 0.094898 (0.060842) train_acc 98.437500 (98.057924)\n", | |
"[12, 701] loss 0.087104 (0.062387) train_acc 98.437500 (98.011769)\n", | |
"[12, 801] loss 0.091956 (0.061537) train_acc 98.437500 (98.045412)\n", | |
"[12, 901] loss 0.058924 (0.061003) train_acc 98.437500 (98.090663)\n", | |
"[13, 1] loss 0.031761 (0.031761) train_acc 100.000000 (100.000000)\n", | |
"[13, 101] loss 0.074488 (0.065928) train_acc 98.437500 (97.973391)\n", | |
"[13, 201] loss 0.022381 (0.059269) train_acc 100.000000 (98.173197)\n", | |
"[13, 301] loss 0.047876 (0.057667) train_acc 98.437500 (98.183140)\n", | |
"[13, 401] loss 0.126233 (0.057638) train_acc 96.875000 (98.176434)\n", | |
"[13, 501] loss 0.010063 (0.057087) train_acc 100.000000 (98.219187)\n", | |
"[13, 601] loss 0.089117 (0.057305) train_acc 93.750000 (98.187916)\n", | |
"[13, 701] loss 0.107229 (0.058360) train_acc 96.875000 (98.163338)\n", | |
"[13, 801] loss 0.026806 (0.058444) train_acc 100.000000 (98.154650)\n", | |
"[13, 901] loss 0.063143 (0.057368) train_acc 96.875000 (98.192980)\n", | |
"[14, 1] loss 0.146043 (0.146043) train_acc 96.875000 (96.875000)\n", | |
"[14, 101] loss 0.015386 (0.051809) train_acc 100.000000 (98.391089)\n", | |
"[14, 201] loss 0.039313 (0.056064) train_acc 98.437500 (98.289801)\n", | |
"[14, 301] loss 0.028365 (0.054364) train_acc 98.437500 (98.312915)\n", | |
"[14, 401] loss 0.013911 (0.054484) train_acc 100.000000 (98.312812)\n", | |
"[14, 501] loss 0.017320 (0.054501) train_acc 100.000000 (98.303393)\n", | |
"[14, 601] loss 0.010553 (0.053376) train_acc 100.000000 (98.323107)\n", | |
"[14, 701] loss 0.003856 (0.054123) train_acc 100.000000 (98.288160)\n", | |
"[14, 801] loss 0.042699 (0.054347) train_acc 98.437500 (98.275593)\n", | |
"[14, 901] loss 0.016714 (0.053905) train_acc 100.000000 (98.302234)\n", | |
"[15, 1] loss 0.011831 (0.011831) train_acc 100.000000 (100.000000)\n", | |
"[15, 101] loss 0.025428 (0.045962) train_acc 98.437500 (98.592203)\n", | |
"[15, 201] loss 0.063596 (0.047157) train_acc 98.437500 (98.515236)\n", | |
"[15, 301] loss 0.057301 (0.049301) train_acc 96.875000 (98.473837)\n", | |
"[15, 401] loss 0.045780 (0.051966) train_acc 96.875000 (98.386845)\n", | |
"[15, 501] loss 0.137832 (0.051855) train_acc 96.875000 (98.390719)\n", | |
"[15, 601] loss 0.027245 (0.051376) train_acc 100.000000 (98.403702)\n", | |
"[15, 701] loss 0.079835 (0.051196) train_acc 96.875000 (98.419668)\n", | |
"[15, 801] loss 0.039238 (0.051334) train_acc 98.437500 (98.388733)\n", | |
"[15, 901] loss 0.065377 (0.050911) train_acc 98.437500 (98.404550)\n", | |
"[16, 1] loss 0.080587 (0.080587) train_acc 98.437500 (98.437500)\n", | |
"[16, 101] loss 0.190880 (0.055706) train_acc 93.750000 (98.128094)\n", | |
"[16, 201] loss 0.011572 (0.051067) train_acc 100.000000 (98.243159)\n", | |
"[16, 301] loss 0.035581 (0.050523) train_acc 98.437500 (98.349252)\n", | |
"[16, 401] loss 0.052534 (0.050344) train_acc 96.875000 (98.332294)\n", | |
"[16, 501] loss 0.024507 (0.050509) train_acc 98.437500 (98.356412)\n", | |
"[16, 601] loss 0.050702 (0.049475) train_acc 96.875000 (98.388103)\n", | |
"[16, 701] loss 0.022349 (0.049447) train_acc 98.437500 (98.399608)\n", | |
"[16, 801] loss 0.035069 (0.048531) train_acc 98.437500 (98.414092)\n", | |
"[16, 901] loss 0.053336 (0.048325) train_acc 98.437500 (98.420158)\n", | |
"[17, 1] loss 0.012353 (0.012353) train_acc 100.000000 (100.000000)\n", | |
"[17, 101] loss 0.069803 (0.047732) train_acc 96.875000 (98.483911)\n", | |
"[17, 201] loss 0.066356 (0.045589) train_acc 98.437500 (98.523010)\n", | |
"[17, 301] loss 0.017363 (0.046701) train_acc 100.000000 (98.489410)\n", | |
"[17, 401] loss 0.020765 (0.044417) train_acc 98.437500 (98.577774)\n", | |
"[17, 501] loss 0.093905 (0.045263) train_acc 96.875000 (98.574726)\n", | |
"[17, 601] loss 0.015167 (0.045330) train_acc 100.000000 (98.601290)\n", | |
"[17, 701] loss 0.019424 (0.046020) train_acc 100.000000 (98.582382)\n", | |
"[17, 801] loss 0.029602 (0.046489) train_acc 98.437500 (98.562344)\n", | |
"[17, 901] loss 0.044802 (0.046264) train_acc 98.437500 (98.567564)\n", | |
"[18, 1] loss 0.019147 (0.019147) train_acc 100.000000 (100.000000)\n", | |
"[18, 101] loss 0.019280 (0.044617) train_acc 100.000000 (98.669554)\n", | |
"[18, 201] loss 0.010242 (0.041541) train_acc 100.000000 (98.725124)\n", | |
"[18, 301] loss 0.015521 (0.042972) train_acc 100.000000 (98.660714)\n", | |
"[18, 401] loss 0.046253 (0.043096) train_acc 98.437500 (98.690773)\n", | |
"[18, 501] loss 0.128635 (0.043363) train_acc 96.875000 (98.640220)\n", | |
"[18, 601] loss 0.049832 (0.042982) train_acc 96.875000 (98.619488)\n", | |
"[18, 701] loss 0.023064 (0.043624) train_acc 98.437500 (98.606901)\n", | |
"[18, 801] loss 0.118694 (0.043160) train_acc 98.437500 (98.620865)\n", | |
"[18, 901] loss 0.013449 (0.043680) train_acc 100.000000 (98.610918)\n", | |
"[19, 1] loss 0.006306 (0.006306) train_acc 100.000000 (100.000000)\n", | |
"[19, 101] loss 0.017740 (0.036281) train_acc 100.000000 (98.793317)\n", | |
"[19, 201] loss 0.011227 (0.042090) train_acc 100.000000 (98.546331)\n", | |
"[19, 301] loss 0.088088 (0.041728) train_acc 98.437500 (98.634759)\n", | |
"[19, 401] loss 0.067727 (0.041819) train_acc 96.875000 (98.671291)\n", | |
"[19, 501] loss 0.125511 (0.042568) train_acc 95.312500 (98.655813)\n", | |
"[19, 601] loss 0.062227 (0.042360) train_acc 96.875000 (98.663686)\n", | |
"[19, 701] loss 0.057913 (0.042971) train_acc 98.437500 (98.667083)\n", | |
"[19, 801] loss 0.028299 (0.042464) train_acc 98.437500 (98.669632)\n", | |
"[19, 901] loss 0.010545 (0.042111) train_acc 100.000000 (98.683754)\n", | |
"[20, 1] loss 0.012830 (0.012830) train_acc 100.000000 (100.000000)\n", | |
"[20, 101] loss 0.042024 (0.036151) train_acc 98.437500 (98.746906)\n", | |
"[20, 201] loss 0.022006 (0.040467) train_acc 100.000000 (98.732898)\n", | |
"[20, 301] loss 0.047596 (0.040482) train_acc 98.437500 (98.660714)\n", | |
"[20, 401] loss 0.021512 (0.039404) train_acc 100.000000 (98.737531)\n", | |
"[20, 501] loss 0.050536 (0.038565) train_acc 98.437500 (98.780564)\n", | |
"[20, 601] loss 0.023132 (0.039123) train_acc 98.437500 (98.757280)\n", | |
"[20, 701] loss 0.028590 (0.038926) train_acc 98.437500 (98.729494)\n", | |
"[20, 801] loss 0.014648 (0.038681) train_acc 100.000000 (98.741807)\n", | |
"[20, 901] loss 0.125813 (0.038972) train_acc 95.312500 (98.730577)\n", | |
"Finished Training\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "vKoPoXnPuxWR", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 68 | |
}, | |
"outputId": "24e85c70-aab2-45a0-d8c2-3a2bea53c95c" | |
}, | |
"source": [ | |
"qnet = qnet.cpu()\n", | |
"torch.quantization.convert(qnet, inplace=True)\n", | |
"print(\"Size of model after quantization\")\n", | |
"print_size_of_model(qnet)\n", | |
"\n", | |
"score = test(qnet, testloader, cuda=False)\n", | |
"print('Accuracy of the fused and quantized network (trained quantized) on the test images: {}% - INT8'.format(score))" | |
], | |
"execution_count": 19, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Size of model after quantization\n", | |
"Size (MB): 0.056182\n", | |
"Accuracy of the fused and quantized network (trained quantized) on the test images: 98.36% - INT8\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "HcHE8SBitv9W", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Training a quantized model with high accuracy requires accurate modeling of numerics at inference. For quantization aware training, therefore, we can modify the training loop by freezing the quantizer parameters (scale and zero-point) and fine tune the weights." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "1OvCOpSFvIJt", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000 | |
}, | |
"outputId": "d6aaaf2e-46e5-441b-97c7-0293d2a479be" | |
}, | |
"source": [ | |
"qnet = Net(q=True)\n", | |
"fuse_modules(qnet)\n", | |
"qnet.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')\n", | |
"torch.quantization.prepare_qat(qnet, inplace=True)\n", | |
"qnet = qnet.cuda()\n", | |
"train(qnet, trainloader, cuda=True, q=True)\n", | |
"qnet = qnet.cpu()\n", | |
"torch.quantization.convert(qnet, inplace=True)\n", | |
"print(\"Size of model after quantization\")\n", | |
"print_size_of_model(qnet)\n", | |
"\n", | |
"score = test(qnet, testloader, cuda=False)\n", | |
"print('Accuracy of the fused and quantized network (trained quantized) on the test images: {}% - INT8'.format(score))" | |
], | |
"execution_count": 20, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"[1, 1] loss 2.302550 (2.302550) train_acc 7.812500 (7.812500)\n", | |
"[1, 101] loss 2.297554 (2.300715) train_acc 20.312500 (13.845916)\n", | |
"[1, 201] loss 2.282641 (2.297055) train_acc 34.375000 (18.135883)\n", | |
"[1, 301] loss 2.270876 (2.292123) train_acc 39.062500 (22.809385)\n", | |
"[1, 401] loss 2.262033 (2.285715) train_acc 37.500000 (26.683292)\n", | |
"[1, 501] loss 2.202892 (2.275584) train_acc 53.125000 (30.417290)\n", | |
"[1, 601] loss 2.071962 (2.256968) train_acc 45.312500 (33.236273)\n", | |
"[1, 701] loss 1.763640 (2.214525) train_acc 59.375000 (35.529601)\n", | |
"[1, 801] loss 1.114516 (2.123895) train_acc 76.562500 (39.054697)\n", | |
"[1, 901] loss 0.700801 (1.992948) train_acc 85.937500 (43.215871)\n", | |
"[2, 1] loss 0.738226 (0.738226) train_acc 84.375000 (84.375000)\n", | |
"[2, 101] loss 0.517096 (0.564659) train_acc 85.937500 (84.730817)\n", | |
"[2, 201] loss 0.456042 (0.522940) train_acc 87.500000 (85.432214)\n", | |
"[2, 301] loss 0.316576 (0.485950) train_acc 89.062500 (86.399502)\n", | |
"[2, 401] loss 0.328412 (0.451746) train_acc 93.750000 (87.375312)\n", | |
"[2, 501] loss 0.150412 (0.423844) train_acc 98.437500 (88.045783)\n", | |
"[2, 601] loss 0.179842 (0.400242) train_acc 93.750000 (88.649126)\n", | |
"[2, 701] loss 0.228518 (0.383382) train_acc 92.187500 (89.042439)\n", | |
"[2, 801] loss 0.474089 (0.366581) train_acc 85.937500 (89.493602)\n", | |
"[2, 901] loss 0.137446 (0.352672) train_acc 93.750000 (89.820339)\n", | |
"[3, 1] loss 0.157492 (0.157492) train_acc 95.312500 (95.312500)\n", | |
"[3, 101] loss 0.191074 (0.214266) train_acc 93.750000 (93.641708)\n", | |
"[3, 201] loss 0.286669 (0.212748) train_acc 89.062500 (93.547886)\n", | |
"[3, 301] loss 0.118892 (0.206445) train_acc 96.875000 (93.781146)\n", | |
"[3, 401] loss 0.155797 (0.203225) train_acc 95.312500 (93.812344)\n", | |
"[3, 501] loss 0.220620 (0.199711) train_acc 93.750000 (93.902819)\n", | |
"[3, 601] loss 0.103094 (0.198659) train_acc 96.875000 (93.944988)\n", | |
"[3, 701] loss 0.128453 (0.193869) train_acc 96.875000 (94.068741)\n", | |
"[3, 801] loss 0.313520 (0.191558) train_acc 90.625000 (94.163546)\n", | |
"[3, 901] loss 0.275842 (0.189198) train_acc 95.312500 (94.251179)\n", | |
"[4, 1] loss 0.120804 (0.120804) train_acc 95.312500 (95.312500)\n", | |
"[4, 101] loss 0.112875 (0.151322) train_acc 98.437500 (95.745668)\n", | |
"[4, 201] loss 0.098576 (0.148086) train_acc 96.875000 (95.708955)\n", | |
"[4, 301] loss 0.267649 (0.147990) train_acc 93.750000 (95.639535)\n", | |
"[4, 401] loss 0.125630 (0.146719) train_acc 93.750000 (95.647600)\n", | |
"[4, 501] loss 0.094601 (0.144580) train_acc 96.875000 (95.683633)\n", | |
"[4, 601] loss 0.159038 (0.142246) train_acc 95.312500 (95.736273)\n", | |
"[4, 701] loss 0.090629 (0.140874) train_acc 95.312500 (95.773894)\n", | |
"[4, 801] loss 0.094944 (0.140627) train_acc 98.437500 (95.774813)\n", | |
"[4, 901] loss 0.066600 (0.139774) train_acc 98.437500 (95.778996)\n", | |
"[5, 1] loss 0.083039 (0.083039) train_acc 96.875000 (96.875000)\n", | |
"[5, 101] loss 0.105796 (0.128018) train_acc 96.875000 (96.147896)\n", | |
"[5, 201] loss 0.136708 (0.124957) train_acc 95.312500 (96.105410)\n", | |
"[5, 301] loss 0.180986 (0.124585) train_acc 96.875000 (96.137874)\n", | |
"[5, 401] loss 0.228525 (0.120904) train_acc 90.625000 (96.290524)\n", | |
"[5, 501] loss 0.157131 (0.119609) train_acc 95.312500 (96.344810)\n", | |
"[5, 601] loss 0.147801 (0.120152) train_acc 93.750000 (96.339434)\n", | |
"[5, 701] loss 0.092475 (0.119133) train_acc 95.312500 (96.398003)\n", | |
"[5, 801] loss 0.030053 (0.117845) train_acc 100.000000 (96.445849)\n", | |
"[5, 901] loss 0.057119 (0.116151) train_acc 98.437500 (96.509087)\n", | |
"[6, 1] loss 0.065908 (0.065908) train_acc 96.875000 (96.875000)\n", | |
"[6, 101] loss 0.178128 (0.102600) train_acc 95.312500 (96.859530)\n", | |
"[6, 201] loss 0.106632 (0.103455) train_acc 93.750000 (96.735075)\n", | |
"[6, 301] loss 0.149924 (0.104807) train_acc 95.312500 (96.688123)\n", | |
"[6, 401] loss 0.201291 (0.104085) train_acc 93.750000 (96.765898)\n", | |
"[6, 501] loss 0.044459 (0.103514) train_acc 100.000000 (96.765843)\n", | |
"[6, 601] loss 0.103602 (0.104020) train_acc 98.437500 (96.752808)\n", | |
"[6, 701] loss 0.050781 (0.103043) train_acc 96.875000 (96.785842)\n", | |
"[6, 801] loss 0.092638 (0.101568) train_acc 96.875000 (96.828184)\n", | |
"[6, 901] loss 0.121193 (0.101234) train_acc 93.750000 (96.821240)\n", | |
"[7, 1] loss 0.027914 (0.027914) train_acc 100.000000 (100.000000)\n", | |
"[7, 101] loss 0.085401 (0.105230) train_acc 98.437500 (96.612005)\n", | |
"[7, 201] loss 0.036594 (0.095268) train_acc 100.000000 (97.022699)\n", | |
"[7, 301] loss 0.076280 (0.097793) train_acc 98.437500 (97.015158)\n", | |
"[7, 401] loss 0.043487 (0.095206) train_acc 98.437500 (97.100998)\n", | |
"[7, 501] loss 0.024349 (0.095355) train_acc 100.000000 (97.099551)\n", | |
"[7, 601] loss 0.162680 (0.095196) train_acc 96.875000 (97.103785)\n", | |
"[7, 701] loss 0.061703 (0.093244) train_acc 98.437500 (97.133559)\n", | |
"[7, 801] loss 0.074943 (0.092670) train_acc 96.875000 (97.151998)\n", | |
"[7, 901] loss 0.060584 (0.091453) train_acc 98.437500 (97.192356)\n", | |
"[8, 1] loss 0.152832 (0.152832) train_acc 95.312500 (95.312500)\n", | |
"[8, 101] loss 0.035814 (0.086930) train_acc 98.437500 (97.354579)\n", | |
"[8, 201] loss 0.102473 (0.089991) train_acc 93.750000 (97.201493)\n", | |
"[8, 301] loss 0.141122 (0.085625) train_acc 93.750000 (97.285091)\n", | |
"[8, 401] loss 0.084509 (0.085933) train_acc 96.875000 (97.295823)\n", | |
"[8, 501] loss 0.046042 (0.084555) train_acc 98.437500 (97.339696)\n", | |
"[8, 601] loss 0.144402 (0.084298) train_acc 96.875000 (97.379368)\n", | |
"[8, 701] loss 0.063557 (0.084965) train_acc 96.875000 (97.354226)\n", | |
"[8, 801] loss 0.034335 (0.083684) train_acc 98.437500 (97.384129)\n", | |
"[8, 901] loss 0.135376 (0.083938) train_acc 96.875000 (97.388319)\n", | |
"[9, 1] loss 0.080674 (0.080674) train_acc 95.312500 (95.312500)\n", | |
"[9, 101] loss 0.080960 (0.079276) train_acc 96.875000 (97.447401)\n", | |
"[9, 201] loss 0.039168 (0.082197) train_acc 100.000000 (97.465796)\n", | |
"[9, 301] loss 0.040251 (0.079286) train_acc 100.000000 (97.477159)\n", | |
"[9, 401] loss 0.080177 (0.078448) train_acc 96.875000 (97.564682)\n", | |
"[9, 501] loss 0.137624 (0.078713) train_acc 93.750000 (97.548653)\n", | |
"[9, 601] loss 0.048188 (0.078182) train_acc 98.437500 (97.592554)\n", | |
"[9, 701] loss 0.152732 (0.079140) train_acc 92.187500 (97.572664)\n", | |
"[9, 801] loss 0.076186 (0.078587) train_acc 98.437500 (97.598705)\n", | |
"[9, 901] loss 0.081240 (0.078819) train_acc 96.875000 (97.596421)\n", | |
"[10, 1] loss 0.036151 (0.036151) train_acc 98.437500 (98.437500)\n", | |
"[10, 101] loss 0.065628 (0.070995) train_acc 98.437500 (97.818688)\n", | |
"[10, 201] loss 0.081905 (0.072232) train_acc 95.312500 (97.854478)\n", | |
"[10, 301] loss 0.036074 (0.071981) train_acc 98.437500 (97.840532)\n", | |
"[10, 401] loss 0.050633 (0.074849) train_acc 98.437500 (97.747818)\n", | |
"[10, 501] loss 0.038039 (0.074749) train_acc 98.437500 (97.735778)\n", | |
"[10, 601] loss 0.092357 (0.076482) train_acc 96.875000 (97.665349)\n", | |
"[10, 701] loss 0.024741 (0.074530) train_acc 100.000000 (97.737607)\n", | |
"[10, 801] loss 0.015401 (0.073452) train_acc 100.000000 (97.787921)\n", | |
"[10, 901] loss 0.013214 (0.073495) train_acc 100.000000 (97.809725)\n", | |
"[11, 1] loss 0.021667 (0.021667) train_acc 98.437500 (98.437500)\n", | |
"[11, 101] loss 0.037800 (0.078155) train_acc 98.437500 (97.586634)\n", | |
"[11, 201] loss 0.049924 (0.070107) train_acc 98.437500 (97.846704)\n", | |
"[11, 301] loss 0.070509 (0.071629) train_acc 98.437500 (97.788621)\n", | |
"[11, 401] loss 0.087520 (0.070533) train_acc 96.875000 (97.786783)\n", | |
"[11, 501] loss 0.115143 (0.069339) train_acc 96.875000 (97.816866)\n", | |
"[11, 601] loss 0.067641 (0.069719) train_acc 96.875000 (97.805740)\n", | |
"[11, 701] loss 0.133686 (0.069835) train_acc 95.312500 (97.824536)\n", | |
"[11, 801] loss 0.019427 (0.069404) train_acc 100.000000 (97.840590)\n", | |
"[11, 901] loss 0.182608 (0.069460) train_acc 93.750000 (97.830536)\n", | |
"[12, 1] loss 0.013651 (0.013651) train_acc 100.000000 (100.000000)\n", | |
"[12, 101] loss 0.084725 (0.060728) train_acc 95.312500 (98.174505)\n", | |
"[12, 201] loss 0.107834 (0.066546) train_acc 95.312500 (98.002177)\n", | |
"[12, 301] loss 0.064416 (0.064022) train_acc 98.437500 (98.105274)\n", | |
"[12, 401] loss 0.102449 (0.064468) train_acc 96.875000 (98.114090)\n", | |
"[12, 501] loss 0.228205 (0.066359) train_acc 93.750000 (98.047655)\n", | |
"[12, 601] loss 0.029645 (0.065243) train_acc 100.000000 (98.068324)\n", | |
"[12, 701] loss 0.052124 (0.065814) train_acc 98.437500 (98.018456)\n", | |
"[12, 801] loss 0.156821 (0.065505) train_acc 96.875000 (98.033708)\n", | |
"[12, 901] loss 0.028394 (0.065984) train_acc 98.437500 (98.003954)\n", | |
"[13, 1] loss 0.119072 (0.119072) train_acc 96.875000 (96.875000)\n", | |
"[13, 101] loss 0.050072 (0.054741) train_acc 98.437500 (98.422030)\n", | |
"[13, 201] loss 0.067381 (0.060345) train_acc 98.437500 (98.180970)\n", | |
"[13, 301] loss 0.048072 (0.060722) train_acc 96.875000 (98.183140)\n", | |
"[13, 401] loss 0.186138 (0.061943) train_acc 96.875000 (98.129676)\n", | |
"[13, 501] loss 0.027444 (0.064065) train_acc 98.437500 (98.075724)\n", | |
"[13, 601] loss 0.047273 (0.062451) train_acc 96.875000 (98.094322)\n", | |
"[13, 701] loss 0.057710 (0.062872) train_acc 96.875000 (98.076409)\n", | |
"[13, 801] loss 0.060320 (0.063511) train_acc 96.875000 (98.076623)\n", | |
"[13, 901] loss 0.037433 (0.063473) train_acc 98.437500 (98.061182)\n", | |
"[14, 1] loss 0.038703 (0.038703) train_acc 100.000000 (100.000000)\n", | |
"[14, 101] loss 0.021609 (0.053827) train_acc 100.000000 (98.391089)\n", | |
"[14, 201] loss 0.058564 (0.056649) train_acc 98.437500 (98.266480)\n", | |
"[14, 301] loss 0.037816 (0.057818) train_acc 98.437500 (98.203904)\n", | |
"[14, 401] loss 0.067520 (0.059272) train_acc 98.437500 (98.211502)\n", | |
"[14, 501] loss 0.019934 (0.060878) train_acc 100.000000 (98.116267)\n", | |
"[14, 601] loss 0.035314 (0.060599) train_acc 98.437500 (98.146319)\n", | |
"[14, 701] loss 0.124766 (0.061728) train_acc 98.437500 (98.096469)\n", | |
"[14, 801] loss 0.040369 (0.061778) train_acc 98.437500 (98.098081)\n", | |
"[14, 901] loss 0.010305 (0.060288) train_acc 100.000000 (98.154828)\n", | |
"[15, 1] loss 0.022719 (0.022719) train_acc 98.437500 (98.437500)\n", | |
"[15, 101] loss 0.263579 (0.057640) train_acc 92.187500 (98.313738)\n", | |
"[15, 201] loss 0.030529 (0.058956) train_acc 98.437500 (98.219838)\n", | |
"[15, 301] loss 0.064915 (0.058186) train_acc 98.437500 (98.229859)\n", | |
"[15, 401] loss 0.011687 (0.056681) train_acc 100.000000 (98.308915)\n", | |
"[15, 501] loss 0.140703 (0.056459) train_acc 96.875000 (98.309631)\n", | |
"[15, 601] loss 0.142636 (0.057355) train_acc 95.312500 (98.291909)\n", | |
"[15, 701] loss 0.024441 (0.057568) train_acc 98.437500 (98.268099)\n", | |
"[15, 801] loss 0.080110 (0.057504) train_acc 95.312500 (98.248283)\n", | |
"[15, 901] loss 0.058720 (0.057768) train_acc 98.437500 (98.245006)\n", | |
"[16, 1] loss 0.011128 (0.011128) train_acc 100.000000 (100.000000)\n", | |
"[16, 101] loss 0.026723 (0.057070) train_acc 100.000000 (98.128094)\n", | |
"[16, 201] loss 0.051001 (0.051976) train_acc 96.875000 (98.336443)\n", | |
"[16, 301] loss 0.031592 (0.053245) train_acc 100.000000 (98.328488)\n", | |
"[16, 401] loss 0.098772 (0.053506) train_acc 96.875000 (98.332294)\n", | |
"[16, 501] loss 0.141219 (0.054016) train_acc 96.875000 (98.328343)\n", | |
"[16, 601] loss 0.058870 (0.053847) train_acc 98.437500 (98.354305)\n", | |
"[16, 701] loss 0.007093 (0.053943) train_acc 100.000000 (98.350571)\n", | |
"[16, 801] loss 0.132708 (0.054902) train_acc 96.875000 (98.332163)\n", | |
"[16, 901] loss 0.037177 (0.055351) train_acc 98.437500 (98.324778)\n", | |
"[17, 1] loss 0.005327 (0.005327) train_acc 100.000000 (100.000000)\n", | |
"[17, 101] loss 0.077196 (0.053524) train_acc 98.437500 (98.530322)\n", | |
"[17, 201] loss 0.021868 (0.054216) train_acc 100.000000 (98.445274)\n", | |
"[17, 301] loss 0.011024 (0.055832) train_acc 100.000000 (98.333679)\n", | |
"[17, 401] loss 0.027054 (0.055034) train_acc 98.437500 (98.332294)\n", | |
"[17, 501] loss 0.010150 (0.054225) train_acc 100.000000 (98.337700)\n", | |
"[17, 601] loss 0.029712 (0.054110) train_acc 98.437500 (98.349106)\n", | |
"[17, 701] loss 0.035342 (0.054022) train_acc 98.437500 (98.352800)\n", | |
"[17, 801] loss 0.059302 (0.053984) train_acc 96.875000 (98.343867)\n", | |
"[17, 901] loss 0.068517 (0.053667) train_acc 98.437500 (98.342120)\n", | |
"[18, 1] loss 0.075476 (0.075476) train_acc 98.437500 (98.437500)\n", | |
"[18, 101] loss 0.145757 (0.052178) train_acc 96.875000 (98.329208)\n", | |
"[18, 201] loss 0.043926 (0.053253) train_acc 98.437500 (98.359764)\n", | |
"[18, 301] loss 0.049170 (0.053832) train_acc 98.437500 (98.375208)\n", | |
"[18, 401] loss 0.036931 (0.053273) train_acc 98.437500 (98.359570)\n", | |
"[18, 501] loss 0.019969 (0.053371) train_acc 100.000000 (98.381362)\n", | |
"[18, 601] loss 0.038306 (0.053556) train_acc 98.437500 (98.372504)\n", | |
"[18, 701] loss 0.041648 (0.052409) train_acc 100.000000 (98.417439)\n", | |
"[18, 801] loss 0.024367 (0.051907) train_acc 98.437500 (98.431648)\n", | |
"[18, 901] loss 0.028199 (0.051279) train_acc 98.437500 (98.449639)\n", | |
"[19, 1] loss 0.077943 (0.077943) train_acc 98.437500 (98.437500)\n", | |
"[19, 101] loss 0.041550 (0.052222) train_acc 98.437500 (98.205446)\n", | |
"[19, 201] loss 0.050326 (0.046886) train_acc 96.875000 (98.445274)\n", | |
"[19, 301] loss 0.012051 (0.047147) train_acc 100.000000 (98.494601)\n", | |
"[19, 401] loss 0.034183 (0.049236) train_acc 98.437500 (98.476465)\n", | |
"[19, 501] loss 0.020020 (0.048748) train_acc 100.000000 (98.506113)\n", | |
"[19, 601] loss 0.141519 (0.048220) train_acc 95.312500 (98.525894)\n", | |
"[19, 701] loss 0.032047 (0.049645) train_acc 100.000000 (98.495453)\n", | |
"[19, 801] loss 0.082106 (0.049142) train_acc 96.875000 (98.505774)\n", | |
"[19, 901] loss 0.072608 (0.049163) train_acc 96.875000 (98.505133)\n", | |
"[20, 1] loss 0.042147 (0.042147) train_acc 100.000000 (100.000000)\n", | |
"[20, 101] loss 0.069253 (0.049164) train_acc 98.437500 (98.592203)\n", | |
"[20, 201] loss 0.066975 (0.049972) train_acc 98.437500 (98.561878)\n", | |
"[20, 301] loss 0.114514 (0.050178) train_acc 96.875000 (98.536130)\n", | |
"[20, 401] loss 0.014213 (0.048549) train_acc 100.000000 (98.554395)\n", | |
"[20, 501] loss 0.032720 (0.048352) train_acc 98.437500 (98.546657)\n", | |
"[20, 601] loss 0.129583 (0.048267) train_acc 96.875000 (98.538894)\n", | |
"[20, 701] loss 0.070022 (0.048461) train_acc 95.312500 (98.535574)\n", | |
"[20, 801] loss 0.007984 (0.048427) train_acc 100.000000 (98.540886)\n", | |
"[20, 901] loss 0.098393 (0.048538) train_acc 96.875000 (98.531146)\n", | |
"Finished Training\n", | |
"Size of model after quantization\n", | |
"Size (MB): 0.056182\n", | |
"Accuracy of the fused and quantized network (trained quantized) on the test images: 98.51% - INT8\n" | |
], | |
"name": "stdout" | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment