Skip to content

Instantly share code, notes, and snippets.

@Roffild
Created September 27, 2019 23:39
Show Gist options
  • Save Roffild/d7606bcf46ed1d05d0088c09930091da to your computer and use it in GitHub Desktop.
Save Roffild/d7606bcf46ed1d05d0088c09930091da to your computer and use it in GitHub Desktop.
from torch.utils.data.dataloader import _SingleProcessDataLoaderIter
class CUDAPrefetcher(_SingleProcessDataLoaderIter):
def __init__(self, loader, device=None, priority=0):
if not torch.cuda.is_available():
raise Exception("Only CUDA")
super(CUDAPrefetcher, self).__init__(loader)
self.device = device
self.stream = torch.cuda.Stream(device=device, priority=priority)
self.last = None
def __next__(self):
torch.cuda.default_stream(device=self.device).wait_stream(stream=self.stream)
result = self.last
if result is None:
result = super(CUDAPrefetcher, self).__next__() # may raise StopIteration
for x, d in enumerate(result):
result[x] = d.to(device=self.device, non_blocking=False)
try:
self.last = super(CUDAPrefetcher, self).__next__()
with torch.cuda.stream(stream=self.stream):
for x, d in enumerate(self.last):
self.last[x] = d.to(device=self.device, non_blocking=True)
except StopIteration:
self.last = None
return result
@Roffild
Copy link
Author

Roffild commented Sep 27, 2019

There is no acceleration.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment