注意,通常不建议在多进程加载中返回 CUDA 张量,因为在使用 CUDA 和在多处理中共享 CUDA 张量时存在许多微妙之处(文档中提出:只要接收过程保留张量的副本,就需要发送过程来保留原始张量)。建议采用 pin_memory=True ,以将数据快速传输到支持 CUDA 的 GPU。简而言之,不建议在使用多线程的情况下返回 CUDA 的 Tensor。
def_reset(self, loader, first_iter=False): ... # prime the prefetch loop for _ inrange(self._prefetch_factor * self._num_workers): self._try_put_index()
def_next_index(self): returnnext(self._sampler_iter) # may raise StopIteration
def_next_data(self): raise NotImplementedError
def__next__(self) -> Any: with torch.autograd.profiler.record_function(self._profile_name): if self._sampler_iter isNone: self._reset() data = self._next_data() # 重点代码行,通过此获取数据 self._num_yielded += 1 ... return data
def_next_data(self): index = self._next_index() # may raise StopIteration data = self._dataset_fetcher.fetch(index) # may raise StopIteration if self._pin_memory: data = _utils.pin_memory.pin_memory(data) return data
deffetch(self, possibly_batched_index): if self.auto_collation: # 有batch_sampler,_auto_collation就为True, # 就优先使用batch_sampler,对应在fetcher中传入的就是一个batch的索引 data = [self.dataset[idx] for idx in possibly_batched_index] else: data = self.dataset[possibly_batched_index] return self.collate_fn(data)
· 对于 Iterable-style: init 方法内设置了 dataset 初始的迭代器,fetch 方法内获取元素,此时 index 其实已经没有多大作用了。
loader.iter --> self._get_iterator() --> class _SingleProcessDataLoaderIter --> class _BaseDataLoaderIter --> __next__() --> self._next_data() --> self._next_index() -->next(self._sampler_iter) 即 next(iter(self._index_sampler)) --> 获得 index --> self._dataset_fetcher.fetch(index) --> 获得 data
而对于多进程而言,借用 PyTorch 内源码的注释,其运行流程解释如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Our data model looks like this (queues are indicated with curly brackets): # # main process || # | || # {index_queue} || # | || # worker processes || DATA # | || # {worker_result_queue} || FLOW # | || # pin_memory_thread of main process || DIRECTION # | || # {data_queue} || # | || # data output \/ # # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if # `pin_memory=False`.
class_MultiProcessingDataLoaderIter(_BaseDataLoaderIter): def__init__(self, loader): super(_MultiProcessingDataLoaderIter, self).__init__(loader) ... self._worker_result_queue = multiprocessing_context.Queue() # 把该worker取出的数放入该队列,用于进程间通信 ... self._workers_done_event = multiprocessing_context.Event() self._index_queues = [] self._workers = [] for i inrange(self._num_workers): index_queue = multiprocessing_context.Queue() # 索引队列,每个子进程一个队列放要处理的下标 index_queue.cancel_join_thread() # _worker_loop 的作用是:从index_queue中取索引,然后通过collate_fn处理数据, # 然后再将处理好的 batch 数据放到 data_queue 中。(发送到队列中的idx是self.send_idx) w = multiprocessing_context.Process( target=_utils.worker._worker_loop, # 每个worker子进程循环执行的函数,主要将数据以(idx, data)的方式传入_worker_result_queue中 args=(self._dataset_kind, self._dataset, index_queue, self._worker_result_queue, self._workers_done_event, self._auto_collation, self._collate_fn, self._drop_last, self._base_seed + i, self._worker_init_fn, i, self._num_workers, self._persistent_workers)) w.daemon = True w.start() self._index_queues.append(index_queue) self._workers.append(w) if self._pin_memory: self._pin_memory_thread_done_event = threading.Event() self._data_queue = queue.Queue() # 用于存取出的数据进行 pin_memory 操作后的结果 pin_memory_thread = threading.Thread( target=_utils.pin_memory._pin_memory_loop, args=(self._worker_result_queue, self._data_queue, torch.cuda.current_device(), self._pin_memory_thread_done_event)) pin_memory_thread.daemon = True pin_memory_thread.start() # Similar to workers (see comment above), we only register # pin_memory_thread once it is started. self._pin_memory_thread = pin_memory_thread else: self._data_queue = self._worker_result_queue ... self._reset(loader, first_iter=True) def_reset(self, loader, first_iter=False): super()._reset(loader, first_iter) self._send_idx = 0# idx of the next task to be sent to workers,发送索引,用来记录这次要放 index_queue 中 batch 的 idx self._rcvd_idx = 0# idx of the next task to be returned in __next__,接受索引,记录要从 data_queue 中取出的 batch 的 idx # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx). # map: task idx => - (worker_id,) if data isn't fetched (outstanding) # \ (worker_id, data) if data is already fetched (out-of-order) self._task_info = {} # _tasks_outstanding 指示当前已经准备好的 task/batch 的数量(可能有些正在准备中) # 初始值为 0, 在 self._try_put_index() 中 +1,在 self._next_data 中-1 self._tasks_outstanding = 0# always equal to count(v for v in task_info.values() if len(v) == 1) # this indicates status that a worker still has work to do *for this epoch*. self._workers_status = [Truefor i inrange(self._num_workers)] # We resume the prefetching in case it was enabled ifnot first_iter: for idx inrange(self._num_workers): self._index_queues[idx].put(_utils.worker._ResumeIteration()) resume_iteration_cnt = self._num_workers while resume_iteration_cnt > 0: data = self._get_data() ifisinstance(data, _utils.worker._ResumeIteration): resume_iteration_cnt -= 1 ... # 初始化的时候,就将 2*num_workers 个 (batch_idx, sampler_indices) 放到 index_queue 中 for _ inrange(self._prefetch_factor * self._num_workers): self._try_put_index() # 进行预取
while self._rcvd_idx < self._send_idx: # 确保待处理的任务(待取的batch)下标 > 处理完毕要返回的任务(已经取完的batch)下标 info = self._task_info[self._rcvd_idx] worker_id = info[0] iflen(info) == 2or self._workers_status[worker_id]: # has data or is still active break del self._task_info[self._rcvd_idx] self._rcvd_idx += 1 else: # no valid `self._rcvd_idx` is found (i.e., didn't break) ifnot self._persistent_workers: self._shutdown_workers() raise StopIteration
# Now `self._rcvd_idx` is the batch index we want to fetch
# Check if the next sample has already been generated iflen(self._task_info[self._rcvd_idx]) == 2: data = self._task_info.pop(self._rcvd_idx)[1] return self._process_data(data)
assertnot self._shutdown and self._tasks_outstanding > 0 idx, data = self._get_data() # 调用 self._try_get_data() 从 self._data_queue 中取数 self._tasks_outstanding -= 1# 表明预备好的batch个数需要减1 if self._dataset_kind == _DatasetKind.Iterable: # Check for _IterableDatasetStopIteration ifisinstance(data, _utils.worker._IterableDatasetStopIteration): if self._persistent_workers: self._workers_status[data.worker_id] = False else: self._mark_worker_as_unavailable(data.worker_id) self._try_put_index() continue
if idx != self._rcvd_idx: # store out-of-order samples self._task_info[idx] += (data,) else: del self._task_info[idx] return self._process_data(data) # 返回数据