Last active
April 12, 2021 02:26
-
-
Save andrewliao11/ba460909bd07548a30336a53da5461ea to your computer and use it in GitHub Desktop.
fix_for_numpy_rng_and_torch_dataset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# inspired by the post: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/ | |
# tl;dr | |
# If you are using numpy random generator with `torch.utils.data.Dataset`, | |
# you might get identical results either across different workers or epochs | |
# disclaimer: this might not be the best choice since setting worker to be persistent requires additional RAM. | |
# Welcome for any idea | |
# Here's a simple fix with torch>=1.7.0 | |
# See the original example here: https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/#a-minimal-example | |
import torch | |
import numpy as np | |
from torch.utils.data import Dataset, DataLoader | |
class RandomDataset(Dataset): | |
def __init__(self, num_workers): | |
base_seed = 123 | |
self.worker_id_to_rng = {i: np.random.RandomState(base_seed + i) for i in range(num_workers)} | |
def __getitem__(self, index): | |
worker_info = torch.utils.data.get_worker_info() | |
rng = self.worker_id_to_rng[worker_info.id] | |
return rng.randint(0, 1000, 3) | |
#return np.random.randint(0, 1000, 3) | |
def __len__(self): | |
return 16 | |
num_workers = 4 | |
dataset = RandomDataset(num_workers) | |
# torch>=1.7.0 | |
dataloader = DataLoader(dataset, batch_size=2, num_workers=num_workers, persistent_workers=True) | |
for epoch in range(2): | |
print(f'Epoch {epoch}') | |
for i, batch in enumerate(dataloader): | |
print(i, batch) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment