Skip to content

Instantly share code, notes, and snippets.

@angeligareta
Last active June 24, 2025 13:14
Show Gist options
  • Save angeligareta/e3332c7a955dba8eaca71bf388d028c2 to your computer and use it in GitHub Desktop.
Save angeligareta/e3332c7a955dba8eaca71bf388d028c2 to your computer and use it in GitHub Desktop.
Method to split a tensorflow dataset (tf.data.Dataset) into train, validation and test splits
def get_dataset_partitions_tf(ds, ds_size, train_split=0.8, val_split=0.1, test_split=0.1, shuffle=True, shuffle_size=10000):
assert (train_split + test_split + val_split) == 1
if shuffle:
# Specify seed to always have the same split distribution between runs
ds = ds.shuffle(shuffle_size, seed=12)
train_size = int(train_split * ds_size)
val_size = int(val_split * ds_size)
train_ds = ds.take(train_size)
val_ds = ds.skip(train_size).take(val_size)
test_ds = ds.skip(train_size).skip(val_size)
return train_ds, val_ds, test_ds
@mghahremanpour
Copy link

Hi! I have a question related to this discussion and really appreciate it if anyone can help me. My training works fine when I use take() and skip() to split the dataset into train and test sets. But when I split data before making the datasets, the loss value on the test set does not go down as much during fitting. This is a pseudocode of what I am doing.

def train_generator():
# Yield train data
pass

def validation_generator():
# Yield validation data
pass

train_dataset = tf.data.Dataset.from_generator(train_generator, ...)
validation_dataset = tf.data.Dataset.from_generator(validation_generator, ...)

Am I missing something here?

Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment