Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save AlisonDavey/17c5a5c09c7bbf31a3b95515c7cb70e6 to your computer and use it in GitHub Desktop.
Save AlisonDavey/17c5a5c09c7bbf31a3b95515c7cb70e6 to your computer and use it in GitHub Desktop.
Image_Classifier_Project.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "# Udacity PyTorch Challenge Classification Project\n\nIn this project an image classifier is trained to recognize different species of flowers using [the Oxford dataset](http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html) of 102 flower categories.\n\nThe project is broken down into:\n\n* Create a test dataset\n* Load and preprocess the image dataset\n* Train the image classifier on the dataset\n* Use the trained classifier to predict image content\n\n[PEP 8 -- Style Guide for Python Code](https://www.python.org/dev/peps/pep-0008/) says that imports are always put at the top of the file.\n\nSource files https://github.com/udacity/pytorch_challenge\n"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "%reload_ext autoreload\n%autoreload 2\n%matplotlib inline\n\nfrom PIL import Image\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\nfrom collections import OrderedDict\n\nimport torch\nimport torch.optim as optim\nimport torch.nn as nn\nfrom torchvision import datasets, transforms, models\n\nimport scipy.io as sio\nfrom pathlib import Path\nfrom shutil import copy\nimport time\n#import copy\nimport shutil\nimport os\nimport random",
"execution_count": 1,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Data Downloading and Creation of a Test Dataset\n\nCreate a test folder (with its labels) for the Udacity PyTorch Challenge flower dataset that contains the images in the original Oxford dataset that are not in the Udacity dataset.\n\nThe dataset will then have three parts: training, validation and test. "
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "# Download the original dataset of 102 different categories of flowers common to the UK \n# from the Visual Geometry Group, Department of Engineering Science, University of Oxford \n# http://www.robots.ox.ac.uk/~vgg/data/flowers/\n!wget -O ./assets/102flowers.tgz \"http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz\"\n",
"execution_count": 14,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "--2019-01-19 17:06:09-- http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz\nResolving www.robots.ox.ac.uk (www.robots.ox.ac.uk)... 129.67.94.2\nConnecting to www.robots.ox.ac.uk (www.robots.ox.ac.uk)|129.67.94.2|:80... connected.\nHTTP request sent, awaiting response... 200 OK\nLength: 344862509 (329M) [application/x-gzip]\nSaving to: ‘./assets/102flowers.tgz’\n\n./assets/102flowers 100%[===================>] 328.89M 21.7MB/s in 16s \n\n2019-01-19 17:06:27 (20.4 MB/s) - ‘./assets/102flowers.tgz’ saved [344862509/344862509]\n\n"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "# Untar the Oxford dataset\n!tar xzf ./assets/102flowers.tgz -C ./assets/\n",
"execution_count": 15,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "## Download the Oxford labels\n!wget -O ./assets/imagelabels.mat \"http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat\"\n",
"execution_count": 16,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "--2019-01-19 17:07:20-- http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat\nResolving www.robots.ox.ac.uk (www.robots.ox.ac.uk)... 129.67.94.2\nConnecting to www.robots.ox.ac.uk (www.robots.ox.ac.uk)|129.67.94.2|:80... connected.\nHTTP request sent, awaiting response... 200 OK\nLength: 502\nSaving to: ‘./assets/imagelabels.mat’\n\n./assets/imagelabel 100%[===================>] 502 --.-KB/s in 0s \n\n2019-01-19 17:07:20 (102 MB/s) - ‘./assets/imagelabels.mat’ saved [502/502]\n\n"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "# Download the Udacity PyTorch Challenge flower dataset \n!wget -O ./assets/flower_data.zip \"https://s3.amazonaws.com/content.udacity-data.com/courses/nd188/flower_data.zip\"\n",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "# Unzip the Udacity dataset\n!unzip ./assets/flower_data.zip -d ./assets/\n",
"execution_count": 3,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "root_dir = Path('./assets')\noriginal_dir = root_dir/'jpg'\nlabels_file = root_dir/'imagelabels.mat'\nudacity_dir = root_dir/'flower_data'\nudacity_train_dir = udacity_dir/'train'\nudacity_valid_dir = udacity_dir/'valid'\nudacity_test_dir = udacity_dir/'test'",
"execution_count": 22,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "# Script prepared by a student in the Udacity PyTorch Scholarship Challenge, apologies that \n# I didn't note their name\nlabels=sio.loadmat(labels_file)['labels'][0]\n(_, _, original_images) = next(os.walk(original_dir))\noriginal_images = sorted(original_images)\nimage_to_label = {name: labels[i] for i, name in enumerate(original_images)}\nudacity_images = []\nfor root, dirs, files in os.walk(udacity_dir): udacity_images.extend(files)\ndiff = set(original_images) - set(udacity_images)\nudacity_test_dir.mkdir(parents=True, exist_ok=True)\nfor file in diff:\n dest_dir = udacity_test_dir/str(image_to_label[file])\n dest_dir.mkdir(parents=True, exist_ok=True)\n copy(original_dir/file, dest_dir)",
"execution_count": 29,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Data Loading and Exploration\n\nFor the training, transformations such as random scaling, cropping, and flipping are applied. This helps the network generalize leading to better performance. Input data is resized to 224x224 pixels as required by the networks.\n\nThe validation set is used to measure the model's performance on data it hasn't seen yet. For this scaling or rotation transformations are not used, but the images are resized then cropped to the appropriate size.\n\nThe test set was created by a student in the challenge by comparing the data selected by Udacity to the original data.\n\nThe pre-trained networks available from `torchvision` were trained on the ImageNet dataset where each colour channel was normalized separately. The image data here needs to be normalized using the same means and standard deviations: `[0.485, 0.456, 0.406] [0.229, 0.224, 0.225]`, as those calculated from the ImageNet images. These values will shift each colour channel to be centered at 0 and range from -1 to 1."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "#Load and preprocess the image dataset\ndata_dir = './assets/flower_data'\ntrain_dir = data_dir + '/train'\nvalid_dir = data_dir + '/valid'\ntest_dir = data_dir + '/test'",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "image = Image.open(train_dir+'/54/image_05459.jpg')\nplt.imshow(image);",
"execution_count": 27,
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": "<Figure size 432x288 with 1 Axes>"
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "#Exploratory Data Analysis\ndef eda_counts():\n dirs = [train_dir,valid_dir,test_dir]\n for directory in dirs:\n counts = []\n total_images = 0\n min_count = 100\n max_count = 0\n print ('Dataset: ',directory)\n folders = ([name for name in os.listdir(directory)])\n for folder in folders:\n contents = os.listdir(os.path.join(directory,folder))\n total_images += len(contents)\n if len(contents)<min_count:\n min_count = len(contents)\n if len(contents)>max_count:\n max_count = len(contents)\n counts.append((folder,len(contents)))\n print ('Classes: ', len(counts), 'Fewest: ', min_count, 'Most: ', max_count)\n print ('Number of images: ', total_images)\n print ('Counts per class: ', counts)\n print ()\neda_counts()",
"execution_count": 30,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "Dataset: ./assets/flower_data/train\nClasses: 102 Fewest: 27 Most: 206\nNumber of images: 6552\nCounts per class: [('87', 51), ('99', 50), ('71', 64), ('60', 85), ('27', 36), ('85', 48), ('100', 35), ('53', 70), ('24', 35), ('48', 57), ('96', 72), ('81', 135), ('13', 38), ('62', 48), ('40', 54), ('64', 42), ('49', 38), ('47', 61), ('32', 36), ('89', 153), ('8', 70), ('44', 73), ('10', 38), ('11', 68), ('98', 68), ('46', 157), ('42', 49), ('61', 36), ('20', 46), ('73', 147), ('56', 92), ('25', 34), ('30', 61), ('22', 47), ('43', 100), ('84', 66), ('18', 65), ('70', 51), ('15', 38), ('69', 46), ('37', 92), ('16', 36), ('33', 31), ('57', 50), ('78', 112), ('101', 49), ('67', 36), ('65', 88), ('34', 28), ('3', 36), ('26', 33), ('36', 62), ('21', 34), ('52', 67), ('92', 53), ('55', 56), ('51', 206), ('28', 55), ('45', 33), ('54', 47), ('41', 97), ('88', 116), ('7', 33), ('77', 205), ('14', 44), ('4', 44), ('83', 104), ('50', 73), ('39', 33), ('63', 42), ('90', 66), ('80', 82), ('97', 54), ('59', 56), ('58', 86), ('76', 83), ('29', 62), ('95', 101), ('72', 77), ('5', 54), ('31', 48), ('79', 34), ('12', 73), ('23', 72), ('91', 59), ('66', 51), ('86', 48), ('74', 142), ('75', 95), ('2', 49), ('17', 60), ('93', 34), ('35', 33), ('9', 41), ('94', 132), ('1', 27), ('102', 36), ('68', 43), ('82', 82), ('6', 35), ('19', 38), ('38', 44)]\n\nDataset: ./assets/flower_data/valid\nClasses: 102 Fewest: 1 Most: 28\nNumber of images: 818\nCounts per class: [('87', 6), ('99', 6), ('71', 5), ('60', 14), ('27', 1), ('85', 5), ('100', 6), ('53', 9), ('24', 5), ('48', 9), ('96', 10), ('81', 18), ('13', 5), ('62', 3), ('40', 5), ('64', 5), ('49', 8), ('47', 3), ('32', 3), ('89', 16), ('8', 5), ('44', 9), ('10', 4), ('11', 10), ('98', 10), ('46', 18), ('42', 6), ('61', 6), ('20', 7), ('73', 19), ('56', 9), ('25', 2), ('30', 10), ('22', 8), ('43', 14), ('84', 10), ('18', 11), ('70', 7), ('15', 7), ('69', 5), ('37', 8), ('16', 2), ('33', 7), ('57', 6), ('78', 11), ('101', 5), ('67', 2), ('65', 7), ('34', 7), ('3', 2), ('26', 3), ('36', 6), ('21', 4), ('52', 10), ('92', 2), ('55', 8), ('51', 28), ('28', 5), ('45', 4), ('54', 10), ('41', 16), ('88', 25), ('7', 1), ('77', 21), ('14', 1), ('4', 6), ('83', 13), ('50', 11), ('39', 3), ('63', 8), ('90', 2), ('80', 12), ('97', 7), ('59', 4), ('58', 14), ('76', 20), ('29', 7), ('95', 13), ('72', 8), ('5', 7), ('31', 2), ('79', 4), ('12', 5), ('23', 12), ('91', 9), ('66', 6), ('86', 5), ('74', 15), ('75', 12), ('2', 6), ('17', 16), ('93', 6), ('35', 4), ('9', 3), ('94', 14), ('1', 8), ('102', 6), ('68', 8), ('82', 13), ('6', 1), ('19', 4), ('38', 4)]\n\nDataset: ./assets/flower_data/test\nClasses: 102 Fewest: 2 Most: 28\nNumber of images: 819\nCounts per class: [('87', 6), ('99', 7), ('71', 9), ('60', 10), ('27', 3), ('85', 10), ('100', 8), ('53', 14), ('24', 2), ('48', 5), ('96', 9), ('81', 13), ('13', 6), ('62', 4), ('40', 8), ('64', 5), ('49', 3), ('47', 3), ('32', 6), ('89', 15), ('8', 10), ('44', 11), ('10', 3), ('11', 9), ('98', 4), ('46', 21), ('42', 4), ('61', 8), ('20', 3), ('73', 28), ('56', 8), ('25', 5), ('30', 14), ('22', 4), ('43', 16), ('84', 10), ('18', 6), ('70', 4), ('15', 4), ('69', 3), ('37', 8), ('16', 3), ('33', 8), ('57', 11), ('78', 14), ('101', 4), ('67', 4), ('65', 7), ('34', 5), ('3', 2), ('26', 5), ('36', 7), ('21', 2), ('52', 8), ('92', 11), ('55', 7), ('51', 24), ('28', 6), ('45', 3), ('54', 4), ('41', 14), ('88', 13), ('7', 6), ('77', 25), ('14', 3), ('4', 6), ('83', 14), ('50', 8), ('39', 5), ('63', 4), ('90', 14), ('80', 11), ('97', 5), ('59', 7), ('58', 14), ('76', 4), ('29', 9), ('95', 14), ('72', 11), ('5', 4), ('31', 2), ('79', 3), ('12', 9), ('23', 7), ('91', 8), ('66', 4), ('86', 5), ('74', 14), ('75', 13), ('2', 5), ('17', 9), ('93', 6), ('35', 6), ('9', 2), ('94', 16), ('1', 5), ('102', 6), ('68', 3), ('82', 17), ('6', 9), ('19', 7), ('38', 8)]\n\n"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "# For the classes in the training set with few exampls, \n# copy random images in the sub-folder to increase the sample size\nmin_number = 54 #twice the smallest class\nfolders = ([name for name in os.listdir(train_dir)])\nfor folder in folders:\n contents = os.listdir(os.path.join(train_dir,folder))\n if len(contents) < min_number: \n images_to_add = min_number-len(contents)\n idx = np.random.choice(len(contents), images_to_add, replace=False)\n for x in range(images_to_add):\n shutil.copy(train_dir + '/' + str(folder) + '/' + contents[idx[x]], train_dir + '/' + str(folder) + '/cp_' + contents[idx[x]])\n print ('sub_folder ', folder, 'copied ' + str(images_to_add) + ' images')",
"execution_count": 31,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "sub_folder 87 copied 3 images\nsub_folder 99 copied 4 images\nsub_folder 27 copied 18 images\nsub_folder 85 copied 6 images\nsub_folder 100 copied 19 images\nsub_folder 24 copied 19 images\nsub_folder 13 copied 16 images\nsub_folder 62 copied 6 images\nsub_folder 64 copied 12 images\nsub_folder 49 copied 16 images\nsub_folder 32 copied 18 images\nsub_folder 10 copied 16 images\nsub_folder 42 copied 5 images\nsub_folder 61 copied 18 images\nsub_folder 20 copied 8 images\nsub_folder 25 copied 20 images\nsub_folder 22 copied 7 images\nsub_folder 70 copied 3 images\nsub_folder 15 copied 16 images\nsub_folder 69 copied 8 images\nsub_folder 16 copied 18 images\nsub_folder 33 copied 23 images\nsub_folder 57 copied 4 images\nsub_folder 101 copied 5 images\nsub_folder 67 copied 18 images\nsub_folder 34 copied 26 images\nsub_folder 3 copied 18 images\nsub_folder 26 copied 21 images\nsub_folder 21 copied 20 images\nsub_folder 92 copied 1 images\nsub_folder 45 copied 21 images\nsub_folder 54 copied 7 images\nsub_folder 7 copied 21 images\nsub_folder 14 copied 10 images\nsub_folder 4 copied 10 images\nsub_folder 39 copied 21 images\nsub_folder 63 copied 12 images\nsub_folder 31 copied 6 images\nsub_folder 79 copied 20 images\nsub_folder 66 copied 3 images\nsub_folder 86 copied 6 images\nsub_folder 2 copied 5 images\nsub_folder 93 copied 20 images\nsub_folder 35 copied 21 images\nsub_folder 9 copied 13 images\nsub_folder 1 copied 27 images\nsub_folder 102 copied 18 images\nsub_folder 68 copied 11 images\nsub_folder 6 copied 19 images\nsub_folder 19 copied 16 images\nsub_folder 38 copied 10 images\n"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Label mapping\n\nThe file `cat_to_name.json` is provided for the mapping from category label to category name. It's a JSON object which can be read in with the [`json` module](https://docs.python.org/2/library/json.html). This gives a dictionary mapping the integer encoded categories to the actual names of the flowers."
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "!wget -O ./assets/cat_to_name.json \"https://github.com/udacity/pytorch_challenge/blob/master/cat_to_name.json\"",
"execution_count": 7,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "import json\nwith open('cat_to_name.json', 'r') as f: cat_to_name = json.load(f) \ncat_to_name ",
"execution_count": 2,
"outputs": [
{
"data": {
"text/plain": "{'1': 'pink primrose',\n '10': 'globe thistle',\n '100': 'blanket flower',\n '101': 'trumpet creeper',\n '102': 'blackberry lily',\n '11': 'snapdragon',\n '12': \"colt's foot\",\n '13': 'king protea',\n '14': 'spear thistle',\n '15': 'yellow iris',\n '16': 'globe-flower',\n '17': 'purple coneflower',\n '18': 'peruvian lily',\n '19': 'balloon flower',\n '2': 'hard-leaved pocket orchid',\n '20': 'giant white arum lily',\n '21': 'fire lily',\n '22': 'pincushion flower',\n '23': 'fritillary',\n '24': 'red ginger',\n '25': 'grape hyacinth',\n '26': 'corn poppy',\n '27': 'prince of wales feathers',\n '28': 'stemless gentian',\n '29': 'artichoke',\n '3': 'canterbury bells',\n '30': 'sweet william',\n '31': 'carnation',\n '32': 'garden phlox',\n '33': 'love in the mist',\n '34': 'mexican aster',\n '35': 'alpine sea holly',\n '36': 'ruby-lipped cattleya',\n '37': 'cape flower',\n '38': 'great masterwort',\n '39': 'siam tulip',\n '4': 'sweet pea',\n '40': 'lenten rose',\n '41': 'barbeton daisy',\n '42': 'daffodil',\n '43': 'sword lily',\n '44': 'poinsettia',\n '45': 'bolero deep blue',\n '46': 'wallflower',\n '47': 'marigold',\n '48': 'buttercup',\n '49': 'oxeye daisy',\n '5': 'english marigold',\n '50': 'common dandelion',\n '51': 'petunia',\n '52': 'wild pansy',\n '53': 'primula',\n '54': 'sunflower',\n '55': 'pelargonium',\n '56': 'bishop of llandaff',\n '57': 'gaura',\n '58': 'geranium',\n '59': 'orange dahlia',\n '6': 'tiger lily',\n '60': 'pink-yellow dahlia',\n '61': 'cautleya spicata',\n '62': 'japanese anemone',\n '63': 'black-eyed susan',\n '64': 'silverbush',\n '65': 'californian poppy',\n '66': 'osteospermum',\n '67': 'spring crocus',\n '68': 'bearded iris',\n '69': 'windflower',\n '7': 'moon orchid',\n '70': 'tree poppy',\n '71': 'gazania',\n '72': 'azalea',\n '73': 'water lily',\n '74': 'rose',\n '75': 'thorn apple',\n '76': 'morning glory',\n '77': 'passion flower',\n '78': 'lotus lotus',\n '79': 'toad lily',\n '8': 'bird of paradise',\n '80': 'anthurium',\n '81': 'frangipani',\n '82': 'clematis',\n '83': 'hibiscus',\n '84': 'columbine',\n '85': 'desert-rose',\n '86': 'tree mallow',\n '87': 'magnolia',\n '88': 'cyclamen',\n '89': 'watercress',\n '9': 'monkshood',\n '90': 'canna lily',\n '91': 'hippeastrum',\n '92': 'bee balm',\n '93': 'ball moss',\n '94': 'foxglove',\n '95': 'bougainvillea',\n '96': 'camellia',\n '97': 'mallow',\n '98': 'mexican petunia',\n '99': 'bromelia'}"
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "# ImageFolder does not take the folder name as the label; it labels the folders in alphabetical order with numbers \n# (0,101)\nnames = pd.DataFrame.from_dict(cat_to_name, orient='index', columns=['class'])\nnames=names.sort_index(axis=0)\nnames['labels']=range(102)\nnames.head()",
"execution_count": 3,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>class</th>\n <th>labels</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>1</th>\n <td>pink primrose</td>\n <td>0</td>\n </tr>\n <tr>\n <th>10</th>\n <td>globe thistle</td>\n <td>1</td>\n </tr>\n <tr>\n <th>100</th>\n <td>blanket flower</td>\n <td>2</td>\n </tr>\n <tr>\n <th>101</th>\n <td>trumpet creeper</td>\n <td>3</td>\n </tr>\n <tr>\n <th>102</th>\n <td>blackberry lily</td>\n <td>4</td>\n </tr>\n </tbody>\n</table>\n</div>",
"text/plain": " class labels\n1 pink primrose 0\n10 globe thistle 1\n100 blanket flower 2\n101 trumpet creeper 3\n102 blackberry lily 4"
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "# Transformations\ntrain_data_transforms = transforms.Compose([\n transforms.RandomHorizontalFlip(), # randomly flip and rotate\n transforms.RandomVerticalFlip(),\n transforms.RandomRotation(10),\n transforms.RandomResizedCrop(224),\n transforms.ToTensor(),\n transforms.Normalize((0.485, 0.456, 0.406), \n (0.229, 0.224, 0.225))])\nvalid_data_transforms = transforms.Compose([\n transforms.Resize(256),\n transforms.CenterCrop(224),\n transforms.ToTensor(),\n transforms.Normalize((0.485, 0.456, 0.406), \n (0.229, 0.224, 0.225))])",
"execution_count": 5,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "# Number of subprocesses to use for data loading\nnum_workers = 0\n\n# Number of samples per batch to load\nbatch_size = 256\n\n# Load the datasets with ImageFolder\ntrain_data = datasets.ImageFolder(train_dir, transform=train_data_transforms)\nvalid_data = datasets.ImageFolder(valid_dir, transform=valid_data_transforms)\n\nvalid_batch_size = len(valid_data) #can be bigger since not optimizing grads\n\n# Using the image datasets and the transforms, define the dataloaders\ntrain_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,\n num_workers=num_workers, shuffle=True)\nvalid_loader = torch.utils.data.DataLoader(valid_data, batch_size=valid_batch_size,\n num_workers=num_workers, shuffle=True)\n# Statistics\nprint('Num training images: ', len(train_data))\nprint('Num valid images: ', len(valid_data))",
"execution_count": 6,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "Num training images: 7241\nNum valid images: 818\n"
}
]
},
{
"metadata": {
"scrolled": true,
"trusted": false
},
"cell_type": "code",
"source": "class_to_idx = {sorted(train_data.classes)[i]: i for i in range(len(train_data.classes))}\n{k: class_to_idx[k] for k in list(class_to_idx)[:5]}",
"execution_count": 7,
"outputs": [
{
"data": {
"text/plain": "{'1': 0, '10': 1, '100': 2, '101': 3, '102': 4}"
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "# Reverse dictionary, used in sanity check\nidx_to_class = {val: key for key, val in class_to_idx.items()}",
"execution_count": 8,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "# Plot an original image from each training folder, and corresponding labels\n# Choose 60 sub-folders at random\nsub_folders = np.random.choice(len(class_to_idx), 60, replace=False)+1\nimages=[]\nfor folder in sub_folders:\n contents = os.listdir(os.path.join(train_dir,str(folder)))\n images.append(Image.open(train_dir + '/' + str(folder) + '/' + contents[0]))\nfig = plt.figure(figsize=(25, 12))\nfor idx in np.arange(60):\n ax = fig.add_subplot(6, 60/6, idx+1, xticks=[], yticks=[])\n plt.imshow(images[idx])\n title = names['class'][class_to_idx[str(sub_folders[idx])]], names['labels'][class_to_idx[str(sub_folders[idx])]]\n ax.set_title(title[0] + ' ' + str(title[1]))",
"execution_count": 38,
"outputs": [
{
"data": {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment