-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_prefetcher.py
31 lines (25 loc) · 939 Bytes
/
data_prefetcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
class DataPrefetcher(object):
def __init__(self, loader):
self._loader = loader
self.loader = iter(self._loader)
self.stream = torch.cuda.Stream()
self.preload()
def preload(self):
try:
self.next_input, self.next_target = next(self.loader)
except StopIteration:
#self.next_input = None
#self.next_target = None
#return
self.loader = iter(self._loader)
self.next_input, self.next_target = next(self.loader)
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input
target = self.next_target
self.preload()
return input, target