Last active
June 24, 2025 13:14
-
-
Save angeligareta/e3332c7a955dba8eaca71bf388d028c2 to your computer and use it in GitHub Desktop.
Method to split a tensorflow dataset (tf.data.Dataset) into train, validation and test splits
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_dataset_partitions_tf(ds, ds_size, train_split=0.8, val_split=0.1, test_split=0.1, shuffle=True, shuffle_size=10000): | |
assert (train_split + test_split + val_split) == 1 | |
if shuffle: | |
# Specify seed to always have the same split distribution between runs | |
ds = ds.shuffle(shuffle_size, seed=12) | |
train_size = int(train_split * ds_size) | |
val_size = int(val_split * ds_size) | |
train_ds = ds.take(train_size) | |
val_ds = ds.skip(train_size).take(val_size) | |
test_ds = ds.skip(train_size).skip(val_size) | |
return train_ds, val_ds, test_ds |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi! I have a question related to this discussion and really appreciate it if anyone can help me. My training works fine when I use take() and skip() to split the dataset into train and test sets. But when I split data before making the datasets, the loss value on the test set does not go down as much during fitting. This is a pseudocode of what I am doing.
def train_generator():
# Yield train data
pass
def validation_generator():
# Yield validation data
pass
train_dataset = tf.data.Dataset.from_generator(train_generator, ...)
validation_dataset = tf.data.Dataset.from_generator(validation_generator, ...)
Am I missing something here?
Thanks.