You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/11/30 18:09:49 UTC

[GitHub] eric-haibin-lin closed pull request #13447: Rewrite dataloader, improves responsiveness and reliability

eric-haibin-lin closed pull request #13447: Rewrite dataloader, improves responsiveness and reliability
URL: https://github.com/apache/incubator-mxnet/pull/13447
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py
index 86cb835f512..ad0f534d16d 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -36,7 +36,6 @@
 
 from . import sampler as _sampler
 from ... import nd, context
-from ...recordio import MXRecordIO
 
 if sys.platform == 'darwin' or sys.platform == 'win32':
     def rebuild_ndarray(*args):
@@ -159,29 +158,9 @@ def _as_in_context(data, ctx):
         return [_as_in_context(d, ctx) for d in data]
     return data
 
-def _recursive_fork_recordio(obj, depth, max_depth=1000):
-    """Recursively find instance of MXRecordIO and reset file handler.
-    This is required for MXRecordIO which holds a C pointer to a opened file after fork.
-    """
-    if depth >= max_depth:
-        return
-    if isinstance(obj, MXRecordIO):
-        obj.close()
-        obj.open()  # re-obtain file hanlder in new process
-    elif (hasattr(obj, '__dict__')):
-        for _, v in obj.__dict__.items():
-            _recursive_fork_recordio(v, depth + 1, max_depth)
-
-def worker_loop(dataset, key_queue, data_queue, batchify_fn):
-    """Worker loop for multiprocessing DataLoader."""
-    # re-fork a new recordio handler in new process if applicable
-    # for a dataset with transform function, the depth of MXRecordIO is 1
-    # for a lazy transformer, the depth is 2
-    # for a user defined transformer, the depth is unknown, try a reasonable depth
-    limit = sys.getrecursionlimit()
-    max_recursion_depth = min(limit - 5, max(10, limit // 2))
-    _recursive_fork_recordio(dataset, 0, max_recursion_depth)
 
+def worker_loop_v1(dataset, key_queue, data_queue, batchify_fn):
+    """Worker loop for multiprocessing DataLoader."""
     while True:
         idx, samples = key_queue.get()
         if idx is None:
@@ -189,7 +168,7 @@ def worker_loop(dataset, key_queue, data_queue, batchify_fn):
         batch = batchify_fn([dataset[i] for i in samples])
         data_queue.put((idx, batch))
 
-def fetcher_loop(data_queue, data_buffer, pin_memory=False, data_buffer_lock=None):
+def fetcher_loop_v1(data_queue, data_buffer, pin_memory=False, data_buffer_lock=None):
     """Fetcher loop for fetching data from queue and put in reorder dict."""
     while True:
         idx, batch = data_queue.get()
@@ -206,10 +185,10 @@ def fetcher_loop(data_queue, data_buffer, pin_memory=False, data_buffer_lock=Non
             data_buffer[idx] = batch
 
 
-class _MultiWorkerIter(object):
-    """Interal multi-worker iterator for DataLoader."""
+class _MultiWorkerIterV1(object):
+    """Internal multi-worker iterator for DataLoader."""
     def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False,
-                 worker_fn=worker_loop):
+                 worker_fn=worker_loop_v1):
         assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers)
         self._num_workers = num_workers
         self._dataset = dataset
@@ -237,7 +216,7 @@ def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=
         self._workers = workers
 
         self._fetcher = threading.Thread(
-            target=fetcher_loop,
+            target=fetcher_loop_v1,
             args=(self._data_queue, self._data_buffer, pin_memory, self._data_buffer_lock))
         self._fetcher.daemon = True
         self._fetcher.start()
@@ -299,7 +278,7 @@ def shutdown(self):
             self._shutdown = True
 
 
-class DataLoader(object):
+class DataLoaderV1(object):
     """Loads data from a dataset and returns mini-batches of data.
 
     Parameters
@@ -390,8 +369,190 @@ def same_process_iter():
             return same_process_iter()
 
         # multi-worker
-        return _MultiWorkerIter(self._num_workers, self._dataset,
-                                self._batchify_fn, self._batch_sampler, self._pin_memory)
+        return _MultiWorkerIterV1(self._num_workers, self._dataset,
+                                  self._batchify_fn, self._batch_sampler, self._pin_memory)
+
+    def __len__(self):
+        return len(self._batch_sampler)
+
+_worker_dataset = None
+def _worker_initializer(dataset):
+    """Initialier for processing pool."""
+    # global dataset is per-process based and only available in worker processes
+    # this is only necessary to handle MXIndexedRecordIO because otherwise dataset
+    # can be passed as argument
+    global _worker_dataset
+    _worker_dataset = dataset
+
+def _worker_fn(samples, batchify_fn):
+    """Function for processing data in worker process."""
+    # it is required that each worker process has to fork a new MXIndexedRecordIO handle
+    # preserving dataset as global variable can save tons of overhead and is safe in new process
+    global _worker_dataset
+    batch = batchify_fn([_worker_dataset[i] for i in samples])
+    buf = io.BytesIO()
+    ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(batch)
+    return buf.getvalue()
+
+class _MultiWorkerIter(object):
+    """Internal multi-worker iterator for DataLoader."""
+    def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
+                 worker_fn=_worker_fn, prefetch=0):
+        self._worker_pool = worker_pool
+        self._batchify_fn = batchify_fn
+        self._batch_sampler = batch_sampler
+        self._data_buffer = {}
+        self._rcvd_idx = 0
+        self._sent_idx = 0
+        self._iter = iter(self._batch_sampler)
+        self._worker_fn = worker_fn
+        self._pin_memory = pin_memory
+        # pre-fetch
+        for _ in range(prefetch):
+            self._push_next()
+
+    def __len__(self):
+        return len(self._batch_sampler)
+
+    def _push_next(self):
+        """Assign next batch workload to workers."""
+        r = next(self._iter, None)
+        if r is None:
+            return
+        async_ret = self._worker_pool.apply_async(self._worker_fn, (r, self._batchify_fn))
+        self._data_buffer[self._sent_idx] = async_ret
+        self._sent_idx += 1
+
+    def __next__(self):
+        self._push_next()
+        if self._rcvd_idx == self._sent_idx:
+            assert not self._data_buffer, "Data buffer should be empty at this moment"
+            raise StopIteration
+
+        assert self._rcvd_idx < self._sent_idx, "rcvd_idx must be smaller than sent_idx"
+        assert self._rcvd_idx in self._data_buffer, "fatal error with _push_next, rcvd_idx missing"
+        ret = self._data_buffer.pop(self._rcvd_idx)
+        batch = pickle.loads(ret.get())
+        if self._pin_memory:
+            batch = _as_in_context(batch, context.cpu_pinned())
+        batch = batch[0] if len(batch) == 1 else batch
+        self._rcvd_idx += 1
+        return batch
+
+    def next(self):
+        return self.__next__()
+
+    def __iter__(self):
+        return self
+
+
+class DataLoader(object):
+    """Loads data from a dataset and returns mini-batches of data.
+
+    Parameters
+    ----------
+    dataset : Dataset
+        Source dataset. Note that numpy and mxnet arrays can be directly used
+        as a Dataset.
+    batch_size : int
+        Size of mini-batch.
+    shuffle : bool
+        Whether to shuffle the samples.
+    sampler : Sampler
+        The sampler to use. Either specify sampler or shuffle, not both.
+    last_batch : {'keep', 'discard', 'rollover'}
+        How to handle the last batch if batch_size does not evenly divide
+        `len(dataset)`.
+
+        keep - A batch with less samples than previous batches is returned.
+        discard - The last batch is discarded if its incomplete.
+        rollover - The remaining samples are rolled over to the next epoch.
+    batch_sampler : Sampler
+        A sampler that returns mini-batches. Do not specify batch_size,
+        shuffle, sampler, and last_batch if batch_sampler is specified.
+    batchify_fn : callable
+        Callback function to allow users to specify how to merge samples
+        into a batch. Defaults to `default_batchify_fn`::
+
+            def default_batchify_fn(data):
+                if isinstance(data[0], nd.NDArray):
+                    return nd.stack(*data)
+                elif isinstance(data[0], tuple):
+                    data = zip(*data)
+                    return [default_batchify_fn(i) for i in data]
+                else:
+                    data = np.asarray(data)
+                    return nd.array(data, dtype=data.dtype)
+
+    num_workers : int, default 0
+        The number of multiprocessing workers to use for data preprocessing.
+    pin_memory : boolean, default False
+        If ``True``, the dataloader will copy NDArrays into pinned memory
+        before returning them. Copying from CPU pinned memory to GPU is faster
+        than from normal CPU memory.
+    prefetch : int, default is `num_workers * 2`
+        The number of prefetching batches only works if `num_workers` > 0.
+        If `prefetch` > 0, it allow worker process to prefetch certain batches before
+        acquiring data from iterators.
+        Note that using large prefetching batch will provide smoother bootstrapping performance,
+        but will consume more shared_memory. Using smaller number may forfeit the purpose of using
+        multiple worker processes, try reduce `num_workers` in this case.
+        By default it defaults to `num_workers * 2`.
+    """
+    def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
+                 last_batch=None, batch_sampler=None, batchify_fn=None,
+                 num_workers=0, pin_memory=False, prefetch=None):
+        self._dataset = dataset
+        self._pin_memory = pin_memory
+
+        if batch_sampler is None:
+            if batch_size is None:
+                raise ValueError("batch_size must be specified unless " \
+                                 "batch_sampler is specified")
+            if sampler is None:
+                if shuffle:
+                    sampler = _sampler.RandomSampler(len(dataset))
+                else:
+                    sampler = _sampler.SequentialSampler(len(dataset))
+            elif shuffle:
+                raise ValueError("shuffle must not be specified if sampler is specified")
+
+            batch_sampler = _sampler.BatchSampler(
+                sampler, batch_size, last_batch if last_batch else 'keep')
+        elif batch_size is not None or shuffle or sampler is not None or \
+                last_batch is not None:
+            raise ValueError("batch_size, shuffle, sampler and last_batch must " \
+                             "not be specified if batch_sampler is specified.")
+
+        self._batch_sampler = batch_sampler
+        self._num_workers = num_workers if num_workers >= 0 else 0
+        self._worker_pool = None
+        self._prefetch = max(0, int(prefetch) if prefetch is not None else 2 * self._num_workers)
+        if self._num_workers > 0:
+            self._worker_pool = multiprocessing.Pool(
+                self._num_workers, initializer=_worker_initializer, initargs=[self._dataset])
+        if batchify_fn is None:
+            if num_workers > 0:
+                self._batchify_fn = default_mp_batchify_fn
+            else:
+                self._batchify_fn = default_batchify_fn
+        else:
+            self._batchify_fn = batchify_fn
+
+    def __iter__(self):
+        if self._num_workers == 0:
+            def same_process_iter():
+                for batch in self._batch_sampler:
+                    ret = self._batchify_fn([self._dataset[idx] for idx in batch])
+                    if self._pin_memory:
+                        ret = _as_in_context(ret, context.cpu_pinned())
+                    yield ret
+            return same_process_iter()
+
+        # multi-worker
+        return _MultiWorkerIter(self._worker_pool, self._batchify_fn, self._batch_sampler,
+                                pin_memory=self._pin_memory, worker_fn=_worker_fn,
+                                prefetch=self._prefetch)
 
     def __len__(self):
         return len(self._batch_sampler)
diff --git a/python/mxnet/recordio.py b/python/mxnet/recordio.py
index 2def141c934..bdc63235d70 100644
--- a/python/mxnet/recordio.py
+++ b/python/mxnet/recordio.py
@@ -18,6 +18,7 @@
 """Read and write for the RecordIO data format."""
 from __future__ import absolute_import
 from collections import namedtuple
+from multiprocessing import current_process
 
 import ctypes
 import struct
@@ -65,6 +66,7 @@ def __init__(self, uri, flag):
         self.uri = c_str(uri)
         self.handle = RecordIOHandle()
         self.flag = flag
+        self.pid = None
         self.is_open = False
         self.open()
 
@@ -78,6 +80,7 @@ def open(self):
             self.writable = False
         else:
             raise ValueError("Invalid flag %s"%self.flag)
+        self.pid = current_process().pid
         self.is_open = True
 
     def __del__(self):
@@ -109,6 +112,14 @@ def __setstate__(self, d):
         if is_open:
             self.open()
 
+    def _check_pid(self, allow_reset=False):
+        """Check process id to ensure integrity, reset if in new process."""
+        if not self.pid == current_process().pid:
+            if allow_reset:
+                self.reset()
+            else:
+                raise RuntimeError("Forbidden operation in multiple processes")
+
     def close(self):
         """Closes the record file."""
         if not self.is_open:
@@ -118,6 +129,7 @@ def close(self):
         else:
             check_call(_LIB.MXRecordIOReaderFree(self.handle))
         self.is_open = False
+        self.pid = None
 
     def reset(self):
         """Resets the pointer to first item.
@@ -156,6 +168,7 @@ def write(self, buf):
             Buffer to write.
         """
         assert self.writable
+        self._check_pid(allow_reset=False)
         check_call(_LIB.MXRecordIOWriterWriteRecord(self.handle,
                                                     ctypes.c_char_p(buf),
                                                     ctypes.c_size_t(len(buf))))
@@ -182,6 +195,9 @@ def read(self):
             Buffer read.
         """
         assert not self.writable
+        # trying to implicitly read from multiple processes is forbidden,
+        # there's no elegant way to handle unless lock is introduced
+        self._check_pid(allow_reset=False)
         buf = ctypes.c_char_p()
         size = ctypes.c_size_t()
         check_call(_LIB.MXRecordIOReaderReadRecord(self.handle,
@@ -255,6 +271,7 @@ def seek(self, idx):
         This function is internally called by `read_idx(idx)` to find the current
         reader pointer position. It doesn't return anything."""
         assert not self.writable
+        self._check_pid(allow_reset=True)
         pos = ctypes.c_size_t(self.idx[idx])
         check_call(_LIB.MXRecordIOReaderSeek(self.handle, pos))
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services