Skip to content

Instantly share code, notes, and snippets.

@andrewliao11
Last active April 12, 2021 02:26
Show Gist options
  • Save andrewliao11/ba460909bd07548a30336a53da5461ea to your computer and use it in GitHub Desktop.
Save andrewliao11/ba460909bd07548a30336a53da5461ea to your computer and use it in GitHub Desktop.
fix_for_numpy_rng_and_torch_dataset
# inspired by the post: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/
# tl;dr
# If you are using numpy random generator with `torch.utils.data.Dataset`,
# you might get identical results either across different workers or epochs
# disclaimer: this might not be the best choice since setting worker to be persistent requires additional RAM.
# Welcome for any idea
# Here's a simple fix with torch>=1.7.0
# See the original example here: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/#a-minimal-example
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
class RandomDataset(Dataset):
def __init__(self, num_workers):
base_seed = 123
self.worker_id_to_rng = {i: np.random.RandomState(base_seed + i) for i in range(num_workers)}
def __getitem__(self, index):
worker_info = torch.utils.data.get_worker_info()
rng = self.worker_id_to_rng[worker_info.id]
return rng.randint(0, 1000, 3)
#return np.random.randint(0, 1000, 3)
def __len__(self):
return 16
num_workers = 4
dataset = RandomDataset(num_workers)
# torch>=1.7.0
dataloader = DataLoader(dataset, batch_size=2, num_workers=num_workers, persistent_workers=True)
for epoch in range(2):
print(f'Epoch {epoch}')
for i, batch in enumerate(dataloader):
print(i, batch)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment