You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/11/30 18:10:03 UTC

[incubator-mxnet] branch master updated: Rewrite dataloader with process pool, improves responsiveness and reliability (#13447)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 883d771  Rewrite dataloader with process pool, improves responsiveness and reliability (#13447)
883d771 is described below

commit 883d7712af9d3d6f3a112746c6a318f5f2677b7c
Author: Joshua Z. Zhang <ch...@gmail.com>
AuthorDate: Fri Nov 30 10:09:47 2018 -0800

    Rewrite dataloader with process pool, improves responsiveness and reliability (#13447)
    
    * fix recordio.py
    
    * rewrite dataloader with pool
    
    * fix batch as tuple
    
    * fix prefetching
    
    * fix pylint
    
    * picklable function
    
    * use pickle
    
    * add missing commit
---
 python/mxnet/gluon/data/dataloader.py | 223 +++++++++++++++++++++++++++++-----
 python/mxnet/recordio.py              |  17 +++
 2 files changed, 209 insertions(+), 31 deletions(-)

diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py
index 86cb835f..ad0f534 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -36,7 +36,6 @@ except ImportError:
 
 from . import sampler as _sampler
 from ... import nd, context
-from ...recordio import MXRecordIO
 
 if sys.platform == 'darwin' or sys.platform == 'win32':
     def rebuild_ndarray(*args):
@@ -159,29 +158,9 @@ def _as_in_context(data, ctx):
         return [_as_in_context(d, ctx) for d in data]
     return data
 
-def _recursive_fork_recordio(obj, depth, max_depth=1000):
-    """Recursively find instance of MXRecordIO and reset file handler.
-    This is required for MXRecordIO which holds a C pointer to a opened file after fork.
-    """
-    if depth >= max_depth:
-        return
-    if isinstance(obj, MXRecordIO):
-        obj.close()
-        obj.open()  # re-obtain file hanlder in new process
-    elif (hasattr(obj, '__dict__')):
-        for _, v in obj.__dict__.items():
-            _recursive_fork_recordio(v, depth + 1, max_depth)
-
-def worker_loop(dataset, key_queue, data_queue, batchify_fn):
-    """Worker loop for multiprocessing DataLoader."""
-    # re-fork a new recordio handler in new process if applicable
-    # for a dataset with transform function, the depth of MXRecordIO is 1
-    # for a lazy transformer, the depth is 2
-    # for a user defined transformer, the depth is unknown, try a reasonable depth
-    limit = sys.getrecursionlimit()
-    max_recursion_depth = min(limit - 5, max(10, limit // 2))
-    _recursive_fork_recordio(dataset, 0, max_recursion_depth)
 
+def worker_loop_v1(dataset, key_queue, data_queue, batchify_fn):
+    """Worker loop for multiprocessing DataLoader."""
     while True:
         idx, samples = key_queue.get()
         if idx is None:
@@ -189,7 +168,7 @@ def worker_loop(dataset, key_queue, data_queue, batchify_fn):
         batch = batchify_fn([dataset[i] for i in samples])
         data_queue.put((idx, batch))
 
-def fetcher_loop(data_queue, data_buffer, pin_memory=False, data_buffer_lock=None):
+def fetcher_loop_v1(data_queue, data_buffer, pin_memory=False, data_buffer_lock=None):
     """Fetcher loop for fetching data from queue and put in reorder dict."""
     while True:
         idx, batch = data_queue.get()
@@ -206,10 +185,10 @@ def fetcher_loop(data_queue, data_buffer, pin_memory=False, data_buffer_lock=Non
             data_buffer[idx] = batch
 
 
-class _MultiWorkerIter(object):
-    """Interal multi-worker iterator for DataLoader."""
+class _MultiWorkerIterV1(object):
+    """Internal multi-worker iterator for DataLoader."""
     def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False,
-                 worker_fn=worker_loop):
+                 worker_fn=worker_loop_v1):
         assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers)
         self._num_workers = num_workers
         self._dataset = dataset
@@ -237,7 +216,7 @@ class _MultiWorkerIter(object):
         self._workers = workers
 
         self._fetcher = threading.Thread(
-            target=fetcher_loop,
+            target=fetcher_loop_v1,
             args=(self._data_queue, self._data_buffer, pin_memory, self._data_buffer_lock))
         self._fetcher.daemon = True
         self._fetcher.start()
@@ -299,7 +278,7 @@ class _MultiWorkerIter(object):
             self._shutdown = True
 
 
-class DataLoader(object):
+class DataLoaderV1(object):
     """Loads data from a dataset and returns mini-batches of data.
 
     Parameters
@@ -390,8 +369,190 @@ class DataLoader(object):
             return same_process_iter()
 
         # multi-worker
-        return _MultiWorkerIter(self._num_workers, self._dataset,
-                                self._batchify_fn, self._batch_sampler, self._pin_memory)
+        return _MultiWorkerIterV1(self._num_workers, self._dataset,
+                                  self._batchify_fn, self._batch_sampler, self._pin_memory)
+
+    def __len__(self):
+        return len(self._batch_sampler)
+
+_worker_dataset = None
+def _worker_initializer(dataset):
+    """Initialier for processing pool."""
+    # global dataset is per-process based and only available in worker processes
+    # this is only necessary to handle MXIndexedRecordIO because otherwise dataset
+    # can be passed as argument
+    global _worker_dataset
+    _worker_dataset = dataset
+
+def _worker_fn(samples, batchify_fn):
+    """Function for processing data in worker process."""
+    # it is required that each worker process has to fork a new MXIndexedRecordIO handle
+    # preserving dataset as global variable can save tons of overhead and is safe in new process
+    global _worker_dataset
+    batch = batchify_fn([_worker_dataset[i] for i in samples])
+    buf = io.BytesIO()
+    ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(batch)
+    return buf.getvalue()
+
+class _MultiWorkerIter(object):
+    """Internal multi-worker iterator for DataLoader."""
+    def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
+                 worker_fn=_worker_fn, prefetch=0):
+        self._worker_pool = worker_pool
+        self._batchify_fn = batchify_fn
+        self._batch_sampler = batch_sampler
+        self._data_buffer = {}
+        self._rcvd_idx = 0
+        self._sent_idx = 0
+        self._iter = iter(self._batch_sampler)
+        self._worker_fn = worker_fn
+        self._pin_memory = pin_memory
+        # pre-fetch
+        for _ in range(prefetch):
+            self._push_next()
+
+    def __len__(self):
+        return len(self._batch_sampler)
+
+    def _push_next(self):
+        """Assign next batch workload to workers."""
+        r = next(self._iter, None)
+        if r is None:
+            return
+        async_ret = self._worker_pool.apply_async(self._worker_fn, (r, self._batchify_fn))
+        self._data_buffer[self._sent_idx] = async_ret
+        self._sent_idx += 1
+
+    def __next__(self):
+        self._push_next()
+        if self._rcvd_idx == self._sent_idx:
+            assert not self._data_buffer, "Data buffer should be empty at this moment"
+            raise StopIteration
+
+        assert self._rcvd_idx < self._sent_idx, "rcvd_idx must be smaller than sent_idx"
+        assert self._rcvd_idx in self._data_buffer, "fatal error with _push_next, rcvd_idx missing"
+        ret = self._data_buffer.pop(self._rcvd_idx)
+        batch = pickle.loads(ret.get())
+        if self._pin_memory:
+            batch = _as_in_context(batch, context.cpu_pinned())
+        batch = batch[0] if len(batch) == 1 else batch
+        self._rcvd_idx += 1
+        return batch
+
+    def next(self):
+        return self.__next__()
+
+    def __iter__(self):
+        return self
+
+
+class DataLoader(object):
+    """Loads data from a dataset and returns mini-batches of data.
+
+    Parameters
+    ----------
+    dataset : Dataset
+        Source dataset. Note that numpy and mxnet arrays can be directly used
+        as a Dataset.
+    batch_size : int
+        Size of mini-batch.
+    shuffle : bool
+        Whether to shuffle the samples.
+    sampler : Sampler
+        The sampler to use. Either specify sampler or shuffle, not both.
+    last_batch : {'keep', 'discard', 'rollover'}
+        How to handle the last batch if batch_size does not evenly divide
+        `len(dataset)`.
+
+        keep - A batch with less samples than previous batches is returned.
+        discard - The last batch is discarded if its incomplete.
+        rollover - The remaining samples are rolled over to the next epoch.
+    batch_sampler : Sampler
+        A sampler that returns mini-batches. Do not specify batch_size,
+        shuffle, sampler, and last_batch if batch_sampler is specified.
+    batchify_fn : callable
+        Callback function to allow users to specify how to merge samples
+        into a batch. Defaults to `default_batchify_fn`::
+
+            def default_batchify_fn(data):
+                if isinstance(data[0], nd.NDArray):
+                    return nd.stack(*data)
+                elif isinstance(data[0], tuple):
+                    data = zip(*data)
+                    return [default_batchify_fn(i) for i in data]
+                else:
+                    data = np.asarray(data)
+                    return nd.array(data, dtype=data.dtype)
+
+    num_workers : int, default 0
+        The number of multiprocessing workers to use for data preprocessing.
+    pin_memory : boolean, default False
+        If ``True``, the dataloader will copy NDArrays into pinned memory
+        before returning them. Copying from CPU pinned memory to GPU is faster
+        than from normal CPU memory.
+    prefetch : int, default is `num_workers * 2`
+        The number of prefetching batches only works if `num_workers` > 0.
+        If `prefetch` > 0, it allow worker process to prefetch certain batches before
+        acquiring data from iterators.
+        Note that using large prefetching batch will provide smoother bootstrapping performance,
+        but will consume more shared_memory. Using smaller number may forfeit the purpose of using
+        multiple worker processes, try reduce `num_workers` in this case.
+        By default it defaults to `num_workers * 2`.
+    """
+    def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
+                 last_batch=None, batch_sampler=None, batchify_fn=None,
+                 num_workers=0, pin_memory=False, prefetch=None):
+        self._dataset = dataset
+        self._pin_memory = pin_memory
+
+        if batch_sampler is None:
+            if batch_size is None:
+                raise ValueError("batch_size must be specified unless " \
+                                 "batch_sampler is specified")
+            if sampler is None:
+                if shuffle:
+                    sampler = _sampler.RandomSampler(len(dataset))
+                else:
+                    sampler = _sampler.SequentialSampler(len(dataset))
+            elif shuffle:
+                raise ValueError("shuffle must not be specified if sampler is specified")
+
+            batch_sampler = _sampler.BatchSampler(
+                sampler, batch_size, last_batch if last_batch else 'keep')
+        elif batch_size is not None or shuffle or sampler is not None or \
+                last_batch is not None:
+            raise ValueError("batch_size, shuffle, sampler and last_batch must " \
+                             "not be specified if batch_sampler is specified.")
+
+        self._batch_sampler = batch_sampler
+        self._num_workers = num_workers if num_workers >= 0 else 0
+        self._worker_pool = None
+        self._prefetch = max(0, int(prefetch) if prefetch is not None else 2 * self._num_workers)
+        if self._num_workers > 0:
+            self._worker_pool = multiprocessing.Pool(
+                self._num_workers, initializer=_worker_initializer, initargs=[self._dataset])
+        if batchify_fn is None:
+            if num_workers > 0:
+                self._batchify_fn = default_mp_batchify_fn
+            else:
+                self._batchify_fn = default_batchify_fn
+        else:
+            self._batchify_fn = batchify_fn
+
+    def __iter__(self):
+        if self._num_workers == 0:
+            def same_process_iter():
+                for batch in self._batch_sampler:
+                    ret = self._batchify_fn([self._dataset[idx] for idx in batch])
+                    if self._pin_memory:
+                        ret = _as_in_context(ret, context.cpu_pinned())
+                    yield ret
+            return same_process_iter()
+
+        # multi-worker
+        return _MultiWorkerIter(self._worker_pool, self._batchify_fn, self._batch_sampler,
+                                pin_memory=self._pin_memory, worker_fn=_worker_fn,
+                                prefetch=self._prefetch)
 
     def __len__(self):
         return len(self._batch_sampler)
diff --git a/python/mxnet/recordio.py b/python/mxnet/recordio.py
index 2def141..bdc6323 100644
--- a/python/mxnet/recordio.py
+++ b/python/mxnet/recordio.py
@@ -18,6 +18,7 @@
 """Read and write for the RecordIO data format."""
 from __future__ import absolute_import
 from collections import namedtuple
+from multiprocessing import current_process
 
 import ctypes
 import struct
@@ -65,6 +66,7 @@ class MXRecordIO(object):
         self.uri = c_str(uri)
         self.handle = RecordIOHandle()
         self.flag = flag
+        self.pid = None
         self.is_open = False
         self.open()
 
@@ -78,6 +80,7 @@ class MXRecordIO(object):
             self.writable = False
         else:
             raise ValueError("Invalid flag %s"%self.flag)
+        self.pid = current_process().pid
         self.is_open = True
 
     def __del__(self):
@@ -109,6 +112,14 @@ class MXRecordIO(object):
         if is_open:
             self.open()
 
+    def _check_pid(self, allow_reset=False):
+        """Check process id to ensure integrity, reset if in new process."""
+        if not self.pid == current_process().pid:
+            if allow_reset:
+                self.reset()
+            else:
+                raise RuntimeError("Forbidden operation in multiple processes")
+
     def close(self):
         """Closes the record file."""
         if not self.is_open:
@@ -118,6 +129,7 @@ class MXRecordIO(object):
         else:
             check_call(_LIB.MXRecordIOReaderFree(self.handle))
         self.is_open = False
+        self.pid = None
 
     def reset(self):
         """Resets the pointer to first item.
@@ -156,6 +168,7 @@ class MXRecordIO(object):
             Buffer to write.
         """
         assert self.writable
+        self._check_pid(allow_reset=False)
         check_call(_LIB.MXRecordIOWriterWriteRecord(self.handle,
                                                     ctypes.c_char_p(buf),
                                                     ctypes.c_size_t(len(buf))))
@@ -182,6 +195,9 @@ class MXRecordIO(object):
             Buffer read.
         """
         assert not self.writable
+        # trying to implicitly read from multiple processes is forbidden,
+        # there's no elegant way to handle unless lock is introduced
+        self._check_pid(allow_reset=False)
         buf = ctypes.c_char_p()
         size = ctypes.c_size_t()
         check_call(_LIB.MXRecordIOReaderReadRecord(self.handle,
@@ -255,6 +271,7 @@ class MXIndexedRecordIO(MXRecordIO):
         This function is internally called by `read_idx(idx)` to find the current
         reader pointer position. It doesn't return anything."""
         assert not self.writable
+        self._check_pid(allow_reset=True)
         pos = ctypes.c_size_t(self.idx[idx])
         check_call(_LIB.MXRecordIOReaderSeek(self.handle, pos))