-
Notifications
You must be signed in to change notification settings - Fork 63
/
Copy pathdataloader.py
161 lines (132 loc) · 5.22 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# 2021.05.07-Changed for IPT
# Huawei Technologies Co., Ltd. <[email protected]>
import threading
import random
import torch
import torch.multiprocessing as multiprocessing
from torch.utils.data import DataLoader
from torch.utils.data import SequentialSampler
from torch.utils.data import RandomSampler
from torch.utils.data import BatchSampler
from torch.utils.data import _utils
from torch.utils.data.dataloader import _DataLoaderIter
from torch.utils.data._utils import collate
from torch.utils.data._utils import signal_handling
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
from torch.utils.data._utils import ExceptionWrapper
from torch.utils.data._utils import IS_WINDOWS
from torch.utils.data._utils.worker import ManagerWatchdog
from torch._six import queue
def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id):
try:
collate._use_shared_memory = True
signal_handling._set_worker_signal_handlers()
torch.set_num_threads(1)
random.seed(seed)
torch.manual_seed(seed)
data_queue.cancel_join_thread()
if init_fn is not None:
init_fn(worker_id)
watchdog = ManagerWatchdog()
while watchdog.is_alive():
try:
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
if r is None:
assert done_event.is_set()
return
elif done_event.is_set():
continue
idx, batch_indices = r
try:
idx_scale = 0
if len(scale) > 1 and dataset.train:
idx_scale = random.randrange(0, len(scale))
dataset.set_scale(idx_scale)
samples = collate_fn([dataset[i] for i in batch_indices])
samples.append(idx_scale)
except Exception:
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples))
del samples
except KeyboardInterrupt:
pass
class _MSDataLoaderIter(_DataLoaderIter):
def __init__(self, loader):
self.dataset = loader.dataset
self.scale = loader.scale
self.collate_fn = loader.collate_fn
self.batch_sampler = loader.batch_sampler
self.num_workers = loader.num_workers
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
self.timeout = loader.timeout
self.sample_iter = iter(self.batch_sampler)
base_seed = torch.LongTensor(1).random_().item()
if self.num_workers > 0:
self.worker_init_fn = loader.worker_init_fn
self.worker_queue_idx = 0
self.worker_result_queue = multiprocessing.Queue()
self.batches_outstanding = 0
self.worker_pids_set = False
self.shutdown = False
self.send_idx = 0
self.rcvd_idx = 0
self.reorder_dict = {}
self.done_event = multiprocessing.Event()
base_seed = torch.LongTensor(1).random_()[0]
self.index_queues = []
self.workers = []
for i in range(self.num_workers):
index_queue = multiprocessing.Queue()
index_queue.cancel_join_thread()
w = multiprocessing.Process(
target=_ms_loop,
args=(
self.dataset,
index_queue,
self.worker_result_queue,
self.done_event,
self.collate_fn,
self.scale,
base_seed + i,
self.worker_init_fn,
i
)
)
w.daemon = True
w.start()
self.index_queues.append(index_queue)
self.workers.append(w)
if self.pin_memory:
self.data_queue = queue.Queue()
pin_memory_thread = threading.Thread(
target=_utils.pin_memory._pin_memory_loop,
args=(
self.worker_result_queue,
self.data_queue,
torch.cuda.current_device(),
self.done_event
)
)
pin_memory_thread.daemon = True
pin_memory_thread.start()
self.pin_memory_thread = pin_memory_thread
else:
self.data_queue = self.worker_result_queue
_utils.signal_handling._set_worker_pids(
id(self), tuple(w.pid for w in self.workers)
)
_utils.signal_handling._set_SIGCHLD_handler()
self.worker_pids_set = True
for _ in range(2 * self.num_workers):
self._put_indices()
class MSDataLoader(DataLoader):
def __init__(self, cfg, *args, **kwargs):
super(MSDataLoader, self).__init__(
*args, **kwargs, num_workers=cfg.n_threads
)
self.scale = cfg.scale
def __iter__(self):
return _MSDataLoaderIter(self)