You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/05/15 18:22:44 UTC

[incubator-mxnet] branch master updated: Fix engine stop/start (#10911)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 57c8ca1  Fix engine stop/start  (#10911)
57c8ca1 is described below

commit 57c8ca1a0a6dae36dc27a9f054041ecce652e4c8
Author: Joshua Z. Zhang <ch...@gmail.com>
AuthorDate: Tue May 15 11:22:39 2018 -0700

    Fix engine stop/start  (#10911)
    
    * fix engine start/stop
    
    * add tests
    
    * fix test
    
    * fix
    
    * fix tests
---
 python/mxnet/gluon/data/dataloader.py       |  2 +-
 src/engine/naive_engine.cc                  |  6 +++
 src/engine/threaded_engine_pooled.cc        | 57 +++++++++++++++++++++--------
 tests/cpp/engine/threaded_engine_test.cc    | 17 +++++++++
 tests/python/unittest/test_engine_import.py | 44 ++++++++++++++++++++++
 5 files changed, 109 insertions(+), 17 deletions(-)

diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py
index 7ef18bd..d80a6bf 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -143,7 +143,7 @@ class _MultiWorkerIter(object):
         self._batchify_fn = batchify_fn
         self._batch_sampler = batch_sampler
         self._key_queue = Queue()
-        self._data_queue = SimpleQueue()
+        self._data_queue = Queue() if sys.version_info[0] <= 2 else SimpleQueue()
         self._data_buffer = {}
         self._rcvd_idx = 0
         self._sent_idx = 0
diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc
index 1fa5306..8196af2 100644
--- a/src/engine/naive_engine.cc
+++ b/src/engine/naive_engine.cc
@@ -63,6 +63,12 @@ class NaiveEngine final : public Engine {
 #endif
   }
 
+  void Stop() override {
+  }
+
+  void Start() override {
+  }
+
   // new variables
   VarHandle NewVariable() override {
     size_t v = ++counter_;
diff --git a/src/engine/threaded_engine_pooled.cc b/src/engine/threaded_engine_pooled.cc
index 074ea4e..574e832 100644
--- a/src/engine/threaded_engine_pooled.cc
+++ b/src/engine/threaded_engine_pooled.cc
@@ -27,6 +27,7 @@
 #include <dmlc/logging.h>
 #include <dmlc/concurrency.h>
 #include <cassert>
+#include <utility>
 #include "./threaded_engine.h"
 #include "./thread_pool.h"
 #include "./stream_manager.h"
@@ -42,14 +43,38 @@ namespace engine {
  */
 class ThreadedEnginePooled : public ThreadedEngine {
  public:
-  ThreadedEnginePooled() :
-      thread_pool_(kNumWorkingThreads, [this]() { ThreadWorker(&task_queue_); }),
-      io_thread_pool_(1, [this]() { ThreadWorker(&io_task_queue_); }) {}
+  ThreadedEnginePooled() {
+    this->Start();
+  }
 
   ~ThreadedEnginePooled() noexcept(false) {
-    streams_.Finalize();
-    task_queue_.SignalForKill();
-    io_task_queue_.SignalForKill();
+    StopNoWait();
+  }
+
+  void StopNoWait() {
+    streams_->Finalize();
+    task_queue_->SignalForKill();
+    io_task_queue_->SignalForKill();
+    task_queue_ = nullptr;
+    io_task_queue_ = nullptr;
+    thread_pool_ = nullptr;
+    io_thread_pool_ = nullptr;
+    streams_ = nullptr;
+  }
+
+  void Stop() override {
+    WaitForAll();
+    StopNoWait();
+  }
+
+  void Start() override {
+    streams_.reset(new StreamManager<kMaxNumGpus, kNumStreamsPerGpu>());
+    task_queue_.reset(new dmlc::ConcurrentBlockingQueue<OprBlock*>());
+    io_task_queue_.reset(new dmlc::ConcurrentBlockingQueue<OprBlock*>());
+    thread_pool_.reset(new ThreadPool(kNumWorkingThreads, [this]() {
+      ThreadWorker(task_queue_); }));
+    io_thread_pool_.reset(new ThreadPool(1, [this]() {
+      ThreadWorker(io_task_queue_); }));
   }
 
  protected:
@@ -71,24 +96,24 @@ class ThreadedEnginePooled : public ThreadedEngine {
   /*!
    * \brief Streams.
    */
-  StreamManager<kMaxNumGpus, kNumStreamsPerGpu> streams_;
+  std::unique_ptr<StreamManager<kMaxNumGpus, kNumStreamsPerGpu>> streams_;
   /*!
    * \brief Task queues.
    */
-  dmlc::ConcurrentBlockingQueue<OprBlock*> task_queue_;
-  dmlc::ConcurrentBlockingQueue<OprBlock*> io_task_queue_;
+  std::shared_ptr<dmlc::ConcurrentBlockingQueue<OprBlock*>> task_queue_;
+  std::shared_ptr<dmlc::ConcurrentBlockingQueue<OprBlock*>> io_task_queue_;
   /*!
    * \brief Thread pools.
    */
-  ThreadPool thread_pool_;
-  ThreadPool io_thread_pool_;
+  std::unique_ptr<ThreadPool> thread_pool_;
+  std::unique_ptr<ThreadPool> io_thread_pool_;
   /*!
    * \brief Worker.
    * \param task_queue Queue to work on.
    *
    * The method to pass to thread pool to parallelize.
    */
-  void ThreadWorker(dmlc::ConcurrentBlockingQueue<OprBlock*>* task_queue) {
+  void ThreadWorker(std::shared_ptr<dmlc::ConcurrentBlockingQueue<OprBlock*>> task_queue) {
     OprBlock* opr_block;
     while (task_queue->Pop(&opr_block)) {
       DoExecute(opr_block);
@@ -110,8 +135,8 @@ class ThreadedEnginePooled : public ThreadedEngine {
     bool is_copy = (opr_block->opr->prop == FnProperty::kCopyFromGPU ||
                     opr_block->opr->prop == FnProperty::kCopyToGPU);
     auto&& rctx = is_copy
-        ? streams_.GetIORunContext(opr_block->ctx)
-        : streams_.GetRunContext(opr_block->ctx);
+        ? streams_->GetIORunContext(opr_block->ctx)
+        : streams_->GetRunContext(opr_block->ctx);
     this->ExecuteOprBlock(rctx, opr_block);
   }
   /*!
@@ -122,11 +147,11 @@ class ThreadedEnginePooled : public ThreadedEngine {
     switch (opr_block->opr->prop) {
       case FnProperty::kCopyFromGPU:
       case FnProperty::kCopyToGPU: {
-        io_task_queue_.Push(opr_block);
+        io_task_queue_->Push(opr_block);
         break;
       }
       default: {
-        task_queue_.Push(opr_block);
+        task_queue_->Push(opr_block);
         break;
       }
     }
diff --git a/tests/cpp/engine/threaded_engine_test.cc b/tests/cpp/engine/threaded_engine_test.cc
index 945c083..92d0958 100644
--- a/tests/cpp/engine/threaded_engine_test.cc
+++ b/tests/cpp/engine/threaded_engine_test.cc
@@ -121,6 +121,23 @@ double EvaluateWorloads(const std::vector<Workload>& workloads,
   return dmlc::GetTime() - t;
 }
 
+TEST(Engine, start_stop) {
+  const int num_engine = 3;
+  std::vector<mxnet::Engine*> engine(num_engine);
+  engine[0] = mxnet::engine::CreateNaiveEngine();
+  engine[1] = mxnet::engine::CreateThreadedEnginePooled();
+  engine[2] = mxnet::engine::CreateThreadedEnginePerDevice();
+  std::string type_names[3] = {"NaiveEngine", "ThreadedEnginePooled", "ThreadedEnginePerDevice"};
+
+  for (int i = 0; i < num_engine; ++i) {
+    LOG(INFO) << "Stopping: " << type_names[i];
+    engine[i]->Stop();
+    LOG(INFO) << "Stopped: " << type_names[i] << " Starting...";
+    engine[i]->Start();
+    LOG(INFO) << "Started: " << type_names[i] << " Done...";
+  }
+}
+
 TEST(Engine, RandSumExpr) {
   std::vector<Workload> workloads;
   int num_repeat = 5;
diff --git a/tests/python/unittest/test_engine_import.py b/tests/python/unittest/test_engine_import.py
new file mode 100644
index 0000000..bd34eff
--- /dev/null
+++ b/tests/python/unittest/test_engine_import.py
@@ -0,0 +1,44 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import sys
+
+def test_engine_import():
+    import mxnet
+    def test_import():
+        version = sys.version_info
+        if version >= (3, 4):
+            import importlib
+            importlib.reload(mxnet)
+        elif version >= (3, ):
+            import imp
+            imp.reload(mxnet)
+        else:
+            reload(mxnet)
+    engine_types = ['', 'NaiveEngine', 'ThreadedEngine', 'ThreadedEnginePerDevice']
+
+    for type in engine_types:
+        if not type:
+            os.environ.pop('MXNET_ENGINE_TYPE', None)
+        else:
+            os.environ['MXNET_ENGINE_TYPE'] = type
+        test_import()
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.