You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/15 18:28:50 UTC
[incubator-mxnet] branch master updated: Imperative bulk execution
(#8520)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new ca4b683 Imperative bulk execution (#8520)
ca4b683 is described below
commit ca4b683935529197cc7f7971765794065d67016c
Author: Eric Junyuan Xie <pi...@users.noreply.github.com>
AuthorDate: Wed Nov 15 10:28:47 2017 -0800
Imperative bulk execution (#8520)
* imperative bulk
* Update data.py
* Update test_engine.py
---
example/gluon/data.py | 2 +-
include/mxnet/c_api.h | 7 ++++
include/mxnet/engine.h | 20 ++++++---
include/mxnet/ndarray.h | 5 ++-
python/mxnet/__init__.py | 1 +
python/mxnet/engine.py | 75 ++++++++++++++++++++++++++++++++++
src/c_api/c_api.cc | 6 +++
src/engine/threaded_engine.cc | 32 +++++++++++++--
src/engine/threaded_engine.h | 75 ++++++++++++++++++++++++++++++++++
src/executor/graph_executor.cc | 4 +-
src/imperative/cached_op.cc | 2 +
src/imperative/imperative.cc | 4 ++
src/imperative/imperative_utils.h | 34 ++++++---------
src/kvstore/comm.h | 16 ++++++--
src/kvstore/kvstore_dist_server.h | 20 ++++-----
src/kvstore/kvstore_local.h | 4 +-
src/ndarray/ndarray.cc | 36 ++++++++++++----
src/operator/cudnn_convolution-inl.h | 3 +-
src/operator/cudnn_deconvolution-inl.h | 3 +-
src/operator/custom/ndarray_op.cc | 16 +++++---
src/resource.cc | 4 +-
tests/python/unittest/test_engine.py | 36 ++++++++++++++++
22 files changed, 339 insertions(+), 66 deletions(-)
diff --git a/example/gluon/data.py b/example/gluon/data.py
index 30c1a8c..67519e6 100644
--- a/example/gluon/data.py
+++ b/example/gluon/data.py
@@ -115,7 +115,7 @@ def imagenet_iterator(train_data, val_data, batch_size, data_shape, resize=-1):
class DummyIter(mx.io.DataIter):
- def __init__(self, batch_size, data_shape, batches = 5):
+ def __init__(self, batch_size, data_shape, batches = 100):
super(DummyIter, self).__init__(batch_size)
self.data_shape = (batch_size,) + data_shape
self.label_shape = (batch_size,)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index acdd4bd..8ea2b0e 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -233,6 +233,13 @@ MXNET_DLL int MXDumpProfile();
MXNET_DLL int MXSetNumOMPThreads(int thread_num);
/*!
+ * \brief set bulk execution limit
+ * \param bulk_size new bulk_size
+ * \param prev_bulk_size previous bulk_size
+ */
+MXNET_DLL int MXEngineSetBulkSize(int bulk_size, int* prev_bulk_size);
+
+/*!
* \brief get the MXNet library version as an integer
* \param pointer to the integer holding the version number
* \return 0 when success, -1 when failure happens
diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h
index 4c2314e..5a4697d 100644
--- a/include/mxnet/engine.h
+++ b/include/mxnet/engine.h
@@ -221,12 +221,12 @@ class MXNET_API Engine {
* \param opr_name The operator name.
* \tparam SyncFn the synchronous function to be pushed.
*/
- inline void PushSync(SyncFn exec_fn, Context exec_ctx,
- std::vector<VarHandle> const& const_vars,
- std::vector<VarHandle> const& mutable_vars,
- FnProperty prop = FnProperty::kNormal,
- int priority = 0,
- const char* opr_name = nullptr) {
+ virtual void PushSync(SyncFn exec_fn, Context exec_ctx,
+ std::vector<VarHandle> const& const_vars,
+ std::vector<VarHandle> const& mutable_vars,
+ FnProperty prop = FnProperty::kNormal,
+ int priority = 0,
+ const char* opr_name = nullptr) {
this->PushAsync([exec_fn](RunContext ctx, CallbackOnComplete on_complete) {
exec_fn(ctx);
on_complete();
@@ -267,6 +267,14 @@ class MXNET_API Engine {
}
read_vars->resize(rtop - read_vars->begin());
}
+ /*! \brief query current limit for bulk size */
+ virtual int bulk_size() const {
+ return 0;
+ }
+ /*! \brief set maximum limit for bulk size */
+ virtual int set_bulk_size(int) {
+ return 0;
+ }
}; // class Engine
#endif // DMLC_USE_CXX11
} // namespace mxnet
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index 1700084..498a47f 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -335,7 +335,10 @@ class NDArray {
* Push an empty mutable function to flush all preceding reads to the
* variable.
*/
- Engine::Get()->PushSync([](RunContext) {}, Context{}, {}, {ptr_->var});
+ Engine::Get()->PushAsync(
+ [](RunContext, Engine::CallbackOnComplete on_complete) {
+ on_complete();
+ }, Context{}, {}, {ptr_->var});
Engine::Get()->WaitForVar(ptr_->var);
}
/*! \return the associated variable of the ndarray.*/
diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py
index cf0ba37..4e2c4f0 100644
--- a/python/mxnet/__init__.py
+++ b/python/mxnet/__init__.py
@@ -22,6 +22,7 @@
from __future__ import absolute_import
from .context import Context, current_context, cpu, gpu
+from . import engine
from .base import MXNetError
from . import base
from . import contrib
diff --git a/python/mxnet/engine.py b/python/mxnet/engine.py
new file mode 100644
index 0000000..d4d38f1
--- /dev/null
+++ b/python/mxnet/engine.py
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""Engine properties management."""
+from __future__ import absolute_import
+
+import ctypes
+from .base import _LIB, check_call
+
+
+def set_bulk_size(size):
+ """Set size limit on bulk execution.
+
+ Bulk execution bundles many operators to run together.
+ This can improve performance when running a lot of small
+ operators sequentially.
+
+ Parameters
+ ----------
+ size : int
+ Maximum number of operators that can be bundled in a bulk.
+
+ Returns
+ -------
+ int
+ Previous bulk size.
+ """
+ prev = ctypes.c_int()
+ check_call(_LIB.MXEngineSetBulkSize(
+ ctypes.c_int(size), ctypes.byref(prev)))
+ return prev.value
+
+
+class _BulkScope(object):
+ """Scope object for bulk execution."""
+ def __init__(self, size):
+ self._size = size
+ self._old_size = None
+
+ def __enter__(self):
+ self._old_size = set_bulk_size(self._size)
+ return self
+
+ def __exit__(self, ptype, value, trace):
+ set_bulk_size(self._old_size)
+
+
+def bulk(size):
+ """Bulk execution bundles many operators to run together.
+ This can improve performance when running a lot of small
+ operators sequentially.
+
+ Returns a scope for managing bulk size::
+
+ with mx.engine.bulk(10):
+ x = mx.nd.zeros((1,))
+ for i in range(100):
+ x += 1
+ """
+ return _BulkScope(size)
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index ef0d3bb..15cd061 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -136,6 +136,12 @@ int MXSetNumOMPThreads(int thread_num) {
API_END();
}
+int MXEngineSetBulkSize(int bulk_size, int* prev_bulk_size) {
+ API_BEGIN();
+ *prev_bulk_size = Engine::Get()->set_bulk_size(bulk_size);
+ API_END();
+}
+
int MXGetVersion(int *out) {
API_BEGIN();
*out = static_cast<int>(MXNET_VERSION);
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index bc5b81c..a0d9f29 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -267,8 +267,9 @@ void ThreadedEngine::DeleteOperator(OprHandle op) {
deps.insert(deps.end(),
threaded_opr->mutable_vars.begin(),
threaded_opr->mutable_vars.end());
- this->PushSync([threaded_opr](RunContext) {
+ this->PushAsync([threaded_opr](RunContext, CallbackOnComplete on_complete) {
ThreadedOpr::Delete(threaded_opr);
+ on_complete();
}, Context::CPU(), {}, deps, FnProperty::kAsync, 0,
PROFILER_MESSAGE("DeleteOperator"));
}
@@ -304,6 +305,7 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx,
FnProperty prop,
int priority,
const char* opr_name) {
+ BulkFlush();
ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name);
opr->temporary = true;
#if MXNET_USE_PROFILER
@@ -316,20 +318,42 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx,
Push(opr, exec_ctx, priority, profiling);
}
+void ThreadedEngine::PushSync(SyncFn exec_fn, Context exec_ctx,
+ std::vector<VarHandle> const& const_vars,
+ std::vector<VarHandle> const& mutable_vars,
+ FnProperty prop,
+ int priority,
+ const char* opr_name) {
+ BulkStatus& bulk_status = *BulkStatusStore::Get();
+ if (!bulk_status.bulk_size || prop != FnProperty::kNormal || priority) {
+ this->PushAsync([exec_fn](RunContext ctx, CallbackOnComplete on_complete) {
+ exec_fn(ctx);
+ on_complete();
+ }, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name);
+ return;
+ }
+
+ if (bulk_status.count && exec_ctx != bulk_status.ctx) BulkFlush();
+ BulkAppend(exec_fn, exec_ctx, const_vars, mutable_vars);
+ return;
+}
+
void ThreadedEngine::DeleteVariable(SyncFn delete_fn,
Context exec_ctx,
VarHandle var) {
ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var);
- this->PushSync([delete_fn, threaded_var](RunContext ctx) {
+ this->PushAsync([delete_fn, threaded_var](RunContext ctx, CallbackOnComplete on_complete) {
// Mark variable as orphan,
// so during `ThreadedEngine::OnComplete` it could be recycled.
threaded_var->SetToDelete();
delete_fn(ctx);
+ on_complete();
}, exec_ctx, {}, {var}, FnProperty::kDeleteVar, 0,
PROFILER_MESSAGE("DeleteVariable"));
}
void ThreadedEngine::WaitForVar(VarHandle var) {
+ BulkFlush();
ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var);
if (threaded_var->ready_to_read()) return;
if (engine_info_) {
@@ -337,7 +361,7 @@ void ThreadedEngine::WaitForVar(VarHandle var) {
debug_wait_var_ = threaded_var;
}
std::atomic<bool> done{false};
- this->PushSync([this, &done](RunContext) {
+ this->PushAsync([this, &done](RunContext, CallbackOnComplete on_complete) {
if (engine_info_) {
LOG(INFO) << "Sync is executed";
}
@@ -349,6 +373,7 @@ void ThreadedEngine::WaitForVar(VarHandle var) {
if (engine_info_) {
LOG(INFO) << "Sync is notified";
}
+ on_complete();
}, Context::CPU(), {var}, {}, FnProperty::kNormal, 0,
PROFILER_MESSAGE("WaitForVar"));
{
@@ -360,6 +385,7 @@ void ThreadedEngine::WaitForVar(VarHandle var) {
}
void ThreadedEngine::WaitForAll() {
+ BulkFlush();
std::unique_lock<std::mutex> lock{finished_m_};
finished_cv_.wait(lock, [this]() {
return pending_.load() == 0 || kill_.load();
diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h
index e000a22..bbb323d 100644
--- a/src/engine/threaded_engine.h
+++ b/src/engine/threaded_engine.h
@@ -33,6 +33,7 @@
#include <functional>
#include <condition_variable>
#include <atomic>
+#include <utility>
#include <mutex>
#include <string>
#include <thread>
@@ -272,6 +273,12 @@ class ThreadedEngine : public Engine {
FnProperty prop = FnProperty::kNormal,
int priority = 0,
const char* opr_name = nullptr) override;
+ void PushSync(SyncFn exec_fn, Context exec_ctx,
+ std::vector<VarHandle> const& const_vars,
+ std::vector<VarHandle> const& mutable_vars,
+ FnProperty prop = FnProperty::kNormal,
+ int priority = 0,
+ const char* opr_name = nullptr) override;
void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override;
void WaitForVar(VarHandle var) override;
void WaitForAll() override;
@@ -364,7 +371,35 @@ class ThreadedEngine : public Engine {
}
}
+ int bulk_size() const override {
+ return BulkStatusStore::Get()->bulk_size;
+ }
+
+ int set_bulk_size(int bulk_size) override {
+ BulkStatus& bulk_status = *BulkStatusStore::Get();
+ std::swap(bulk_status.bulk_size, bulk_size);
+ if (bulk_status.count >= bulk_status.bulk_size) BulkFlush();
+ return bulk_size;
+ }
+
private:
+ /*! \brief structure for holding bulk execution status */
+ struct BulkStatus {
+ /*! \brief maximum number of ops per bulk */
+ int bulk_size = 0;
+ /*! \brief current number of ops in bulk */
+ int count = 0;
+ /*! \brief context of current ops */
+ Context ctx;
+ /*! \brief current op functions */
+ SyncFn fn;
+ /*! \brief constant variables */
+ std::vector<VarHandle> const_vars;
+ /*! \brief mutable variables */
+ std::vector<VarHandle> mutable_vars;
+ };
+ /*! thread local store for bulk */
+ typedef dmlc::ThreadLocalStore<BulkStatus> BulkStatusStore;
/*!
* \brief check if thee is duplication in const_vars and mutable_vars.
* \param const_vars the variables to read from.
@@ -380,6 +415,46 @@ class ThreadedEngine : public Engine {
inline void OnComplete(ThreadedOpr* threaded_opr);
// callback to the threaded engine
static void OnCompleteStatic(Engine *engine, void *threaded_opr);
+ /*! \brief append an operator to bulk */
+ inline void BulkAppend(SyncFn exec_fn, Context exec_ctx,
+ std::vector<VarHandle> const& const_vars,
+ std::vector<VarHandle> const& mutable_vars) {
+ BulkStatus& bulk_status = *BulkStatusStore::Get();
+ if (!bulk_status.count) {
+ bulk_status.ctx = exec_ctx;
+ bulk_status.fn = std::move(exec_fn);
+ } else {
+ auto prev_fn = std::move(bulk_status.fn);
+ bulk_status.fn = [exec_fn, prev_fn](RunContext rctx) {
+ prev_fn(rctx);
+ exec_fn(rctx);
+ };
+ }
+
+ ++bulk_status.count;
+ bulk_status.const_vars.insert(
+ bulk_status.const_vars.end(), const_vars.begin(), const_vars.end());
+ bulk_status.mutable_vars.insert(
+ bulk_status.mutable_vars.end(), mutable_vars.begin(), mutable_vars.end());
+
+ if (bulk_status.count >= bulk_status.bulk_size) BulkFlush();
+ }
+ /*! \brief flush current bulk to execution */
+ inline void BulkFlush() {
+ BulkStatus& bulk_status = *BulkStatusStore::Get();
+ if (!bulk_status.count) return;
+ bulk_status.count = 0;
+ DeduplicateVarHandle(&bulk_status.const_vars, &bulk_status.mutable_vars);
+ auto fn = std::move(bulk_status.fn);
+ this->PushAsync([fn](RunContext ctx, CallbackOnComplete on_complete) {
+ fn(ctx);
+ on_complete();
+ }, bulk_status.ctx, bulk_status.const_vars, bulk_status.mutable_vars,
+ FnProperty::kNormal, 0, "ImperativeBulk");
+
+ bulk_status.const_vars.clear();
+ bulk_status.mutable_vars.clear();
+ }
/*!
* \brief Number of pending operations.
*/
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index dd48675..2fc7ce2 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -1297,8 +1297,10 @@ void GraphExecutor::InitCachedOps() {
std::copy(mutate_vars.begin(), mutate_vars.end(),
std::inserter(all_vars, all_vars.end()));
// setup exec vars
- Engine::Get()->PushSync([exec](RunContext rctx) {
+ Engine::Get()->PushAsync(
+ [exec](RunContext rctx, Engine::CallbackOnComplete on_complete) {
exec->Setup();
+ on_complete();
}, Context::CPU(), {}, all_vars, FnProperty::kNormal, 0,
PROFILER_MESSAGE("SetupExec"));
auto exec_fun = [exec, is_async, is_gpu] (
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index e9d801f..ec0b9c2 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -380,6 +380,7 @@ OpStatePtr Imperative::CachedOp::Forward(const std::vector<NDArray*>& inputs,
mem_plan, arrays, &array_reqs);
const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
+
Imperative::Get()->RunGraph(
false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
std::move(ref_count), &states, dispatch_modes);
@@ -450,6 +451,7 @@ void Imperative::CachedOp::Backward(
mem_plan, arrays, &array_reqs);
const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
+
Imperative::Get()->RunGraph(
retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
std::move(array_reqs), std::move(ref_count), &states, dispatch_modes);
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index fc35c49..361b971 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -288,6 +288,8 @@ void Imperative::RunGraph(
DTypeVector arg_dtypes;
std::vector<OpReqType> req;
+ int prev_bulk_size = Engine::Get()->set_bulk_size(10);
+
for (size_t i = node_start; i < node_end; ++i) {
const nnvm::IndexedGraph::Node& node = idx[i];
if (node.source->op() == nullptr) continue;
@@ -351,6 +353,8 @@ void Imperative::RunGraph(
if (ref_count[eid] == 0) arrays[eid]->ptr_.reset();
}
}
+
+ Engine::Get()->set_bulk_size(prev_bulk_size);
}
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index dbae9c4..34099d0 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -340,10 +340,9 @@ inline void PushFCompute(const FCompute& fn,
bool is_train = Imperative::Get()->is_training();
std::vector<NDArray> inputs, outputs;
DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs);
- Engine::Get()->PushAsync(
+ Engine::Get()->PushSync(
[ctx, attrs, fn, inputs, outputs, requested, is_train, mutate_idx, req](
- RunContext rctx,
- engine::CallbackOnComplete on_complete) {
+ RunContext rctx) {
std::vector<TBlob> input_blobs, output_blobs;
// pre-fcompute and post-fcompute storage fallback src NDArrays and dst NDArrays
std::vector<NDArray> pre_temp_src, pre_temp_dst, post_temp_dst, post_temp_src;
@@ -364,7 +363,6 @@ inline void PushFCompute(const FCompute& fn,
if (is_gpu) {
rctx.get_stream<gpu>()->Wait();
}
- on_complete();
}, ctx, read_vars, write_vars, FnProperty::kNormal,
0, PROFILER_MESSAGE(op->name.c_str()));
}
@@ -389,21 +387,19 @@ inline void PushFComputeEx(const FComputeEx& fn,
std::vector<NDArray> inputs, outputs;
DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs);
const auto& run = [ctx, exec_type, is_train, attrs, fn, inputs, outputs, requested, req](
- RunContext rctx,
- engine::CallbackOnComplete on_complete) {
- OpContext opctx{is_train, rctx, on_complete, requested};
+ RunContext rctx) {
+ OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
fn(attrs, opctx, inputs, req, outputs);
if (exec_type == ExecType::kSync) {
if (rctx.get_ctx().dev_mask() == gpu::kDevMask) {
rctx.get_stream<gpu>()->Wait();
}
- on_complete();
}
};
if (exec_type == ExecType::kLocal) {
- run(RunContext{ctx, nullptr}, engine::CallbackOnComplete());
+ run(RunContext{ctx, nullptr});
} else {
- Engine::Get()->PushAsync(run, ctx, read_vars, write_vars, FnProperty::kNormal,
+ Engine::Get()->PushSync(run, ctx, read_vars, write_vars, FnProperty::kNormal,
0, PROFILER_MESSAGE(op->name.c_str()));
}
}
@@ -436,21 +432,19 @@ inline void PushOperator(const OpStatePtr& state,
if (fcompute_ex != nullptr && dispatch_mode == DispatchMode::kFComputeEx) {
const auto& run = [state, fcompute_ex, inputs, outputs, requested, is_train,
exec_type, req](
- RunContext rctx,
- engine::CallbackOnComplete on_complete) {
- OpContext opctx{is_train, rctx, on_complete, requested};
+ RunContext rctx) {
+ OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
fcompute_ex(state, opctx, inputs, req, outputs);
if (exec_type == ExecType::kSync) {
if (rctx.get_ctx().dev_mask() == gpu::kDevMask) {
rctx.get_stream<gpu>()->Wait();
}
- on_complete();
}
};
if (exec_type == ExecType::kLocal) {
- run(RunContext{ctx, nullptr}, engine::CallbackOnComplete());
+ run(RunContext{ctx, nullptr});
} else {
- Engine::Get()->PushAsync(run, ctx, read_vars, write_vars, FnProperty::kNormal,
+ Engine::Get()->PushSync(run, ctx, read_vars, write_vars, FnProperty::kNormal,
0, PROFILER_MESSAGE(op->name.c_str()));
}
} else {
@@ -458,11 +452,10 @@ inline void PushOperator(const OpStatePtr& state,
<< "One of FStatefulCompute and FStatefulComputeEx must be registered "
<< "for stateful operator " << op->name;
CHECK(exec_type == ExecType::kSync || exec_type == ExecType::kAsync);
- Engine::Get()->PushAsync(
+ Engine::Get()->PushSync(
[state, fcompute, inputs, outputs, requested, is_train, exec_type, mutate_idx, req](
- RunContext rctx,
- engine::CallbackOnComplete on_complete) {
- OpContext opctx{is_train, rctx, on_complete, requested};
+ RunContext rctx) {
+ OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
std::vector<TBlob> input_blobs, output_blobs;
// pre-fcompute and post-fcompute storage fallback src NDArrays and dst NDArrays
@@ -484,7 +477,6 @@ inline void PushOperator(const OpStatePtr& state,
if (is_gpu) {
rctx.get_stream<gpu>()->Wait();
}
- on_complete();
}
}, ctx, read_vars, write_vars, FnProperty::kNormal,
0, PROFILER_MESSAGE(op->name.c_str()));
diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h
index deed1a1..028ab59 100644
--- a/src/kvstore/comm.h
+++ b/src/kvstore/comm.h
@@ -140,8 +140,10 @@ class CommCPU : public Comm {
const_vars[i-1] = reduce[i].var();
}
- Engine::Get()->PushSync([reduce, this](RunContext rctx) {
+ Engine::Get()->PushAsync(
+ [reduce, this](RunContext rctx, Engine::CallbackOnComplete on_complete) {
ReduceSumCPU(reduce);
+ on_complete();
}, Context::CPU(), const_vars, {reduce[0].var()},
FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce"));
@@ -163,13 +165,15 @@ class CommCPU : public Comm {
const_vars[i] = reduce[i].var();
}
auto result = buf.merged;
- Engine::Get()->PushSync([reduce, result, this](RunContext rctx) {
+ Engine::Get()->PushAsync(
+ [reduce, result, this](RunContext rctx, Engine::CallbackOnComplete on_complete) {
NDArray out = result;
Resource rsc = ResourceManager::Get()->Request(rctx.ctx,
ResourceRequest(ResourceRequest::kTempSpace));
is_serial_push_?
ReduceSumCPUExSerial(reduce, &out)
: mxnet::ndarray::ElementwiseSum(rctx.get_stream<cpu>(), rsc, reduce, &out);
+ on_complete();
}, Context::CPU(), const_vars, {result.var()},
FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce"));
}
@@ -217,21 +221,25 @@ class CommCPU : public Comm {
const bool is_to_gpu = out->ctx().dev_mask() == Context::kGPU;
NDArray out_cpu = is_to_gpu? NDArray(kRowSparseStorage, src.shape(),
src.ctx(), true, src.dtype(), src.aux_types()) : *out;
- Engine::Get()->PushSync([=](RunContext rctx) {
+ Engine::Get()->PushAsync(
+ [=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
const TBlob& indices = row_id.data();
NDArray temp = out_cpu; // get rid of const qualifier
op::SparseRetainOpForwardRspImpl<cpu>(rctx.get_stream<cpu>(),
src, indices, kWriteTo,
&temp);
+ on_complete();
}, Context::CPU(), {src.var(), row_id.var()}, {out_cpu.var()},
FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain"));
if (is_to_gpu) {
CopyFromTo(out_cpu, out, priority);
}
} else { // direct copy rows
- Engine::Get()->PushSync([=](RunContext rctx) {
+ Engine::Get()->PushAsync(
+ [=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
CopyRetainedRowsToGPU(rctx.get_stream<cpu>(), rctx.get_stream<gpu>(),
src, row_id, out);
+ on_complete();
}, out->ctx(), {src.var(), row_id.var()}, {out->var()},
FnProperty::kCopyToGPU, priority, PROFILER_MESSAGE("KVStoreCopyRetainedRowsToGPU"));
}
diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h
index bedb539..1503408 100644
--- a/src/kvstore/kvstore_dist_server.h
+++ b/src/kvstore/kvstore_dist_server.h
@@ -230,13 +230,15 @@ class KVStoreDistServer {
TBlob recv_blob(data, dshape, cpu::kDevMask); // NOLINT(*)
NDArray recved = NDArray(recv_blob, 0);
stored = NDArray(kRowSparseStorage, dshape, Context());
- Engine::Get()->PushSync([recved, stored](RunContext ctx) {
+ Engine::Get()->PushAsync(
+ [recved, stored](RunContext ctx, Engine::CallbackOnComplete on_complete) {
NDArray rsp = stored;
stored.CheckAndAlloc({mshadow::Shape1(recved.shape()[0])});
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
op::PopulateFullIdxRspImpl(s, &rsp);
mshadow::Copy(rsp.data().FlatTo1D<cpu, float>(),
recved.data().FlatTo1D<cpu, float>(), s);
+ on_complete();
}, recved.ctx(), {recved.var()}, {stored.var()},
FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
stored.WaitToRead();
@@ -285,15 +287,13 @@ class KVStoreDistServer {
// TODO(haibin) override + operator for row_sparse NDArray
// instead of calling BinaryComputeRspRsp directly
using namespace mshadow;
- Engine::Get()->PushSync([recved, merged, out](RunContext ctx) {
- std::vector<NDArray> inputs, outputs;
- inputs.push_back(recved);
- inputs.push_back(merged.array);
- outputs.push_back(out);
- op::ElemwiseBinaryOp::ComputeEx<cpu, mshadow::op::plus>(
- {}, {}, inputs, {kWriteTo}, outputs);
- }, recved.ctx(), const_vars, {out.var()},
- FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
+ Engine::Get()->PushAsync(
+ [recved, merged, out](RunContext ctx, Engine::CallbackOnComplete on_complete) {
+ op::ElemwiseBinaryOp::ComputeEx<cpu, mshadow::op::plus>(
+ {}, {}, {recved, merged.array}, {kWriteTo}, {out});
+ on_complete();
+ }, recved.ctx(), const_vars, {out.var()},
+ FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
CopyFromTo(out, &merged.array, 0);
}
merged.request.push_back(req_meta);
diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h
index 15a4c60..4038185 100644
--- a/src/kvstore/kvstore_local.h
+++ b/src/kvstore/kvstore_local.h
@@ -347,7 +347,8 @@ class KVStoreLocal : public KVStore {
void Unique(NDArray *out, int priority = 0) {
CHECK_EQ(out->ctx().dev_mask(), pinned_ctx_.dev_mask())
<< "Unique expects input with `pinned_ctx_`";
- Engine::Get()->PushSync([out](RunContext rctx) {
+ Engine::Get()->PushAsync(
+ [out](RunContext rctx, Engine::CallbackOnComplete on_complete) {
NDArray *output = out;
CHECK_EQ(out->shape().ndim(), 1) << "Unique expects 1D inputs";
const auto size = out->shape()[0];
@@ -358,6 +359,7 @@ class KVStoreLocal : public KVStore {
auto num_unique_idx = std::unique(dptr, dptr + size) - dptr;
*output = output->Reshape(mshadow::Shape1(num_unique_idx));
});
+ on_complete();
}, pinned_ctx_, {}, {out->var()},
FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreUnique"));
out->WaitToRead();
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 797cc99..18ca211 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -527,25 +527,33 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority) {
if (from.var() != to.var()) const_vars.push_back(from.var());
if (a == cpu::kDevMask && b == cpu::kDevMask) {
- Engine::Get()->PushSync([from, to](RunContext ctx) {
+ Engine::Get()->PushAsync(
+ [from, to](RunContext ctx, Engine::CallbackOnComplete on_complete) {
CopyFromToImpl<cpu, cpu>(from, to, ctx);
+ on_complete();
}, from.ctx(), const_vars, {to.var()},
FnProperty::kNormal, priority, PROFILER_MESSAGE("CopyCPU2CPU"));
} else {
#if MXNET_USE_CUDA
if (a == cpu::kDevMask && b == gpu::kDevMask) {
- Engine::Get()->PushSync([from, to](RunContext ctx) {
+ Engine::Get()->PushAsync(
+ [from, to](RunContext ctx, Engine::CallbackOnComplete on_complete) {
CopyFromToImpl<cpu, gpu>(from, to, ctx);
+ on_complete();
}, to.ctx(), const_vars, {to.var()},
FnProperty::kCopyToGPU, priority, PROFILER_MESSAGE("CopyCPU2GPU"));
} else if (a == gpu::kDevMask && b == cpu::kDevMask) {
- Engine::Get()->PushSync([from, to](RunContext ctx) {
+ Engine::Get()->PushAsync(
+ [from, to](RunContext ctx, Engine::CallbackOnComplete on_complete) {
CopyFromToImpl<gpu, cpu>(from, to, ctx);
+ on_complete();
}, from.ctx(), const_vars, {to.var()},
FnProperty::kCopyFromGPU, priority, PROFILER_MESSAGE("CopyGPU2CPU"));
} else if (a == gpu::kDevMask && b == gpu::kDevMask) {
- Engine::Get()->PushSync([from, to](RunContext ctx) {
+ Engine::Get()->PushAsync(
+ [from, to](RunContext ctx, Engine::CallbackOnComplete on_complete) {
CopyFromToImpl<gpu, gpu>(from, to, ctx);
+ on_complete();
}, from.ctx(), const_vars, {to.var()},
from.dtype() != to.dtype() ? FnProperty::kNormal : FnProperty::kCopyFromGPU,
priority, PROFILER_MESSAGE("CopyGPU2GPU"));
@@ -1077,12 +1085,14 @@ void NDArray::SyncCopyFromCPU(const void *data, size_t size) const {
ndarray::Copy<cpu, cpu>(src, &dst, Context::CPU(), Context::CPU(), rctx);
} else {
#if MXNET_USE_CUDA
- Engine::Get()->PushSync([&](RunContext rctx) {
+ Engine::Get()->PushAsync(
+ [&](RunContext rctx, Engine::CallbackOnComplete on_complete) {
TBlob dst = this->data();
ndarray::Copy<cpu, gpu>(src, &dst,
Context::CPU(), this->ctx(), rctx);
// Wait GPU kernel to complete
rctx.get_stream<gpu>()->Wait();
+ on_complete();
}, this->ctx(), {}, {this->var()},
FnProperty::kCopyToGPU, 0, PROFILER_MESSAGE("SyncCopyCPU2GPU"));
this->WaitToRead();
@@ -1145,27 +1155,33 @@ void NDArray::SyncCopyFromNDArray(const NDArray& src, int i, int j) {
} else {
#if MXNET_USE_CUDA
if (src_dev_mask == cpu::kDevMask && dst_dev_mask == gpu::kDevMask) {
- Engine::Get()->PushSync([&](RunContext rctx) {
+ Engine::Get()->PushAsync(
+ [&](RunContext rctx, Engine::CallbackOnComplete on_complete) {
const TBlob src_data = (i >= 0? src.aux_data(i) : src.data());
TBlob dst_data = get_dst_data(src_data.shape_);
ndarray::Copy<cpu, gpu>(src_data, &dst_data, src.ctx(), this->ctx(), rctx);
rctx.get_stream<gpu>()->Wait();
+ on_complete();
}, this->ctx(), const_vars, {this->var()},
FnProperty::kCopyToGPU, 0, PROFILER_MESSAGE("SyncCopyFromNDArrayCPU2GPU"));
} else if (src_dev_mask == gpu::kDevMask && dst_dev_mask == cpu::kDevMask) {
- Engine::Get()->PushSync([&](RunContext rctx) {
+ Engine::Get()->PushAsync(
+ [&](RunContext rctx, Engine::CallbackOnComplete on_complete) {
const TBlob src_data = (i >= 0? src.aux_data(i) : src.data());
TBlob dst_data = get_dst_data(src_data.shape_);
ndarray::Copy<gpu, cpu>(src_data, &dst_data, src.ctx(), this->ctx(), rctx);
rctx.get_stream<gpu>()->Wait();
+ on_complete();
}, this->ctx(), const_vars, {this->var()},
FnProperty::kCopyFromGPU, 0, PROFILER_MESSAGE("SyncCopyFromNDArrayGPU2CPU"));
} else if (src_dev_mask == gpu::kDevMask && dst_dev_mask == gpu::kDevMask) {
- Engine::Get()->PushSync([&](RunContext rctx) {
+ Engine::Get()->PushAsync(
+ [&](RunContext rctx, Engine::CallbackOnComplete on_complete) {
const TBlob src_data = (i >= 0? src.aux_data(i) : src.data());
TBlob dst_data = get_dst_data(src_data.shape_);
ndarray::Copy<gpu, gpu>(src_data, &dst_data, src.ctx(), this->ctx(), rctx);
rctx.get_stream<gpu>()->Wait();
+ on_complete();
}, this->ctx(), const_vars, {this->var()},
src.dtype() != this->dtype() ? FnProperty::kNormal : FnProperty::kCopyFromGPU,
0, PROFILER_MESSAGE("SyncCopyFromNDArrayGPU2GPU"));
@@ -1200,11 +1216,13 @@ void NDArray::SyncCopyToCPU(void *data, size_t size) const {
Context::CPU(), Context::CPU(), rctx);
} else {
#if MXNET_USE_CUDA
- Engine::Get()->PushSync([&](RunContext rctx) {
+ Engine::Get()->PushAsync(
+ [&](RunContext rctx, Engine::CallbackOnComplete on_complete) {
ndarray::Copy<gpu, cpu>(this->data(), &dst,
this->ctx(), Context::CPU(), rctx);
// Wait GPU kernel to complete
rctx.get_stream<gpu>()->Wait();
+ on_complete();
}, this->ctx(), {this->var()}, {},
FnProperty::kCopyFromGPU, 0, PROFILER_MESSAGE("SyncCopyGPU2CPU"));
this->WaitToWrite();
diff --git a/src/operator/cudnn_convolution-inl.h b/src/operator/cudnn_convolution-inl.h
index b2b5994..d42ec9c 100644
--- a/src/operator/cudnn_convolution-inl.h
+++ b/src/operator/cudnn_convolution-inl.h
@@ -586,7 +586,7 @@ class CuDNNConvolutionOp : public Operator {
&back_algo_w_)) {
// Not in algo registry, must determine via *Get*() or *Find*()
Engine::VarHandle var = Engine::Get()->NewVariable();
- Engine::Get()->PushSync([=](RunContext rctx) {
+ Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
mshadow::Stream<gpu> *s = rctx.get_stream<gpu>();
CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
size_t workspace_byte = static_cast<size_t>(param_.workspace * sizeof(DType));
@@ -776,6 +776,7 @@ class CuDNNConvolutionOp : public Operator {
cudnn_backward_compute_type,
SMArch(ctx.dev_id), this->forward_algo_,
this->back_algo_, this->back_algo_w_);
+ on_complete();
}, ctx, {}, {var});
Engine::Get()->WaitForVar(var);
Engine::Get()->DeleteVariable([](RunContext s) {}, ctx, var);
diff --git a/src/operator/cudnn_deconvolution-inl.h b/src/operator/cudnn_deconvolution-inl.h
index 5e9b7c5..95ab596 100644
--- a/src/operator/cudnn_deconvolution-inl.h
+++ b/src/operator/cudnn_deconvolution-inl.h
@@ -605,7 +605,7 @@ class CuDNNDeconvolutionOp : public Operator {
&back_algo_, &back_algo_w_)) {
// Not in algo registry, must determine via *Get*() or *Find*()
Engine::VarHandle var = Engine::Get()->NewVariable();
- Engine::Get()->PushSync([=](RunContext rctx) {
+ Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
mshadow::Stream <gpu> *s = rctx.get_stream<gpu>();
CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
size_t workspace_byte = static_cast<size_t>(param_.workspace * sizeof(DType));
@@ -798,6 +798,7 @@ class CuDNNDeconvolutionOp : public Operator {
cudnn_backward_compute_type,
SMArch(ctx.dev_id), this->forward_algo_,
this->back_algo_, this->back_algo_w_);
+ on_complete();
}, ctx, {}, {var});
Engine::Get()->WaitForVar(var);
Engine::Get()->DeleteVariable([](RunContext s) {}, ctx, var);
diff --git a/src/operator/custom/ndarray_op.cc b/src/operator/custom/ndarray_op.cc
index 48426ba..66bdfc7 100644
--- a/src/operator/custom/ndarray_op.cc
+++ b/src/operator/custom/ndarray_op.cc
@@ -84,9 +84,11 @@ void NDArrayOp<xpu>::Forward(const OpContext &ctx,
}
CHECK(param_.pinfo->forward(ptrs.size(), ptrs.data(), tags.data(), param_.pinfo->p_forward));
- Engine::Get()->PushSync([ndcpy, ctx](RunContext rctx) {ctx.async_on_complete(); },
- ndctx, ndvar, {}, FnProperty::kNormal, 0,
- PROFILER_MESSAGE("NDArrayOpForward"));
+ Engine::Get()->PushAsync(
+ [ndcpy, ctx](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+ ctx.async_on_complete();
+ on_complete();
+ }, ndctx, ndvar, {}, FnProperty::kNormal, 0, PROFILER_MESSAGE("NDArrayOpForward"));
}
template<typename xpu>
@@ -131,9 +133,11 @@ void NDArrayOp<xpu>::Backward(const OpContext &ctx,
}
CHECK(param_.pinfo->backward(ptrs.size(), ptrs.data(), tags.data(), param_.pinfo->p_backward));
- Engine::Get()->PushSync([ndcpy, ctx](RunContext rctx){ ctx.async_on_complete(); },
- ndctx, ndvar, {}, FnProperty::kNormal, 0,
- PROFILER_MESSAGE("NDArrayOpBackward"));
+ Engine::Get()->PushAsync(
+ [ndcpy, ctx](RunContext rctx, Engine::CallbackOnComplete on_complete){
+ ctx.async_on_complete();
+ on_complete();
+ }, ndctx, ndvar, {}, FnProperty::kNormal, 0, PROFILER_MESSAGE("NDArrayOpBackward"));
}
Operator* NDArrayOpProp::CreateOperator(Context ctx) const {
diff --git a/src/resource.cc b/src/resource.cc
index 4c2dbee..d591651 100644
--- a/src/resource.cc
+++ b/src/resource.cc
@@ -186,9 +186,11 @@ class ResourceManagerImpl : public ResourceManager {
inline void Seed(uint32_t global_seed) {
uint32_t seed = ctx.dev_id + global_seed * kRandMagic;
mshadow::Random<xpu> *r = prnd;
- Engine::Get()->PushSync([r, seed](RunContext rctx) {
+ Engine::Get()->PushAsync(
+ [r, seed](RunContext rctx, Engine::CallbackOnComplete on_complete) {
r->set_stream(rctx.get_stream<xpu>());
r->Seed(seed);
+ on_complete();
}, ctx, {}, {resource.var},
FnProperty::kNormal, 0, PROFILER_MESSAGE("ResourceRandomSetSeed"));
}
diff --git a/tests/python/unittest/test_engine.py b/tests/python/unittest/test_engine.py
new file mode 100644
index 0000000..29b7b82
--- /dev/null
+++ b/tests/python/unittest/test_engine.py
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import nose
+import mxnet as mx
+
+def test_bulk():
+ with mx.engine.bulk(10):
+ x = mx.nd.ones((10,))
+ x *= 2
+ x += 1
+ x.wait_to_read()
+ x += 1
+ assert (x.asnumpy() == 4).all()
+ for i in range(100):
+ x += 1
+ assert (x.asnumpy() == 104).all()
+
+
+if __name__ == '__main__':
+ import nose
+ nose.runmodule()
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].