You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/12/14 18:16:34 UTC

[GitHub] anirudh2290 closed pull request #13606: Complimentary gluon DataLoader improvements

anirudh2290 closed pull request #13606: Complimentary gluon DataLoader improvements
URL: https://github.com/apache/incubator-mxnet/pull/13606
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md
index 8d08e320721..b1f5014ff16 100644
--- a/docs/faq/env_var.md
+++ b/docs/faq/env_var.md
@@ -37,6 +37,12 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
 * MXNET_CPU_NNPACK_NTHREADS
   - Values: Int ```(default=4)```
   - The number of threads used for NNPACK. NNPACK package aims to provide high-performance implementations of some layers for multi-core CPUs. Checkout [NNPACK](http://mxnet.io/faq/nnpack.html) to know more about it.
+* MXNET_MP_WORKER_NTHREADS
+  - Values: Int ```(default=1)```
+  - The number of scheduling threads on CPU given to multiprocess workers. Enlarge this number allows more operators to run in parallel in individual workers but please consider reducing the overall `num_workers` to avoid thread contention (not available on Windows).
+* MXNET_MP_OPENCV_NUM_THREADS
+  - Values: Int ```(default=0)```
+  - The number of OpenCV execution threads given to multiprocess workers. OpenCV multithreading is disabled if `MXNET_MP_OPENCV_NUM_THREADS` < 1 (default). Enlarge this number may boost the performance of individual workers when executing underlying OpenCV functions but please consider reducing the overall `num_workers` to avoid thread contention (not available on Windows).
 
 ## Memory Options
 
@@ -99,10 +105,10 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
 * MXNET_KVSTORE_REDUCTION_NTHREADS
   - Values: Int ```(default=4)```
   - The number of CPU threads used for summing up big arrays on a single machine
-  - This will also be used for `dist_sync` kvstore to sum up arrays from different contexts on a single machine. 
-  - This does not affect summing up of arrays from different machines on servers. 
+  - This will also be used for `dist_sync` kvstore to sum up arrays from different contexts on a single machine.
+  - This does not affect summing up of arrays from different machines on servers.
   - Summing up of arrays for `dist_sync_device` kvstore is also unaffected as that happens on GPUs.
-  
+
 * MXNET_KVSTORE_BIGARRAY_BOUND
   - Values: Int ```(default=1000000)```
   - The minimum size of a "big array".
@@ -166,7 +172,7 @@ When USE_PROFILER is enabled in Makefile or CMake, the following environments ca
 
 * MXNET_CUDNN_AUTOTUNE_DEFAULT
   - Values: 0, 1, or 2 ```(default=1)```
-  - The default value of cudnn auto tuning for convolution layers. 
+  - The default value of cudnn auto tuning for convolution layers.
   - Value of 0 means there is no auto tuning to pick the convolution algo
   - Performance tests are run to pick the convolution algo when value is 1 or 2
   - Value of 1 chooses the best algo in a limited workspace
@@ -190,12 +196,12 @@ When USE_PROFILER is enabled in Makefile or CMake, the following environments ca
 * MXNET_HOME
   - Data directory in the filesystem for storage, for example when downloading gluon models.
   - Default in *nix is .mxnet APPDATA/mxnet in windows.
-  
+
 * MXNET_MKLDNN_ENABLED
   - Values: 0, 1 ```(default=1)```
   - Flag to enable or disable MKLDNN accelerator. On by default.
   - Only applies to mxnet that has been compiled with MKLDNN (```pip install mxnet-mkl``` or built from source with ```USE_MKLDNN=1```)
-  
+
 * MXNET_MKLDNN_CACHE_NUM
   - Values: Int ```(default=-1)```
   - Flag to set num of elements that MKLDNN cache can hold. Default is -1 which means cache size is unbounded. Should only be set if your model has variable input shapes, as cache size may grow unbounded. The number represents the number of items in the cache and is proportional to the number of layers that use MKLDNN and different input shape.
diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py
index 586e620470d..9d762745a40 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -26,6 +26,7 @@
 import multiprocessing
 import multiprocessing.queues
 from multiprocessing.reduction import ForkingPickler
+from multiprocessing.pool import ThreadPool
 import threading
 import numpy as np
 
@@ -384,8 +385,9 @@ def _worker_initializer(dataset):
     global _worker_dataset
     _worker_dataset = dataset
 
-def _worker_fn(samples, batchify_fn):
+def _worker_fn(samples, batchify_fn, dataset=None):
     """Function for processing data in worker process."""
+    # pylint: disable=unused-argument
     # it is required that each worker process has to fork a new MXIndexedRecordIO handle
     # preserving dataset as global variable can save tons of overhead and is safe in new process
     global _worker_dataset
@@ -394,10 +396,14 @@ def _worker_fn(samples, batchify_fn):
     ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(batch)
     return buf.getvalue()
 
+def _thread_worker_fn(samples, batchify_fn, dataset):
+    """Threadpool worker function for processing data."""
+    return batchify_fn([dataset[i] for i in samples])
+
 class _MultiWorkerIter(object):
     """Internal multi-worker iterator for DataLoader."""
     def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
-                 worker_fn=_worker_fn, prefetch=0):
+                 worker_fn=_worker_fn, prefetch=0, dataset=None):
         self._worker_pool = worker_pool
         self._batchify_fn = batchify_fn
         self._batch_sampler = batch_sampler
@@ -407,6 +413,7 @@ def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
         self._iter = iter(self._batch_sampler)
         self._worker_fn = worker_fn
         self._pin_memory = pin_memory
+        self._dataset = dataset
         # pre-fetch
         for _ in range(prefetch):
             self._push_next()
@@ -419,7 +426,8 @@ def _push_next(self):
         r = next(self._iter, None)
         if r is None:
             return
-        async_ret = self._worker_pool.apply_async(self._worker_fn, (r, self._batchify_fn))
+        async_ret = self._worker_pool.apply_async(
+            self._worker_fn, (r, self._batchify_fn, self._dataset))
         self._data_buffer[self._sent_idx] = async_ret
         self._sent_idx += 1
 
@@ -432,7 +440,7 @@ def __next__(self):
         assert self._rcvd_idx < self._sent_idx, "rcvd_idx must be smaller than sent_idx"
         assert self._rcvd_idx in self._data_buffer, "fatal error with _push_next, rcvd_idx missing"
         ret = self._data_buffer.pop(self._rcvd_idx)
-        batch = pickle.loads(ret.get())
+        batch = pickle.loads(ret.get()) if self._dataset is None else ret.get()
         if self._pin_memory:
             batch = _as_in_context(batch, context.cpu_pinned())
         batch = batch[0] if len(batch) == 1 else batch
@@ -498,12 +506,18 @@ def default_batchify_fn(data):
         but will consume more shared_memory. Using smaller number may forfeit the purpose of using
         multiple worker processes, try reduce `num_workers` in this case.
         By default it defaults to `num_workers * 2`.
+    thread_pool : bool, default False
+        If ``True``, use threading pool instead of multiprocessing pool. Using threadpool
+        can avoid shared memory usage. If `DataLoader` is more IO bounded or GIL is not a killing
+        problem, threadpool version may achieve better performance than multiprocessing.
+
     """
     def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
                  last_batch=None, batch_sampler=None, batchify_fn=None,
-                 num_workers=0, pin_memory=False, prefetch=None):
+                 num_workers=0, pin_memory=False, prefetch=None, thread_pool=False):
         self._dataset = dataset
         self._pin_memory = pin_memory
+        self._thread_pool = thread_pool
 
         if batch_sampler is None:
             if batch_size is None:
@@ -529,8 +543,11 @@ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
         self._worker_pool = None
         self._prefetch = max(0, int(prefetch) if prefetch is not None else 2 * self._num_workers)
         if self._num_workers > 0:
-            self._worker_pool = multiprocessing.Pool(
-                self._num_workers, initializer=_worker_initializer, initargs=[self._dataset])
+            if self._thread_pool:
+                self._worker_pool = ThreadPool(self._num_workers)
+            else:
+                self._worker_pool = multiprocessing.Pool(
+                    self._num_workers, initializer=_worker_initializer, initargs=[self._dataset])
         if batchify_fn is None:
             if num_workers > 0:
                 self._batchify_fn = default_mp_batchify_fn
@@ -551,14 +568,17 @@ def same_process_iter():
 
         # multi-worker
         return _MultiWorkerIter(self._worker_pool, self._batchify_fn, self._batch_sampler,
-                                pin_memory=self._pin_memory, worker_fn=_worker_fn,
-                                prefetch=self._prefetch)
+                                pin_memory=self._pin_memory,
+                                worker_fn=_thread_worker_fn if self._thread_pool else _worker_fn,
+                                prefetch=self._prefetch,
+                                dataset=self._dataset if self._thread_pool else None)
 
     def __len__(self):
         return len(self._batch_sampler)
 
     def __del__(self):
         if self._worker_pool:
-            # manually terminate due to a bug that pool is not automatically terminated on linux
+            # manually terminate due to a bug that pool is not automatically terminated
+            # https://bugs.python.org/issue34172
             assert isinstance(self._worker_pool, multiprocessing.pool.Pool)
             self._worker_pool.terminate()
diff --git a/src/initialize.cc b/src/initialize.cc
index ddda3f18a3a..de7edd1b145 100644
--- a/src/initialize.cc
+++ b/src/initialize.cc
@@ -57,11 +57,13 @@ class LibraryInitializer {
         Engine::Get()->Start();
       },
       []() {
-        // Make children single threaded since they are typically workers
-        dmlc::SetEnv("MXNET_CPU_WORKER_NTHREADS", 1);
+        // Conservative thread management for multiprocess workers
+        const size_t mp_worker_threads = dmlc::GetEnv("MXNET_MP_WORKER_NTHREADS", 1);
+        dmlc::SetEnv("MXNET_CPU_WORKER_NTHREADS", mp_worker_threads);
         dmlc::SetEnv("OMP_NUM_THREADS", 1);
 #if MXNET_USE_OPENCV && !__APPLE__
-        cv::setNumThreads(0);  // disable opencv threading
+        const size_t mp_cv_num_threads = dmlc::GetEnv("MXNET_MP_OPENCV_NUM_THREADS", 0);
+        cv::setNumThreads(mp_cv_num_threads);  // disable opencv threading
 #endif  // MXNET_USE_OPENCV
         engine::OpenMP::Get()->set_enabled(false);
         Engine::Get()->Start();
diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py
index a3ba222c71d..6a5322616e2 100644
--- a/tests/python/unittest/test_gluon_data.py
+++ b/tests/python/unittest/test_gluon_data.py
@@ -156,9 +156,10 @@ def __getitem__(self, key):
 @with_seed()
 def test_multi_worker():
     data = Dataset()
-    loader = gluon.data.DataLoader(data, batch_size=1, num_workers=5)
-    for i, batch in enumerate(loader):
-        assert (batch.asnumpy() == i).all()
+    for thread_pool in [True, False]:
+        loader = gluon.data.DataLoader(data, batch_size=1, num_workers=5, thread_pool=thread_pool)
+        for i, batch in enumerate(loader):
+            assert (batch.asnumpy() == i).all()
 
 class _Dummy(Dataset):
     """Dummy dataset for randomized shape arrays."""


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services