Skip to content

Instantly share code, notes, and snippets.

@angeligareta
Last active June 24, 2025 13:14
Show Gist options
  • Save angeligareta/e3332c7a955dba8eaca71bf388d028c2 to your computer and use it in GitHub Desktop.
Save angeligareta/e3332c7a955dba8eaca71bf388d028c2 to your computer and use it in GitHub Desktop.
Method to split a tensorflow dataset (tf.data.Dataset) into train, validation and test splits
def get_dataset_partitions_tf(ds, ds_size, train_split=0.8, val_split=0.1, test_split=0.1, shuffle=True, shuffle_size=10000):
assert (train_split + test_split + val_split) == 1
if shuffle:
# Specify seed to always have the same split distribution between runs
ds = ds.shuffle(shuffle_size, seed=12)
train_size = int(train_split * ds_size)
val_size = int(val_split * ds_size)
train_ds = ds.take(train_size)
val_ds = ds.skip(train_size).take(val_size)
test_ds = ds.skip(train_size).skip(val_size)
return train_ds, val_ds, test_ds
@hsparkastro
Copy link

ds.shuffle, without the additional parameter shuffle_each_iteration=False, will shuffle the dataset in each iteration before splitting into three separate datasets. This will cause the the three sets to be different every iteration, and a datapoint that was in val_ds could be in train_ds in the next iteration.

@Dennis-Malonza
Copy link

hello guys, am currently working on my tensorflow model to fit into a CNN model but then the problem am experiencing is that my kernel is not allow me to visualize my dataset. numpys are working prety good but when I run imshow my kernel says it's dead and will restart again,this problem is happening again and again even after restarting the kernel and running all the cells. Kindly, help me guys i'll appreciate so much when i get someone help solve this problem

@mghahremanpour
Copy link

Hi! I have a question related to this discussion and really appreciate it if anyone can help me. My training works fine when I use take() and skip() to split the dataset into train and test sets. But when I split data before making the datasets, the loss value on the test set does not go down as much during fitting. This is a pseudocode of what I am doing.

def train_generator():
# Yield train data
pass

def validation_generator():
# Yield validation data
pass

train_dataset = tf.data.Dataset.from_generator(train_generator, ...)
validation_dataset = tf.data.Dataset.from_generator(validation_generator, ...)

Am I missing something here?

Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment