You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/11/29 07:04:45 UTC

[GitHub] leezu commented on a change in pull request #13447: Rewrite dataloader, improves responsiveness and reliability

leezu commented on a change in pull request #13447: Rewrite dataloader, improves responsiveness and reliability
URL: https://github.com/apache/incubator-mxnet/pull/13447#discussion_r237371949
 
 

 ##########
 File path: python/mxnet/gluon/data/dataloader.py
 ##########
 @@ -390,8 +369,187 @@ def same_process_iter():
             return same_process_iter()
 
         # multi-worker
-        return _MultiWorkerIter(self._num_workers, self._dataset,
-                                self._batchify_fn, self._batch_sampler, self._pin_memory)
+        return _MultiWorkerIterV1(self._num_workers, self._dataset,
+                                  self._batchify_fn, self._batch_sampler, self._pin_memory)
+
+    def __len__(self):
+        return len(self._batch_sampler)
+
+_worker_dataset = None
+def _worker_fn(samples, batchify_fn):
+    """Function for processing data in worker process."""
+    # it is required that each worker process has to fork a new MXIndexedRecordIO handle
+    # preserving dataset as global variable can save tons of overhead and is safe in new process
+    global _worker_dataset
+    batch = batchify_fn([_worker_dataset[i] for i in samples])
+    batch = [batch] if not isinstance(batch, (list, tuple)) else batch
+    ret = [reduce_ndarray(x)[1] for x in batch]  # reduce_ndarray(x)[0] is the rebuild function
+    return ret
+
+class _MultiWorkerIter(object):
+    """Internal multi-worker iterator for DataLoader."""
+    def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
+                 worker_fn=_worker_fn, prefetch=0):
+        self._worker_pool = worker_pool
+        self._batchify_fn = batchify_fn
+        self._batch_sampler = batch_sampler
+        self._data_buffer = {}
+        self._rcvd_idx = 0
+        self._sent_idx = 0
+        self._iter = iter(self._batch_sampler)
+        self._worker_fn = worker_fn
+        self._pin_memory = pin_memory
+        # pre-fetch
+        for _ in range(prefetch):
+            self._push_next()
+
+    def __len__(self):
+        return len(self._batch_sampler)
+
+    def _push_next(self):
+        """Assign next batch workload to workers."""
+        r = next(self._iter, None)
+        if r is None:
+            return
+        async_ret = self._worker_pool.apply_async(self._worker_fn, (r, self._batchify_fn))
+        self._data_buffer[self._sent_idx] = async_ret
+        self._sent_idx += 1
+
+    def __next__(self):
+        self._push_next()
+        if self._rcvd_idx == self._sent_idx:
+            assert not self._data_buffer, "Data buffer should be empty at this moment"
+            raise StopIteration
+
+        assert self._rcvd_idx < self._sent_idx, "rcvd_idx must be smaller than sent_idx"
+        assert self._rcvd_idx in self._data_buffer, "fatal error with _push_next, rcvd_idx missing"
+        ret = self._data_buffer.pop(self._rcvd_idx)
+        shared_batch = ret.get()
+        batch = tuple([rebuild_ndarray(*x) for x in shared_batch])
+        if self._pin_memory:
+            batch = _as_in_context(batch, context.cpu_pinned())
+        batch = batch[0] if len(batch) == 1 else batch
+        self._rcvd_idx += 1
+        return batch
+
+    def next(self):
+        return self.__next__()
+
+    def __iter__(self):
+        return self
+
+
+class DataLoader(object):
+    """Loads data from a dataset and returns mini-batches of data.
+
+    Parameters
+    ----------
+    dataset : Dataset
+        Source dataset. Note that numpy and mxnet arrays can be directly used
+        as a Dataset.
+    batch_size : int
+        Size of mini-batch.
+    shuffle : bool
+        Whether to shuffle the samples.
+    sampler : Sampler
+        The sampler to use. Either specify sampler or shuffle, not both.
+    last_batch : {'keep', 'discard', 'rollover'}
+        How to handle the last batch if batch_size does not evenly divide
+        `len(dataset)`.
+
+        keep - A batch with less samples than previous batches is returned.
+        discard - The last batch is discarded if its incomplete.
+        rollover - The remaining samples are rolled over to the next epoch.
+    batch_sampler : Sampler
+        A sampler that returns mini-batches. Do not specify batch_size,
+        shuffle, sampler, and last_batch if batch_sampler is specified.
+    batchify_fn : callable
+        Callback function to allow users to specify how to merge samples
+        into a batch. Defaults to `default_batchify_fn`::
+
+            def default_batchify_fn(data):
+                if isinstance(data[0], nd.NDArray):
+                    return nd.stack(*data)
+                elif isinstance(data[0], tuple):
+                    data = zip(*data)
+                    return [default_batchify_fn(i) for i in data]
+                else:
+                    data = np.asarray(data)
+                    return nd.array(data, dtype=data.dtype)
+
+    num_workers : int, default 0
+        The number of multiprocessing workers to use for data preprocessing.
+    pin_memory : boolean, default False
+        If ``True``, the dataloader will copy NDArrays into pinned memory
+        before returning them. Copying from CPU pinned memory to GPU is faster
+        than from normal CPU memory.
+    prefetch : int, default is `num_workers * 2`
+        The number of prefetching batches only works if `num_workers` > 0.
+        If `prefetch` > 0, it allow worker process to prefetch certain batches before
+        acquiring data from iterators.
+        Note that using large prefetching batch will provide smoother bootstrapping performance,
+        but will consume more shared_memory. Using smaller number may forfeit the purpose of using
+        multiple worker processes, try reduce `num_workers` in this case.
+        By default it defaults to `num_workers * 2`.
+    """
+    def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
+                 last_batch=None, batch_sampler=None, batchify_fn=None,
+                 num_workers=0, pin_memory=False, prefetch=None):
+        self._dataset = dataset
+        self._pin_memory = pin_memory
+
+        if batch_sampler is None:
+            if batch_size is None:
+                raise ValueError("batch_size must be specified unless " \
+                                 "batch_sampler is specified")
+            if sampler is None:
+                if shuffle:
+                    sampler = _sampler.RandomSampler(len(dataset))
+                else:
+                    sampler = _sampler.SequentialSampler(len(dataset))
+            elif shuffle:
+                raise ValueError("shuffle must not be specified if sampler is specified")
+
+            batch_sampler = _sampler.BatchSampler(
+                sampler, batch_size, last_batch if last_batch else 'keep')
+        elif batch_size is not None or shuffle or sampler is not None or \
+                last_batch is not None:
+            raise ValueError("batch_size, shuffle, sampler and last_batch must " \
+                             "not be specified if batch_sampler is specified.")
+
+        self._batch_sampler = batch_sampler
+        self._num_workers = num_workers if num_workers >= 0 else 0
+        self._worker_pool = None
+        self._prefetch = max(0, int(prefetch) if prefetch is not None else 2 * self._num_workers)
+        if self._num_workers > 0:
+            def worker_initializer(data):
+                global _worker_dataset
 
 Review comment:
   Won't this break when using multiple DataLoader with different datasets?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services