Created
August 31, 2018 12:14
-
-
Save jkjung-avt/48a0e59ea970c9ea6b44bf789e511121 to your computer and use it in GitHub Desktop.
An example multiprocessing-ready data generator for Keras, taken from https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import keras | |
class DataGenerator(keras.utils.Sequence): | |
'Generates data for Keras' | |
def __init__(self, list_IDs, labels, batch_size=32, dim=(32,32,32), n_channels=1, | |
n_classes=10, shuffle=True): | |
'Initialization' | |
self.dim = dim | |
self.batch_size = batch_size | |
self.labels = labels | |
self.list_IDs = list_IDs | |
self.n_channels = n_channels | |
self.n_classes = n_classes | |
self.shuffle = shuffle | |
self.on_epoch_end() | |
def __len__(self): | |
'Denotes the number of batches per epoch' | |
return int(np.floor(len(self.list_IDs) / self.batch_size)) | |
def __getitem__(self, index): | |
'Generate one batch of data' | |
# Generate indexes of the batch | |
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size] | |
# Find list of IDs | |
list_IDs_temp = [self.list_IDs[k] for k in indexes] | |
# Generate data | |
X, y = self.__data_generation(list_IDs_temp) | |
return X, y | |
def on_epoch_end(self): | |
'Updates indexes after each epoch' | |
self.indexes = np.arange(len(self.list_IDs)) | |
if self.shuffle == True: | |
np.random.shuffle(self.indexes) | |
def __data_generation(self, list_IDs_temp): | |
'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels) | |
# Initialization | |
X = np.empty((self.batch_size, *self.dim, self.n_channels)) | |
y = np.empty((self.batch_size), dtype=int) | |
# Generate data | |
for i, ID in enumerate(list_IDs_temp): | |
# Store sample | |
X[i,] = np.load('data/' + ID + '.npy') | |
# Store class | |
y[i] = self.labels[ID] | |
return X, keras.utils.to_categorical(y, num_classes=self.n_classes) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from keras.models import Sequential | |
from my_classes import DataGenerator | |
# Parameters | |
params = {'dim': (32,32,32), | |
'batch_size': 64, | |
'n_classes': 6, | |
'n_channels': 1, | |
'shuffle': True} | |
# Datasets | |
partition = # IDs | |
labels = # Labels | |
# Generators | |
training_generator = DataGenerator(partition['train'], labels, **params) | |
validation_generator = DataGenerator(partition['validation'], labels, **params) | |
# Design model | |
model = Sequential() | |
[...] # Architecture | |
model.compile() | |
# Train model on dataset | |
model.fit_generator(generator=training_generator, | |
validation_data=validation_generator, | |
use_multiprocessing=True, | |
workers=6) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment