You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/09 04:11:54 UTC

[GitHub] szha closed pull request #10820: fix thread contention caused by openmp

szha closed pull request #10820: fix thread contention caused by openmp
URL: https://github.com/apache/incubator-mxnet/pull/10820
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py
index 7f09e286742..7ef18bd53d3 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -83,6 +83,21 @@ def __init__(self, *args, **kwargs):
         self._recv = self._reader.recv
 
 
+class SimpleQueue(multiprocessing.queues.SimpleQueue):
+    """Wrapper for multiprocessing SimpleQueue that dumps NDArray with shared memory.
+       SimpleQueue don't use threading internally.
+    """
+    def __init__(self, *args, **kwargs):
+        if sys.version_info[0] <= 2:
+            super(SimpleQueue, self).__init__(*args, **kwargs)
+        else:
+            super(SimpleQueue, self).__init__(*args, ctx=multiprocessing.get_context(),
+                                              **kwargs)
+        self._reader = ConnectionWrapper(self._reader)
+        self._writer = ConnectionWrapper(self._writer)
+        self._send = self._writer.send
+        self._recv = self._reader.recv
+
 def default_batchify_fn(data):
     """Collate data into batch."""
     if isinstance(data[0], nd.NDArray):
@@ -128,7 +143,7 @@ def __init__(self, num_workers, dataset, batchify_fn, batch_sampler):
         self._batchify_fn = batchify_fn
         self._batch_sampler = batch_sampler
         self._key_queue = Queue()
-        self._data_queue = Queue(2*self._num_workers)
+        self._data_queue = SimpleQueue()
         self._data_buffer = {}
         self._rcvd_idx = 0
         self._sent_idx = 0
@@ -170,10 +185,10 @@ def __next__(self):
             raise StopIteration
 
         while True:
+            self._push_next()
             if self._rcvd_idx in self._data_buffer:
                 batch = self._data_buffer.pop(self._rcvd_idx)
                 self._rcvd_idx += 1
-                self._push_next()
                 return batch
             idx, batch = self._data_queue.get()
             self._data_buffer[idx] = batch
diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc
index a8227886f84..2f77380baf8 100644
--- a/src/engine/threaded_engine_perdevice.cc
+++ b/src/engine/threaded_engine_perdevice.cc
@@ -53,22 +53,6 @@ class ThreadedEnginePerDevice : public ThreadedEngine {
 
   ThreadedEnginePerDevice() noexcept(false) {
     this->Start();
-#ifndef _WIN32
-    pthread_atfork(
-      []() {
-        Engine::Get()->Stop();
-      },
-      []() {
-        Engine::Get()->Start();
-      },
-      []() {
-        // Make children single threaded since they are typically workers
-        dmlc::SetEnv("MXNET_CPU_WORKER_NTHREADS", 1);
-        dmlc::SetEnv("OMP_NUM_THREADS", 1);
-        OpenMP::Get()->set_enabled(false);
-        Engine::Get()->Start();
-      });
-#endif
   }
   ~ThreadedEnginePerDevice() noexcept(false) {
     this->StopNoWait();
diff --git a/src/initialize.cc b/src/initialize.cc
index 69d408d7d87..1fd92628e9b 100644
--- a/src/initialize.cc
+++ b/src/initialize.cc
@@ -25,6 +25,7 @@
 #include <signal.h>
 #include <dmlc/logging.h>
 #include <mxnet/engine.h>
+#include "./engine/openmp.h"
 
 namespace mxnet {
 #if MXNET_USE_SIGNAL_HANDLER && DMLC_LOG_STACK_TRACE
@@ -42,6 +43,24 @@ class LibraryInitializer {
 #if MXNET_USE_SIGNAL_HANDLER && DMLC_LOG_STACK_TRACE
     signal(SIGSEGV, SegfaultLogger);
 #endif
+
+// disable openmp for multithreaded workers
+#ifndef _WIN32
+    pthread_atfork(
+      []() {
+        Engine::Get()->Stop();
+      },
+      []() {
+        Engine::Get()->Start();
+      },
+      []() {
+        // Make children single threaded since they are typically workers
+        dmlc::SetEnv("MXNET_CPU_WORKER_NTHREADS", 1);
+        dmlc::SetEnv("OMP_NUM_THREADS", 1);
+        engine::OpenMP::Get()->set_enabled(false);
+        Engine::Get()->Start();
+      });
+#endif
   }
 
   static LibraryInitializer* Get();


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services