Skip to content

Instantly share code, notes, and snippets.

@spezold
Last active November 28, 2024 12:46
Show Gist options
  • Save spezold/20e12871e29e1a03ff1a6bfcdcaac38a to your computer and use it in GitHub Desktop.
Save spezold/20e12871e29e1a03ff1a6bfcdcaac38a to your computer and use it in GitHub Desktop.
Demonstrate how a resource can be shared for exclusive access among the workers of a PyTorch dataloader, by distributing a corresponding lock.
import multiprocessing as mp
import time
from torch.utils.data import Dataset, DataLoader, get_worker_info
class LockedDataset(Dataset):
def __init__(self, gpu_id, lock, length):
self._gpu_id = gpu_id
self._lock = lock
self._length = length
self._worker_id = -1
def __getitem__(self, idx):
# Acquire the lock for returning the sample (the content of the sample doesn't really matter here)
with self._lock:
item = idx
from_t = time.time()
time.sleep(1.)
to_t = time.time()
print(f"GPU {self._gpu_id}, worker {self._worker_id}: {from_t % 1000 :06.2f}s – {to_t % 1000 :06.2f}s")
return item
def __len__(self):
return self._length
# This is just for printing the worker's ID and *not* necessary for the actual synchronization setup to work
def worker_init_fn(*args):
info = get_worker_info()
info.dataset._worker_id = info.id
def main(gpu_id, shared_lock, num_workers):
# Provide the lock to the dataset. The DataLoader will then distribute it among its worker processes.
# Provide one sample per worker (`length=num_workers`). Set `batch_size=1` to encourage each sample being loaded
# by a different worker (not sure if guaranteed though). Finally, actually load the samples.
dataset = LockedDataset(gpu_id, shared_lock, length=num_workers)
dataloader = DataLoader(dataset, batch_size=1, num_workers=num_workers, worker_init_fn=worker_init_fn)
for batch in dataloader:
pass # We don't have any use for the batches other than loading them
if __name__ == "__main__":
num_gpus = 4
num_workers_per_gpu = 3
try:
mp.set_start_method("spawn")
except RuntimeError as e:
# For convenience: "spawn" may have been set already with an earlier run (happens e.g. in Spyder IDE)
if not "already been set" in str(e):
raise
start = time.time()
shared_lock = mp.Lock()
jobs = []
for gpu_id in range(num_gpus):
# Launch one process per GPU
jobs.append(job := mp.Process(target=main, args=(gpu_id, shared_lock, num_workers_per_gpu)))
job.start()
for job in jobs:
job.join()
print(f"Took {time.time() - start:.2f}s. Check above:")
print(" - Exactly `num_gpus * num_workers_per_gpu` outputs altogether?")
print(" - Exactly one output for each combination of GPU and worker?")
print(" - No overlapping time intervals?")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment