Skip to content

Instantly share code, notes, and snippets.

@etrigger
Forked from Taekyoon/test_multiprocess.py
Created February 20, 2022 16:10
Show Gist options
  • Save etrigger/a6552332b6e823208672bb08b0ec63ff to your computer and use it in GitHub Desktop.
Save etrigger/a6552332b6e823208672bb08b0ec63ff to your computer and use it in GitHub Desktop.
Test Pytorch multiprocess IterableDataset
import torch
import math
import time
SLEEP_TIME = 0.1
class MyMapDataset(torch.utils.data.Dataset):
def __init__(self, size):
self.dataset = [i for i in range(size)]
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
return self.dataset[index]
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, start, end):
super(MyIterableDataset).__init__()
assert end > start, "this example code only works with end >= start"
self.start = start
self.end = end
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single-process data loading, return the full iterator
iter_start = self.start
iter_end = self.end
else: # in a worker process
# split workload
per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
return iter(range(iter_start, iter_end))
class SleepCollate():
def __init__(self):
pass
def __call__(self, batch):
time.sleep(SLEEP_TIME)
return batch
# should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
ds = MyIterableDataset(start=0, end=1000)
collate = SleepCollate()
# Single-process loading
# print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
loader = torch.utils.data.DataLoader(ds, num_workers=0, batch_size=10, collate_fn=collate)
start = time.time()
d = list(loader)
print(len(d))
end = time.time()
print("single process duration: {}".format(end-start))
# [3, 4, 5, 6]
# Mult-process loading with two worker processes
# Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
# print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
loader = torch.utils.data.DataLoader(ds, num_workers=2, batch_size=10, collate_fn=collate)
start = time.time()
d = list(loader)
print(len(d))
end = time.time()
print("multi #2 process duration: {}".format(end-start))
# [3, 5, 4, 6]
# With even more workers
# print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
loader = torch.utils.data.DataLoader(ds, num_workers=20, batch_size=10, collate_fn=collate)
start = time.time()
d = list(loader)
print(len(d))
end = time.time()
print("multi #20 process duration: {}".format(end-start))
# [3, 4, 5, 6]
ds = MyMapDataset(1000)
loader = torch.utils.data.DataLoader(ds, num_workers=20, batch_size=10, collate_fn=collate)
start = time.time()
d = list(loader)
print(len(d))
end = time.time()
print("map multi #20 process duration: {}".format(end-start))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment