You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/15 18:28:53 UTC

[GitHub] piiswrong closed pull request #8520: Imperative bulk execution

piiswrong closed pull request #8520: Imperative bulk execution
URL: https://github.com/apache/incubator-mxnet/pull/8520
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/gluon/data.py b/example/gluon/data.py
index 30c1a8c59b..67519e6a20 100644
--- a/example/gluon/data.py
+++ b/example/gluon/data.py
@@ -115,7 +115,7 @@ def imagenet_iterator(train_data, val_data, batch_size, data_shape, resize=-1):
 
 
 class DummyIter(mx.io.DataIter):
-    def __init__(self, batch_size, data_shape, batches = 5):
+    def __init__(self, batch_size, data_shape, batches = 100):
         super(DummyIter, self).__init__(batch_size)
         self.data_shape = (batch_size,) + data_shape
         self.label_shape = (batch_size,)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 55b840dd2c..c315bc7e88 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -233,6 +233,13 @@ MXNET_DLL int MXDumpProfile();
 MXNET_DLL int MXSetNumOMPThreads(int thread_num);
 
 /*!
+ * \brief set bulk execution limit
+ * \param bulk_size new bulk_size
+ * \param prev_bulk_size previous bulk_size
+ */
+MXNET_DLL int MXEngineSetBulkSize(int bulk_size, int* prev_bulk_size);
+
+/*!
  * \brief get the MXNet library version as an integer
  * \param pointer to the integer holding the version number
  * \return 0 when success, -1 when failure happens
diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h
index 4c2314e176..5a4697df4b 100644
--- a/include/mxnet/engine.h
+++ b/include/mxnet/engine.h
@@ -221,12 +221,12 @@ class MXNET_API Engine {
    * \param opr_name The operator name.
    * \tparam SyncFn the synchronous function to be pushed.
    */
-  inline void PushSync(SyncFn exec_fn, Context exec_ctx,
-                       std::vector<VarHandle> const& const_vars,
-                       std::vector<VarHandle> const& mutable_vars,
-                       FnProperty prop = FnProperty::kNormal,
-                       int priority = 0,
-                       const char* opr_name = nullptr) {
+  virtual void PushSync(SyncFn exec_fn, Context exec_ctx,
+                        std::vector<VarHandle> const& const_vars,
+                        std::vector<VarHandle> const& mutable_vars,
+                        FnProperty prop = FnProperty::kNormal,
+                        int priority = 0,
+                        const char* opr_name = nullptr) {
     this->PushAsync([exec_fn](RunContext ctx, CallbackOnComplete on_complete) {
         exec_fn(ctx);
         on_complete();
@@ -267,6 +267,14 @@ class MXNET_API Engine {
     }
     read_vars->resize(rtop - read_vars->begin());
   }
+  /*! \brief query current limit for bulk size */
+  virtual int bulk_size() const {
+    return 0;
+  }
+  /*! \brief set maximum limit for bulk size */
+  virtual int set_bulk_size(int) {
+    return 0;
+  }
 };  // class Engine
 #endif  // DMLC_USE_CXX11
 }  // namespace mxnet
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index 84ee9fa5e4..4636eeb1f8 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -326,7 +326,10 @@ class NDArray {
      * Push an empty mutable function to flush all preceding reads to the
      * variable.
      */
-    Engine::Get()->PushSync([](RunContext) {}, Context{}, {}, {ptr_->var});
+    Engine::Get()->PushAsync(
+      [](RunContext, Engine::CallbackOnComplete on_complete) {
+        on_complete();
+      }, Context{}, {}, {ptr_->var});
     Engine::Get()->WaitForVar(ptr_->var);
   }
   /*! \return the associated variable of the ndarray.*/
diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py
index cf0ba37ab9..4e2c4f0134 100644
--- a/python/mxnet/__init__.py
+++ b/python/mxnet/__init__.py
@@ -22,6 +22,7 @@
 from __future__ import absolute_import
 
 from .context import Context, current_context, cpu, gpu
+from . import engine
 from .base import MXNetError
 from . import base
 from . import contrib
diff --git a/python/mxnet/engine.py b/python/mxnet/engine.py
new file mode 100644
index 0000000000..d4d38f1f29
--- /dev/null
+++ b/python/mxnet/engine.py
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""Engine properties management."""
+from __future__ import absolute_import
+
+import ctypes
+from .base import _LIB, check_call
+
+
+def set_bulk_size(size):
+    """Set size limit on bulk execution.
+
+    Bulk execution bundles many operators to run together.
+    This can improve performance when running a lot of small
+    operators sequentially.
+
+    Parameters
+    ----------
+    size : int
+        Maximum number of operators that can be bundled in a bulk.
+
+    Returns
+    -------
+    int
+        Previous bulk size.
+    """
+    prev = ctypes.c_int()
+    check_call(_LIB.MXEngineSetBulkSize(
+        ctypes.c_int(size), ctypes.byref(prev)))
+    return prev.value
+
+
+class _BulkScope(object):
+    """Scope object for bulk execution."""
+    def __init__(self, size):
+        self._size = size
+        self._old_size = None
+
+    def __enter__(self):
+        self._old_size = set_bulk_size(self._size)
+        return self
+
+    def __exit__(self, ptype, value, trace):
+        set_bulk_size(self._old_size)
+
+
+def bulk(size):
+    """Bulk execution bundles many operators to run together.
+    This can improve performance when running a lot of small
+    operators sequentially.
+
+    Returns a scope for managing bulk size::
+
+        with mx.engine.bulk(10):
+            x = mx.nd.zeros((1,))
+            for i in range(100):
+                x += 1
+    """
+    return _BulkScope(size)
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 1d348a5b40..b7a98f8290 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -136,6 +136,12 @@ int MXSetNumOMPThreads(int thread_num) {
   API_END();
 }
 
+int MXEngineSetBulkSize(int bulk_size, int* prev_bulk_size) {
+  API_BEGIN();
+  *prev_bulk_size = Engine::Get()->set_bulk_size(bulk_size);
+  API_END();
+}
+
 int MXGetVersion(int *out) {
   API_BEGIN();
   *out = static_cast<int>(MXNET_VERSION);
diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc
index bc5b81c568..a0d9f29984 100644
--- a/src/engine/threaded_engine.cc
+++ b/src/engine/threaded_engine.cc
@@ -267,8 +267,9 @@ void ThreadedEngine::DeleteOperator(OprHandle op) {
   deps.insert(deps.end(),
               threaded_opr->mutable_vars.begin(),
               threaded_opr->mutable_vars.end());
-  this->PushSync([threaded_opr](RunContext) {
+  this->PushAsync([threaded_opr](RunContext, CallbackOnComplete on_complete) {
       ThreadedOpr::Delete(threaded_opr);
+      on_complete();
     }, Context::CPU(), {}, deps, FnProperty::kAsync, 0,
     PROFILER_MESSAGE("DeleteOperator"));
 }
@@ -304,6 +305,7 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx,
                                FnProperty prop,
                                int priority,
                                const char* opr_name) {
+  BulkFlush();
   ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name);
   opr->temporary = true;
 #if MXNET_USE_PROFILER
@@ -316,20 +318,42 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx,
   Push(opr, exec_ctx, priority, profiling);
 }
 
+void ThreadedEngine::PushSync(SyncFn exec_fn, Context exec_ctx,
+                              std::vector<VarHandle> const& const_vars,
+                              std::vector<VarHandle> const& mutable_vars,
+                              FnProperty prop,
+                              int priority,
+                              const char* opr_name) {
+  BulkStatus& bulk_status = *BulkStatusStore::Get();
+  if (!bulk_status.bulk_size || prop != FnProperty::kNormal || priority) {
+    this->PushAsync([exec_fn](RunContext ctx, CallbackOnComplete on_complete) {
+        exec_fn(ctx);
+        on_complete();
+      }, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name);
+    return;
+  }
+
+  if (bulk_status.count && exec_ctx != bulk_status.ctx) BulkFlush();
+  BulkAppend(exec_fn, exec_ctx, const_vars, mutable_vars);
+  return;
+}
+
 void ThreadedEngine::DeleteVariable(SyncFn delete_fn,
                                     Context exec_ctx,
                                     VarHandle var) {
   ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var);
-  this->PushSync([delete_fn, threaded_var](RunContext ctx) {
+  this->PushAsync([delete_fn, threaded_var](RunContext ctx, CallbackOnComplete on_complete) {
       // Mark variable as orphan,
       // so during `ThreadedEngine::OnComplete` it could be recycled.
       threaded_var->SetToDelete();
       delete_fn(ctx);
+      on_complete();
     }, exec_ctx, {}, {var}, FnProperty::kDeleteVar, 0,
     PROFILER_MESSAGE("DeleteVariable"));
 }
 
 void ThreadedEngine::WaitForVar(VarHandle var) {
+  BulkFlush();
   ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var);
   if (threaded_var->ready_to_read()) return;
   if (engine_info_) {
@@ -337,7 +361,7 @@ void ThreadedEngine::WaitForVar(VarHandle var) {
     debug_wait_var_ = threaded_var;
   }
   std::atomic<bool> done{false};
-  this->PushSync([this, &done](RunContext) {
+  this->PushAsync([this, &done](RunContext, CallbackOnComplete on_complete) {
       if (engine_info_) {
         LOG(INFO) << "Sync is executed";
       }
@@ -349,6 +373,7 @@ void ThreadedEngine::WaitForVar(VarHandle var) {
       if (engine_info_) {
         LOG(INFO) << "Sync is notified";
       }
+      on_complete();
     }, Context::CPU(), {var}, {}, FnProperty::kNormal, 0,
     PROFILER_MESSAGE("WaitForVar"));
   {
@@ -360,6 +385,7 @@ void ThreadedEngine::WaitForVar(VarHandle var) {
 }
 
 void ThreadedEngine::WaitForAll() {
+  BulkFlush();
   std::unique_lock<std::mutex> lock{finished_m_};
   finished_cv_.wait(lock, [this]() {
       return pending_.load() == 0 || kill_.load();
diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h
index e000a22c22..bbb323d722 100644
--- a/src/engine/threaded_engine.h
+++ b/src/engine/threaded_engine.h
@@ -33,6 +33,7 @@
 #include <functional>
 #include <condition_variable>
 #include <atomic>
+#include <utility>
 #include <mutex>
 #include <string>
 #include <thread>
@@ -272,6 +273,12 @@ class ThreadedEngine : public Engine {
                  FnProperty prop = FnProperty::kNormal,
                  int priority = 0,
                  const char* opr_name = nullptr) override;
+  void PushSync(SyncFn exec_fn, Context exec_ctx,
+                std::vector<VarHandle> const& const_vars,
+                std::vector<VarHandle> const& mutable_vars,
+                FnProperty prop = FnProperty::kNormal,
+                int priority = 0,
+                const char* opr_name = nullptr) override;
   void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override;
   void WaitForVar(VarHandle var) override;
   void WaitForAll() override;
@@ -364,7 +371,35 @@ class ThreadedEngine : public Engine {
     }
   }
 
+  int bulk_size() const override {
+    return BulkStatusStore::Get()->bulk_size;
+  }
+
+  int set_bulk_size(int bulk_size) override {
+    BulkStatus& bulk_status = *BulkStatusStore::Get();
+    std::swap(bulk_status.bulk_size, bulk_size);
+    if (bulk_status.count >= bulk_status.bulk_size) BulkFlush();
+    return bulk_size;
+  }
+
  private:
+  /*! \brief structure for holding bulk execution status */
+  struct BulkStatus {
+    /*! \brief maximum number of ops per bulk */
+    int bulk_size = 0;
+    /*! \brief current number of ops in bulk */
+    int count = 0;
+    /*! \brief context of current ops */
+    Context ctx;
+    /*! \brief current op functions */
+    SyncFn fn;
+    /*! \brief constant variables */
+    std::vector<VarHandle> const_vars;
+    /*! \brief mutable variables */
+    std::vector<VarHandle> mutable_vars;
+  };
+  /*! thread local store for bulk */
+  typedef dmlc::ThreadLocalStore<BulkStatus> BulkStatusStore;
   /*!
    * \brief check if thee is duplication in const_vars and mutable_vars.
    * \param const_vars the variables to read from.
@@ -380,6 +415,46 @@ class ThreadedEngine : public Engine {
   inline void OnComplete(ThreadedOpr* threaded_opr);
   // callback to the threaded engine
   static void OnCompleteStatic(Engine *engine, void *threaded_opr);
+  /*! \brief append an operator to bulk */
+  inline void BulkAppend(SyncFn exec_fn, Context exec_ctx,
+                         std::vector<VarHandle> const& const_vars,
+                         std::vector<VarHandle> const& mutable_vars) {
+    BulkStatus& bulk_status = *BulkStatusStore::Get();
+    if (!bulk_status.count) {
+      bulk_status.ctx = exec_ctx;
+      bulk_status.fn = std::move(exec_fn);
+    } else {
+      auto prev_fn = std::move(bulk_status.fn);
+      bulk_status.fn = [exec_fn, prev_fn](RunContext rctx) {
+          prev_fn(rctx);
+          exec_fn(rctx);
+        };
+    }
+
+    ++bulk_status.count;
+    bulk_status.const_vars.insert(
+        bulk_status.const_vars.end(), const_vars.begin(), const_vars.end());
+    bulk_status.mutable_vars.insert(
+        bulk_status.mutable_vars.end(), mutable_vars.begin(), mutable_vars.end());
+
+    if (bulk_status.count >= bulk_status.bulk_size) BulkFlush();
+  }
+  /*! \brief flush current bulk to execution */
+  inline void BulkFlush() {
+    BulkStatus& bulk_status = *BulkStatusStore::Get();
+    if (!bulk_status.count) return;
+    bulk_status.count = 0;
+    DeduplicateVarHandle(&bulk_status.const_vars, &bulk_status.mutable_vars);
+    auto fn = std::move(bulk_status.fn);
+    this->PushAsync([fn](RunContext ctx, CallbackOnComplete on_complete) {
+        fn(ctx);
+        on_complete();
+      }, bulk_status.ctx, bulk_status.const_vars, bulk_status.mutable_vars,
+      FnProperty::kNormal, 0, "ImperativeBulk");
+
+    bulk_status.const_vars.clear();
+    bulk_status.mutable_vars.clear();
+  }
   /*!
    * \brief Number of pending operations.
    */
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index dd4867559d..2fc7ce2338 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -1297,8 +1297,10 @@ void GraphExecutor::InitCachedOps() {
     std::copy(mutate_vars.begin(), mutate_vars.end(),
               std::inserter(all_vars, all_vars.end()));
     // setup exec vars
-    Engine::Get()->PushSync([exec](RunContext rctx) {
+    Engine::Get()->PushAsync(
+      [exec](RunContext rctx, Engine::CallbackOnComplete on_complete) {
         exec->Setup();
+        on_complete();
       }, Context::CPU(), {}, all_vars, FnProperty::kNormal, 0,
       PROFILER_MESSAGE("SetupExec"));
     auto exec_fun = [exec, is_async, is_gpu] (
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 60d66db485..2a8702f789 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -381,6 +381,7 @@ OpStatePtr Imperative::CachedOp::Forward(const std::vector<NDArray*>& inputs,
                  mem_plan, arrays, &array_reqs);
 
   const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
+
   Imperative::Get()->RunGraph(
       false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
       std::move(ref_count), &states, dispatch_modes);
@@ -451,6 +452,7 @@ void Imperative::CachedOp::Backward(
                  mem_plan, arrays, &array_reqs);
 
   const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
+
   Imperative::Get()->RunGraph(
       retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
       std::move(array_reqs), std::move(ref_count), &states, dispatch_modes);
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index fc35c492f7..361b971a2d 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -288,6 +288,8 @@ void Imperative::RunGraph(
   DTypeVector arg_dtypes;
   std::vector<OpReqType> req;
 
+  int prev_bulk_size = Engine::Get()->set_bulk_size(10);
+
   for (size_t i = node_start; i < node_end; ++i) {
     const nnvm::IndexedGraph::Node& node = idx[i];
     if (node.source->op() == nullptr) continue;
@@ -351,6 +353,8 @@ void Imperative::RunGraph(
       if (ref_count[eid] == 0) arrays[eid]->ptr_.reset();
     }
   }
+
+  Engine::Get()->set_bulk_size(prev_bulk_size);
 }
 
 
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index dbae9c4f4d..34099d029c 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -340,10 +340,9 @@ inline void PushFCompute(const FCompute& fn,
   bool is_train = Imperative::Get()->is_training();
   std::vector<NDArray> inputs, outputs;
   DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs);
-  Engine::Get()->PushAsync(
+  Engine::Get()->PushSync(
     [ctx, attrs, fn, inputs, outputs, requested, is_train, mutate_idx, req](
-        RunContext rctx,
-        engine::CallbackOnComplete on_complete) {
+        RunContext rctx) {
       std::vector<TBlob> input_blobs, output_blobs;
       // pre-fcompute and post-fcompute storage fallback src NDArrays and dst NDArrays
       std::vector<NDArray> pre_temp_src, pre_temp_dst, post_temp_dst, post_temp_src;
@@ -364,7 +363,6 @@ inline void PushFCompute(const FCompute& fn,
       if (is_gpu) {
         rctx.get_stream<gpu>()->Wait();
       }
-      on_complete();
     }, ctx, read_vars, write_vars, FnProperty::kNormal,
     0, PROFILER_MESSAGE(op->name.c_str()));
 }
@@ -389,21 +387,19 @@ inline void PushFComputeEx(const FComputeEx& fn,
   std::vector<NDArray> inputs, outputs;
   DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs);
   const auto& run = [ctx, exec_type, is_train, attrs, fn, inputs, outputs, requested, req](
-        RunContext rctx,
-        engine::CallbackOnComplete on_complete) {
-      OpContext opctx{is_train, rctx, on_complete, requested};
+        RunContext rctx) {
+      OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
       fn(attrs, opctx, inputs, req, outputs);
       if (exec_type == ExecType::kSync) {
         if (rctx.get_ctx().dev_mask() == gpu::kDevMask) {
           rctx.get_stream<gpu>()->Wait();
         }
-        on_complete();
       }
     };
   if (exec_type == ExecType::kLocal) {
-    run(RunContext{ctx, nullptr}, engine::CallbackOnComplete());
+    run(RunContext{ctx, nullptr});
   } else {
-    Engine::Get()->PushAsync(run, ctx, read_vars, write_vars, FnProperty::kNormal,
+    Engine::Get()->PushSync(run, ctx, read_vars, write_vars, FnProperty::kNormal,
       0, PROFILER_MESSAGE(op->name.c_str()));
   }
 }
@@ -436,21 +432,19 @@ inline void PushOperator(const OpStatePtr& state,
   if (fcompute_ex != nullptr && dispatch_mode == DispatchMode::kFComputeEx) {
     const auto& run = [state, fcompute_ex, inputs, outputs, requested, is_train,
                        exec_type, req](
-          RunContext rctx,
-          engine::CallbackOnComplete on_complete) {
-        OpContext opctx{is_train, rctx, on_complete, requested};
+          RunContext rctx) {
+        OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
         fcompute_ex(state, opctx, inputs, req, outputs);
         if (exec_type == ExecType::kSync) {
           if (rctx.get_ctx().dev_mask() == gpu::kDevMask) {
             rctx.get_stream<gpu>()->Wait();
           }
-          on_complete();
         }
       };
     if (exec_type == ExecType::kLocal) {
-      run(RunContext{ctx, nullptr}, engine::CallbackOnComplete());
+      run(RunContext{ctx, nullptr});
     } else {
-      Engine::Get()->PushAsync(run, ctx, read_vars, write_vars, FnProperty::kNormal,
+      Engine::Get()->PushSync(run, ctx, read_vars, write_vars, FnProperty::kNormal,
                                0, PROFILER_MESSAGE(op->name.c_str()));
     }
   } else {
@@ -458,11 +452,10 @@ inline void PushOperator(const OpStatePtr& state,
         << "One of FStatefulCompute and FStatefulComputeEx must be registered "
         << "for stateful operator " << op->name;
     CHECK(exec_type == ExecType::kSync || exec_type == ExecType::kAsync);
-    Engine::Get()->PushAsync(
+    Engine::Get()->PushSync(
       [state, fcompute, inputs, outputs, requested, is_train, exec_type, mutate_idx, req](
-          RunContext rctx,
-          engine::CallbackOnComplete on_complete) {
-        OpContext opctx{is_train, rctx, on_complete, requested};
+          RunContext rctx) {
+        OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
 
         std::vector<TBlob> input_blobs, output_blobs;
         // pre-fcompute and post-fcompute storage fallback src NDArrays and dst NDArrays
@@ -484,7 +477,6 @@ inline void PushOperator(const OpStatePtr& state,
           if (is_gpu) {
             rctx.get_stream<gpu>()->Wait();
           }
-          on_complete();
         }
       }, ctx, read_vars, write_vars, FnProperty::kNormal,
       0, PROFILER_MESSAGE(op->name.c_str()));
diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h
index deed1a15c9..028ab5992c 100644
--- a/src/kvstore/comm.h
+++ b/src/kvstore/comm.h
@@ -140,8 +140,10 @@ class CommCPU : public Comm {
         const_vars[i-1] = reduce[i].var();
       }
 
-      Engine::Get()->PushSync([reduce, this](RunContext rctx) {
+      Engine::Get()->PushAsync(
+        [reduce, this](RunContext rctx, Engine::CallbackOnComplete on_complete) {
           ReduceSumCPU(reduce);
+          on_complete();
         }, Context::CPU(), const_vars, {reduce[0].var()},
         FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce"));
 
@@ -163,13 +165,15 @@ class CommCPU : public Comm {
         const_vars[i] = reduce[i].var();
       }
       auto result = buf.merged;
-      Engine::Get()->PushSync([reduce, result, this](RunContext rctx) {
+      Engine::Get()->PushAsync(
+        [reduce, result, this](RunContext rctx, Engine::CallbackOnComplete on_complete) {
           NDArray out = result;
           Resource rsc = ResourceManager::Get()->Request(rctx.ctx,
               ResourceRequest(ResourceRequest::kTempSpace));
           is_serial_push_?
             ReduceSumCPUExSerial(reduce, &out)
             : mxnet::ndarray::ElementwiseSum(rctx.get_stream<cpu>(), rsc, reduce, &out);
+          on_complete();
         }, Context::CPU(), const_vars, {result.var()},
         FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce"));
     }
@@ -217,21 +221,25 @@ class CommCPU : public Comm {
           const bool is_to_gpu = out->ctx().dev_mask() == Context::kGPU;
           NDArray out_cpu = is_to_gpu? NDArray(kRowSparseStorage, src.shape(),
               src.ctx(), true, src.dtype(), src.aux_types()) : *out;
-          Engine::Get()->PushSync([=](RunContext rctx) {
+          Engine::Get()->PushAsync(
+            [=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
               const TBlob& indices = row_id.data();
               NDArray temp = out_cpu;  // get rid of const qualifier
               op::SparseRetainOpForwardRspImpl<cpu>(rctx.get_stream<cpu>(),
                                                     src, indices, kWriteTo,
                                                     &temp);
+              on_complete();
             }, Context::CPU(), {src.var(), row_id.var()}, {out_cpu.var()},
             FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain"));
           if (is_to_gpu) {
             CopyFromTo(out_cpu, out, priority);
           }
         } else {  // direct copy rows
-          Engine::Get()->PushSync([=](RunContext rctx) {
+          Engine::Get()->PushAsync(
+            [=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
               CopyRetainedRowsToGPU(rctx.get_stream<cpu>(), rctx.get_stream<gpu>(),
                                     src, row_id, out);
+              on_complete();
             }, out->ctx(), {src.var(), row_id.var()}, {out->var()},
             FnProperty::kCopyToGPU, priority, PROFILER_MESSAGE("KVStoreCopyRetainedRowsToGPU"));
         }
diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h
index bedb5398a0..1503408618 100644
--- a/src/kvstore/kvstore_dist_server.h
+++ b/src/kvstore/kvstore_dist_server.h
@@ -230,13 +230,15 @@ class KVStoreDistServer {
         TBlob recv_blob(data, dshape, cpu::kDevMask);  // NOLINT(*)
         NDArray recved = NDArray(recv_blob, 0);
         stored = NDArray(kRowSparseStorage, dshape, Context());
-        Engine::Get()->PushSync([recved, stored](RunContext ctx) {
+        Engine::Get()->PushAsync(
+          [recved, stored](RunContext ctx, Engine::CallbackOnComplete on_complete) {
             NDArray rsp = stored;
             stored.CheckAndAlloc({mshadow::Shape1(recved.shape()[0])});
             mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
             op::PopulateFullIdxRspImpl(s, &rsp);
             mshadow::Copy(rsp.data().FlatTo1D<cpu, float>(),
                           recved.data().FlatTo1D<cpu, float>(), s);
+            on_complete();
           }, recved.ctx(), {recved.var()}, {stored.var()},
           FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
         stored.WaitToRead();
@@ -285,15 +287,13 @@ class KVStoreDistServer {
           // TODO(haibin) override + operator for row_sparse NDArray
           // instead of calling BinaryComputeRspRsp directly
           using namespace mshadow;
-          Engine::Get()->PushSync([recved, merged, out](RunContext ctx) {
-                                    std::vector<NDArray> inputs, outputs;
-                                    inputs.push_back(recved);
-                                    inputs.push_back(merged.array);
-                                    outputs.push_back(out);
-                                    op::ElemwiseBinaryOp::ComputeEx<cpu, mshadow::op::plus>(
-                                      {}, {}, inputs, {kWriteTo}, outputs);
-                                  }, recved.ctx(), const_vars, {out.var()},
-                                  FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
+          Engine::Get()->PushAsync(
+            [recved, merged, out](RunContext ctx, Engine::CallbackOnComplete on_complete) {
+              op::ElemwiseBinaryOp::ComputeEx<cpu, mshadow::op::plus>(
+                {}, {}, {recved, merged.array}, {kWriteTo}, {out});
+              on_complete();
+            }, recved.ctx(), const_vars, {out.var()},
+            FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
           CopyFromTo(out, &merged.array, 0);
         }
         merged.request.push_back(req_meta);
diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h
index 15a4c6055b..4038185244 100644
--- a/src/kvstore/kvstore_local.h
+++ b/src/kvstore/kvstore_local.h
@@ -347,7 +347,8 @@ class KVStoreLocal : public KVStore {
   void Unique(NDArray *out, int priority = 0) {
     CHECK_EQ(out->ctx().dev_mask(), pinned_ctx_.dev_mask())
              << "Unique expects input with `pinned_ctx_`";
-    Engine::Get()->PushSync([out](RunContext rctx) {
+    Engine::Get()->PushAsync(
+      [out](RunContext rctx, Engine::CallbackOnComplete on_complete) {
         NDArray *output = out;
         CHECK_EQ(out->shape().ndim(), 1) << "Unique expects 1D inputs";
         const auto size = out->shape()[0];
@@ -358,6 +359,7 @@ class KVStoreLocal : public KVStore {
           auto num_unique_idx = std::unique(dptr, dptr + size) - dptr;
           *output = output->Reshape(mshadow::Shape1(num_unique_idx));
         });
+        on_complete();
       }, pinned_ctx_, {}, {out->var()},
       FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreUnique"));
     out->WaitToRead();
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 275cf40380..79d1324119 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -527,25 +527,33 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority) {
   if (from.var() != to.var()) const_vars.push_back(from.var());
 
   if (a == cpu::kDevMask && b == cpu::kDevMask) {
-    Engine::Get()->PushSync([from, to](RunContext ctx) {
+    Engine::Get()->PushAsync(
+      [from, to](RunContext ctx, Engine::CallbackOnComplete on_complete) {
         CopyFromToImpl<cpu, cpu>(from, to, ctx);
+        on_complete();
       }, from.ctx(), const_vars, {to.var()},
       FnProperty::kNormal, priority, PROFILER_MESSAGE("CopyCPU2CPU"));
   } else {
 #if MXNET_USE_CUDA
     if (a == cpu::kDevMask && b == gpu::kDevMask) {
-      Engine::Get()->PushSync([from, to](RunContext ctx) {
+      Engine::Get()->PushAsync(
+        [from, to](RunContext ctx, Engine::CallbackOnComplete on_complete) {
           CopyFromToImpl<cpu, gpu>(from, to, ctx);
+          on_complete();
         }, to.ctx(), const_vars, {to.var()},
         FnProperty::kCopyToGPU, priority, PROFILER_MESSAGE("CopyCPU2GPU"));
     } else if (a == gpu::kDevMask && b == cpu::kDevMask) {
-      Engine::Get()->PushSync([from, to](RunContext ctx) {
+      Engine::Get()->PushAsync(
+        [from, to](RunContext ctx, Engine::CallbackOnComplete on_complete) {
           CopyFromToImpl<gpu, cpu>(from, to, ctx);
+          on_complete();
         }, from.ctx(), const_vars, {to.var()},
         FnProperty::kCopyFromGPU, priority, PROFILER_MESSAGE("CopyGPU2CPU"));
     } else if (a == gpu::kDevMask && b == gpu::kDevMask) {
-      Engine::Get()->PushSync([from, to](RunContext ctx) {
+      Engine::Get()->PushAsync(
+        [from, to](RunContext ctx, Engine::CallbackOnComplete on_complete) {
           CopyFromToImpl<gpu, gpu>(from, to, ctx);
+          on_complete();
         }, from.ctx(), const_vars, {to.var()},
         from.dtype() != to.dtype() ? FnProperty::kNormal : FnProperty::kCopyFromGPU,
         priority, PROFILER_MESSAGE("CopyGPU2GPU"));
@@ -1077,12 +1085,14 @@ void NDArray::SyncCopyFromCPU(const void *data, size_t size) const {
     ndarray::Copy<cpu, cpu>(src, &dst, Context::CPU(), Context::CPU(), rctx);
   } else {
 #if MXNET_USE_CUDA
-    Engine::Get()->PushSync([&](RunContext rctx) {
+    Engine::Get()->PushAsync(
+      [&](RunContext rctx, Engine::CallbackOnComplete on_complete) {
         TBlob dst = this->data();
         ndarray::Copy<cpu, gpu>(src, &dst,
                                 Context::CPU(), this->ctx(), rctx);
         // Wait GPU kernel to complete
         rctx.get_stream<gpu>()->Wait();
+        on_complete();
       }, this->ctx(), {}, {this->var()},
       FnProperty::kCopyToGPU, 0, PROFILER_MESSAGE("SyncCopyCPU2GPU"));
     this->WaitToRead();
@@ -1145,27 +1155,33 @@ void NDArray::SyncCopyFromNDArray(const NDArray& src, int i, int j) {
   } else {
 #if MXNET_USE_CUDA
     if (src_dev_mask == cpu::kDevMask && dst_dev_mask == gpu::kDevMask) {
-      Engine::Get()->PushSync([&](RunContext rctx) {
+      Engine::Get()->PushAsync(
+        [&](RunContext rctx, Engine::CallbackOnComplete on_complete) {
           const TBlob src_data = (i >= 0? src.aux_data(i) : src.data());
           TBlob dst_data = get_dst_data(src_data.shape_);
           ndarray::Copy<cpu, gpu>(src_data, &dst_data, src.ctx(), this->ctx(), rctx);
           rctx.get_stream<gpu>()->Wait();
+          on_complete();
         }, this->ctx(), const_vars, {this->var()},
         FnProperty::kCopyToGPU, 0, PROFILER_MESSAGE("SyncCopyFromNDArrayCPU2GPU"));
     } else if (src_dev_mask == gpu::kDevMask && dst_dev_mask == cpu::kDevMask) {
-      Engine::Get()->PushSync([&](RunContext rctx) {
+      Engine::Get()->PushAsync(
+        [&](RunContext rctx, Engine::CallbackOnComplete on_complete) {
           const TBlob src_data = (i >= 0? src.aux_data(i) : src.data());
           TBlob dst_data = get_dst_data(src_data.shape_);
           ndarray::Copy<gpu, cpu>(src_data, &dst_data, src.ctx(), this->ctx(), rctx);
           rctx.get_stream<gpu>()->Wait();
+          on_complete();
         }, this->ctx(), const_vars, {this->var()},
         FnProperty::kCopyFromGPU, 0, PROFILER_MESSAGE("SyncCopyFromNDArrayGPU2CPU"));
     } else if (src_dev_mask == gpu::kDevMask && dst_dev_mask == gpu::kDevMask) {
-      Engine::Get()->PushSync([&](RunContext rctx) {
+      Engine::Get()->PushAsync(
+        [&](RunContext rctx, Engine::CallbackOnComplete on_complete) {
           const TBlob src_data = (i >= 0? src.aux_data(i) : src.data());
           TBlob dst_data = get_dst_data(src_data.shape_);
           ndarray::Copy<gpu, gpu>(src_data, &dst_data, src.ctx(), this->ctx(), rctx);
           rctx.get_stream<gpu>()->Wait();
+          on_complete();
         }, this->ctx(), const_vars, {this->var()},
         src.dtype() != this->dtype() ? FnProperty::kNormal : FnProperty::kCopyFromGPU,
         0, PROFILER_MESSAGE("SyncCopyFromNDArrayGPU2GPU"));
@@ -1200,11 +1216,13 @@ void NDArray::SyncCopyToCPU(void *data, size_t size) const {
                             Context::CPU(), Context::CPU(), rctx);
   } else {
 #if MXNET_USE_CUDA
-    Engine::Get()->PushSync([&](RunContext rctx) {
+    Engine::Get()->PushAsync(
+      [&](RunContext rctx, Engine::CallbackOnComplete on_complete) {
         ndarray::Copy<gpu, cpu>(this->data(), &dst,
                                 this->ctx(), Context::CPU(), rctx);
         // Wait GPU kernel to complete
         rctx.get_stream<gpu>()->Wait();
+        on_complete();
       }, this->ctx(), {this->var()}, {},
       FnProperty::kCopyFromGPU, 0, PROFILER_MESSAGE("SyncCopyGPU2CPU"));
     this->WaitToWrite();
diff --git a/src/operator/cudnn_convolution-inl.h b/src/operator/cudnn_convolution-inl.h
index b2b59944e8..d42ec9caa5 100644
--- a/src/operator/cudnn_convolution-inl.h
+++ b/src/operator/cudnn_convolution-inl.h
@@ -586,7 +586,7 @@ class CuDNNConvolutionOp : public Operator {
                                        &back_algo_w_)) {
       // Not in algo registry, must determine via *Get*() or *Find*()
       Engine::VarHandle var = Engine::Get()->NewVariable();
-      Engine::Get()->PushSync([=](RunContext rctx) {
+      Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
         mshadow::Stream<gpu> *s = rctx.get_stream<gpu>();
         CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
         size_t workspace_byte = static_cast<size_t>(param_.workspace * sizeof(DType));
@@ -776,6 +776,7 @@ class CuDNNConvolutionOp : public Operator {
                                           cudnn_backward_compute_type,
                                           SMArch(ctx.dev_id), this->forward_algo_,
                                           this->back_algo_, this->back_algo_w_);
+        on_complete();
       }, ctx, {}, {var});
       Engine::Get()->WaitForVar(var);
       Engine::Get()->DeleteVariable([](RunContext s) {}, ctx, var);
diff --git a/src/operator/cudnn_deconvolution-inl.h b/src/operator/cudnn_deconvolution-inl.h
index 5e9b7c5704..95ab596f27 100644
--- a/src/operator/cudnn_deconvolution-inl.h
+++ b/src/operator/cudnn_deconvolution-inl.h
@@ -605,7 +605,7 @@ class CuDNNDeconvolutionOp : public Operator {
                                          &back_algo_, &back_algo_w_)) {
       // Not in algo registry, must determine via *Get*() or *Find*()
       Engine::VarHandle var = Engine::Get()->NewVariable();
-      Engine::Get()->PushSync([=](RunContext rctx) {
+      Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
         mshadow::Stream <gpu> *s = rctx.get_stream<gpu>();
         CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
         size_t workspace_byte = static_cast<size_t>(param_.workspace * sizeof(DType));
@@ -798,6 +798,7 @@ class CuDNNDeconvolutionOp : public Operator {
                                             cudnn_backward_compute_type,
                                             SMArch(ctx.dev_id), this->forward_algo_,
                                             this->back_algo_, this->back_algo_w_);
+        on_complete();
       }, ctx, {}, {var});
       Engine::Get()->WaitForVar(var);
       Engine::Get()->DeleteVariable([](RunContext s) {}, ctx, var);
diff --git a/src/operator/custom/ndarray_op.cc b/src/operator/custom/ndarray_op.cc
index 48426baea8..66bdfc78f2 100644
--- a/src/operator/custom/ndarray_op.cc
+++ b/src/operator/custom/ndarray_op.cc
@@ -84,9 +84,11 @@ void NDArrayOp<xpu>::Forward(const OpContext &ctx,
   }
 
   CHECK(param_.pinfo->forward(ptrs.size(), ptrs.data(), tags.data(), param_.pinfo->p_forward));
-  Engine::Get()->PushSync([ndcpy, ctx](RunContext rctx) {ctx.async_on_complete(); },
-                          ndctx, ndvar, {}, FnProperty::kNormal, 0,
-                          PROFILER_MESSAGE("NDArrayOpForward"));
+  Engine::Get()->PushAsync(
+      [ndcpy, ctx](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+        ctx.async_on_complete();
+        on_complete();
+      }, ndctx, ndvar, {}, FnProperty::kNormal, 0, PROFILER_MESSAGE("NDArrayOpForward"));
 }
 
 template<typename xpu>
@@ -131,9 +133,11 @@ void NDArrayOp<xpu>::Backward(const OpContext &ctx,
   }
 
   CHECK(param_.pinfo->backward(ptrs.size(), ptrs.data(), tags.data(), param_.pinfo->p_backward));
-  Engine::Get()->PushSync([ndcpy, ctx](RunContext rctx){ ctx.async_on_complete(); },
-                          ndctx, ndvar, {}, FnProperty::kNormal, 0,
-                          PROFILER_MESSAGE("NDArrayOpBackward"));
+  Engine::Get()->PushAsync(
+      [ndcpy, ctx](RunContext rctx, Engine::CallbackOnComplete on_complete){
+        ctx.async_on_complete();
+        on_complete();
+      }, ndctx, ndvar, {}, FnProperty::kNormal, 0, PROFILER_MESSAGE("NDArrayOpBackward"));
 }
 
 Operator* NDArrayOpProp::CreateOperator(Context ctx) const {
diff --git a/src/resource.cc b/src/resource.cc
index 4c2dbee33f..d591651145 100644
--- a/src/resource.cc
+++ b/src/resource.cc
@@ -186,9 +186,11 @@ class ResourceManagerImpl : public ResourceManager {
     inline void Seed(uint32_t global_seed) {
       uint32_t seed = ctx.dev_id + global_seed * kRandMagic;
       mshadow::Random<xpu> *r = prnd;
-      Engine::Get()->PushSync([r, seed](RunContext rctx) {
+      Engine::Get()->PushAsync(
+        [r, seed](RunContext rctx, Engine::CallbackOnComplete on_complete) {
           r->set_stream(rctx.get_stream<xpu>());
           r->Seed(seed);
+          on_complete();
         }, ctx, {}, {resource.var},
         FnProperty::kNormal, 0, PROFILER_MESSAGE("ResourceRandomSetSeed"));
     }
diff --git a/tests/python/unittest/test_engine.py b/tests/python/unittest/test_engine.py
new file mode 100644
index 0000000000..29b7b822b3
--- /dev/null
+++ b/tests/python/unittest/test_engine.py
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import nose
+import mxnet as mx
+
+def test_bulk():
+    with mx.engine.bulk(10):
+        x = mx.nd.ones((10,))
+        x *= 2
+        x += 1
+        x.wait_to_read()
+        x += 1
+        assert (x.asnumpy() == 4).all()
+        for i in range(100):
+            x += 1
+    assert (x.asnumpy() == 104).all()
+
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services