Last active
January 23, 2024 06:45
-
-
Save ischlag/41d15424e7989b936c1609b53edd1390 to your computer and use it in GitHub Desktop.
Simple python script which takes the mnist data from tensorflow and builds a data set based on jpg files and text files containing the image paths and labels. Parts of it are from the mnist tensorflow example.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import gzip | |
import os | |
import sys | |
import time | |
from six.moves import urllib | |
from six.moves import xrange # pylint: disable=redefined-builtin | |
from scipy.misc import imsave | |
import tensorflow as tf | |
import numpy as np | |
import csv | |
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' | |
WORK_DIRECTORY = 'data' | |
IMAGE_SIZE = 28 | |
NUM_CHANNELS = 1 | |
PIXEL_DEPTH = 255 | |
NUM_LABELS = 10 | |
def maybe_download(filename): | |
"""Download the data from Yann's website, unless it's already here.""" | |
if not tf.gfile.Exists(WORK_DIRECTORY): | |
tf.gfile.MakeDirs(WORK_DIRECTORY) | |
filepath = os.path.join(WORK_DIRECTORY, filename) | |
if not tf.gfile.Exists(filepath): | |
filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath) | |
with tf.gfile.GFile(filepath) as f: | |
size = f.Size() | |
print('Successfully downloaded', filename, size, 'bytes.') | |
return filepath | |
def extract_data(filename, num_images): | |
"""Extract the images into a 4D tensor [image index, y, x, channels]. | |
Values are rescaled from [0, 255] down to [-0.5, 0.5]. | |
""" | |
print('Extracting', filename) | |
with gzip.open(filename) as bytestream: | |
bytestream.read(16) | |
buf = bytestream.read(IMAGE_SIZE * IMAGE_SIZE * num_images) | |
data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32) | |
#data = (data - (PIXEL_DEPTH / 2.0)) / PIXEL_DEPTH | |
data = data.reshape(num_images, IMAGE_SIZE, IMAGE_SIZE, 1) | |
return data | |
def extract_labels(filename, num_images): | |
"""Extract the labels into a vector of int64 label IDs.""" | |
print('Extracting', filename) | |
with gzip.open(filename) as bytestream: | |
bytestream.read(8) | |
buf = bytestream.read(1 * num_images) | |
labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64) | |
return labels | |
train_data_filename = maybe_download('train-images-idx3-ubyte.gz') | |
train_labels_filename = maybe_download('train-labels-idx1-ubyte.gz') | |
test_data_filename = maybe_download('t10k-images-idx3-ubyte.gz') | |
test_labels_filename = maybe_download('t10k-labels-idx1-ubyte.gz') | |
# Extract it into np arrays. | |
train_data = extract_data(train_data_filename, 60000) | |
train_labels = extract_labels(train_labels_filename, 60000) | |
test_data = extract_data(test_data_filename, 10000) | |
test_labels = extract_labels(test_labels_filename, 10000) | |
if not os.path.isdir("mnist/train-images"): | |
os.makedirs("mnist/train-images") | |
if not os.path.isdir("mnist/test-images"): | |
os.makedirs("mnist/test-images") | |
# process train data | |
with open("mnist/train-labels.csv", 'wb') as csvFile: | |
writer = csv.writer(csvFile, delimiter=',', quotechar='"') | |
for i in range(len(train_data)): | |
imsave("mnist/train-images/" + str(i) + ".jpg", train_data[i][:,:,0]) | |
writer.writerow(["train-images/" + str(i) + ".jpg", train_labels[i]]) | |
# repeat for test data | |
with open("mnist/test-labels.csv", 'wb') as csvFile: | |
writer = csv.writer(csvFile, delimiter=',', quotechar='"') | |
for i in range(len(test_data)): | |
imsave("mnist/test-images/" + str(i) + ".jpg", test_data[i][:,:,0]) | |
writer.writerow(["test-images/" + str(i) + ".jpg", test_labels[i]]) | |
In line 83,
writer.writerow(["train-images/" + str(i) + ".jpg", train_labels[i]])
PermissionError: [Errno 13] Permission denied
Please help! Thanks in advance
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Anyone knows how to do the programming for EMNIST? I think classes will change from 10 to 47, anything else? Because shape doesn't match.