You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/05/15 18:22:44 UTC
[incubator-mxnet] branch master updated: Fix engine stop/start
(#10911)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 57c8ca1 Fix engine stop/start (#10911)
57c8ca1 is described below
commit 57c8ca1a0a6dae36dc27a9f054041ecce652e4c8
Author: Joshua Z. Zhang <ch...@gmail.com>
AuthorDate: Tue May 15 11:22:39 2018 -0700
Fix engine stop/start (#10911)
* fix engine start/stop
* add tests
* fix test
* fix
* fix tests
---
python/mxnet/gluon/data/dataloader.py | 2 +-
src/engine/naive_engine.cc | 6 +++
src/engine/threaded_engine_pooled.cc | 57 +++++++++++++++++++++--------
tests/cpp/engine/threaded_engine_test.cc | 17 +++++++++
tests/python/unittest/test_engine_import.py | 44 ++++++++++++++++++++++
5 files changed, 109 insertions(+), 17 deletions(-)
diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py
index 7ef18bd..d80a6bf 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -143,7 +143,7 @@ class _MultiWorkerIter(object):
self._batchify_fn = batchify_fn
self._batch_sampler = batch_sampler
self._key_queue = Queue()
- self._data_queue = SimpleQueue()
+ self._data_queue = Queue() if sys.version_info[0] <= 2 else SimpleQueue()
self._data_buffer = {}
self._rcvd_idx = 0
self._sent_idx = 0
diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc
index 1fa5306..8196af2 100644
--- a/src/engine/naive_engine.cc
+++ b/src/engine/naive_engine.cc
@@ -63,6 +63,12 @@ class NaiveEngine final : public Engine {
#endif
}
+ void Stop() override {
+ }
+
+ void Start() override {
+ }
+
// new variables
VarHandle NewVariable() override {
size_t v = ++counter_;
diff --git a/src/engine/threaded_engine_pooled.cc b/src/engine/threaded_engine_pooled.cc
index 074ea4e..574e832 100644
--- a/src/engine/threaded_engine_pooled.cc
+++ b/src/engine/threaded_engine_pooled.cc
@@ -27,6 +27,7 @@
#include <dmlc/logging.h>
#include <dmlc/concurrency.h>
#include <cassert>
+#include <utility>
#include "./threaded_engine.h"
#include "./thread_pool.h"
#include "./stream_manager.h"
@@ -42,14 +43,38 @@ namespace engine {
*/
class ThreadedEnginePooled : public ThreadedEngine {
public:
- ThreadedEnginePooled() :
- thread_pool_(kNumWorkingThreads, [this]() { ThreadWorker(&task_queue_); }),
- io_thread_pool_(1, [this]() { ThreadWorker(&io_task_queue_); }) {}
+ ThreadedEnginePooled() {
+ this->Start();
+ }
~ThreadedEnginePooled() noexcept(false) {
- streams_.Finalize();
- task_queue_.SignalForKill();
- io_task_queue_.SignalForKill();
+ StopNoWait();
+ }
+
+ void StopNoWait() {
+ streams_->Finalize();
+ task_queue_->SignalForKill();
+ io_task_queue_->SignalForKill();
+ task_queue_ = nullptr;
+ io_task_queue_ = nullptr;
+ thread_pool_ = nullptr;
+ io_thread_pool_ = nullptr;
+ streams_ = nullptr;
+ }
+
+ void Stop() override {
+ WaitForAll();
+ StopNoWait();
+ }
+
+ void Start() override {
+ streams_.reset(new StreamManager<kMaxNumGpus, kNumStreamsPerGpu>());
+ task_queue_.reset(new dmlc::ConcurrentBlockingQueue<OprBlock*>());
+ io_task_queue_.reset(new dmlc::ConcurrentBlockingQueue<OprBlock*>());
+ thread_pool_.reset(new ThreadPool(kNumWorkingThreads, [this]() {
+ ThreadWorker(task_queue_); }));
+ io_thread_pool_.reset(new ThreadPool(1, [this]() {
+ ThreadWorker(io_task_queue_); }));
}
protected:
@@ -71,24 +96,24 @@ class ThreadedEnginePooled : public ThreadedEngine {
/*!
* \brief Streams.
*/
- StreamManager<kMaxNumGpus, kNumStreamsPerGpu> streams_;
+ std::unique_ptr<StreamManager<kMaxNumGpus, kNumStreamsPerGpu>> streams_;
/*!
* \brief Task queues.
*/
- dmlc::ConcurrentBlockingQueue<OprBlock*> task_queue_;
- dmlc::ConcurrentBlockingQueue<OprBlock*> io_task_queue_;
+ std::shared_ptr<dmlc::ConcurrentBlockingQueue<OprBlock*>> task_queue_;
+ std::shared_ptr<dmlc::ConcurrentBlockingQueue<OprBlock*>> io_task_queue_;
/*!
* \brief Thread pools.
*/
- ThreadPool thread_pool_;
- ThreadPool io_thread_pool_;
+ std::unique_ptr<ThreadPool> thread_pool_;
+ std::unique_ptr<ThreadPool> io_thread_pool_;
/*!
* \brief Worker.
* \param task_queue Queue to work on.
*
* The method to pass to thread pool to parallelize.
*/
- void ThreadWorker(dmlc::ConcurrentBlockingQueue<OprBlock*>* task_queue) {
+ void ThreadWorker(std::shared_ptr<dmlc::ConcurrentBlockingQueue<OprBlock*>> task_queue) {
OprBlock* opr_block;
while (task_queue->Pop(&opr_block)) {
DoExecute(opr_block);
@@ -110,8 +135,8 @@ class ThreadedEnginePooled : public ThreadedEngine {
bool is_copy = (opr_block->opr->prop == FnProperty::kCopyFromGPU ||
opr_block->opr->prop == FnProperty::kCopyToGPU);
auto&& rctx = is_copy
- ? streams_.GetIORunContext(opr_block->ctx)
- : streams_.GetRunContext(opr_block->ctx);
+ ? streams_->GetIORunContext(opr_block->ctx)
+ : streams_->GetRunContext(opr_block->ctx);
this->ExecuteOprBlock(rctx, opr_block);
}
/*!
@@ -122,11 +147,11 @@ class ThreadedEnginePooled : public ThreadedEngine {
switch (opr_block->opr->prop) {
case FnProperty::kCopyFromGPU:
case FnProperty::kCopyToGPU: {
- io_task_queue_.Push(opr_block);
+ io_task_queue_->Push(opr_block);
break;
}
default: {
- task_queue_.Push(opr_block);
+ task_queue_->Push(opr_block);
break;
}
}
diff --git a/tests/cpp/engine/threaded_engine_test.cc b/tests/cpp/engine/threaded_engine_test.cc
index 945c083..92d0958 100644
--- a/tests/cpp/engine/threaded_engine_test.cc
+++ b/tests/cpp/engine/threaded_engine_test.cc
@@ -121,6 +121,23 @@ double EvaluateWorloads(const std::vector<Workload>& workloads,
return dmlc::GetTime() - t;
}
+TEST(Engine, start_stop) {
+ const int num_engine = 3;
+ std::vector<mxnet::Engine*> engine(num_engine);
+ engine[0] = mxnet::engine::CreateNaiveEngine();
+ engine[1] = mxnet::engine::CreateThreadedEnginePooled();
+ engine[2] = mxnet::engine::CreateThreadedEnginePerDevice();
+ std::string type_names[3] = {"NaiveEngine", "ThreadedEnginePooled", "ThreadedEnginePerDevice"};
+
+ for (int i = 0; i < num_engine; ++i) {
+ LOG(INFO) << "Stopping: " << type_names[i];
+ engine[i]->Stop();
+ LOG(INFO) << "Stopped: " << type_names[i] << " Starting...";
+ engine[i]->Start();
+ LOG(INFO) << "Started: " << type_names[i] << " Done...";
+ }
+}
+
TEST(Engine, RandSumExpr) {
std::vector<Workload> workloads;
int num_repeat = 5;
diff --git a/tests/python/unittest/test_engine_import.py b/tests/python/unittest/test_engine_import.py
new file mode 100644
index 0000000..bd34eff
--- /dev/null
+++ b/tests/python/unittest/test_engine_import.py
@@ -0,0 +1,44 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import sys
+
+def test_engine_import():
+ import mxnet
+ def test_import():
+ version = sys.version_info
+ if version >= (3, 4):
+ import importlib
+ importlib.reload(mxnet)
+ elif version >= (3, ):
+ import imp
+ imp.reload(mxnet)
+ else:
+ reload(mxnet)
+ engine_types = ['', 'NaiveEngine', 'ThreadedEngine', 'ThreadedEnginePerDevice']
+
+ for type in engine_types:
+ if not type:
+ os.environ.pop('MXNET_ENGINE_TYPE', None)
+ else:
+ os.environ['MXNET_ENGINE_TYPE'] = type
+ test_import()
+
+if __name__ == '__main__':
+ import nose
+ nose.runmodule()
--
To stop receiving notification emails like this one, please contact
jxie@apache.org.