You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/01/05 19:29:26 UTC

[incubator-mxnet] branch master updated: Fix custom op multi-gpu scaling (#9283)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 004dead  Fix custom op multi-gpu scaling (#9283)
004dead is described below

commit 004dead77f7c731fdd7d32a1a123ab6044e4db59
Author: Eric Junyuan Xie <pi...@users.noreply.github.com>
AuthorDate: Fri Jan 5 11:29:22 2018 -0800

    Fix custom op multi-gpu scaling (#9283)
    
    * refactor custom op
    
    * fix
    
    * fix
    
    * fix
    
    * fix
---
 include/mxnet/op_attr_types.h     |   2 -
 src/c_api/c_api.cc                |   2 +-
 src/c_api/c_api_function.cc       |  72 +++++++++++---------------
 src/executor/graph_executor.cc    |   3 --
 src/imperative/imperative_utils.h |  89 ++++++++++++++++----------------
 src/ndarray/ndarray.cc            |   2 +-
 src/operator/custom/custom-inl.h  |  76 +++++++++++++++++++++++++--
 src/operator/custom/custom.cc     | 105 +++++++++++++++++---------------------
 8 files changed, 194 insertions(+), 157 deletions(-)

diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h
index 8cb8a99..fb41d39 100644
--- a/include/mxnet/op_attr_types.h
+++ b/include/mxnet/op_attr_types.h
@@ -92,8 +92,6 @@ enum class ExecType {
    *  will call OpContext.async_on_complete when operation finishes.
    */
   kAsync,
-  /*! \brief Run this operator on the scheduling thread without pushing to engine. */
-  kLocal,
   /*!
    * \brief Cross device copy operation, this is a special operator
    *  That indicates copy across devices, the input and output can sit on different device.
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 027f00b..c55f6c5 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -1170,7 +1170,7 @@ int MXRtcFree(RtcHandle handle) {
 
 int MXCustomOpRegister(const char* op_type, CustomOpPropCreator creator) {
   API_BEGIN();
-  mxnet::op::custom::Registry::Get()->Register(op_type, creator);
+  mxnet::op::custom::CustomOperator::Get()->Register(op_type, creator);
   API_END();
 }
 
diff --git a/src/c_api/c_api_function.cc b/src/c_api/c_api_function.cc
index 3cd4f66..e8ca189 100644
--- a/src/c_api/c_api_function.cc
+++ b/src/c_api/c_api_function.cc
@@ -29,6 +29,7 @@
 
 #include "./c_api_common.h"
 #include "../operator/operator_common.h"
+#include "../operator/custom/custom-inl.h"
 
 namespace mxnet {
 namespace custom_function {
@@ -62,68 +63,55 @@ std::vector<nnvm::NodeEntry> Gradient(
 }
 
 OpStatePtr CreateState(const nnvm::NodeAttrs& attrs,
-                               Context ctx,
-                               const std::vector<TShape>& ishape,
-                               const std::vector<int>& itype) {
+                       Context ctx,
+                       const std::vector<TShape>& ishape,
+                       const std::vector<int>& itype) {
   LOG(FATAL) << "Not reached";
   return OpStatePtr::Create<void*>(nullptr);
 }
 
 void Forward(const OpStatePtr& state,
              const OpContext& ctx,
-             const std::vector<NDArray>& inputs,
+             const std::vector<TBlob>& inputs,
              const std::vector<OpReqType>& req,
-             const std::vector<NDArray>& outputs) {
+             const std::vector<TBlob>& outputs) {
   LOG(FATAL) << "Not reached";
 }
 
 void Backward(const OpStatePtr& state,
               const OpContext& ctx,
-              const std::vector<NDArray>& inputs,
+              const std::vector<TBlob>& inputs,
               const std::vector<OpReqType>& req,
-              const std::vector<NDArray>& outputs) {
+              const std::vector<TBlob>& outputs) {
   const CustomFunctionParam& params = state.get_state<CustomFunctionParam>();
 
   std::vector<NDArrayHandle> ptrs;
+  std::vector<NDArray> cpys;
+
+  auto dev_id = ctx.run_ctx.ctx.dev_id;
 
   for (const auto& i : inputs) {
-    NDArray* nd = new NDArray(i.Detach());
+    NDArray* nd = new NDArray(i, dev_id);
     ptrs.push_back(reinterpret_cast<NDArrayHandle>(nd));
+    cpys.push_back(*nd);
   }
   for (const auto& i : outputs) {
-    NDArray* nd = new NDArray(i.Detach());
+    NDArray* nd = new NDArray(i, dev_id);
     ptrs.push_back(reinterpret_cast<NDArrayHandle>(nd));
+    cpys.push_back(*nd);
   }
 
-  bool prev_recording = Imperative::Get()->set_is_recording(false);
-  bool prev_training = Imperative::Get()->set_is_training(ctx.is_train);
-
-  CHECK(reinterpret_cast<CustomFunctionBwdFunc>(
-      params.info->callbacks[kCustomFunctionBackward])(
-          inputs.size(), outputs.size(), ptrs.data(),
-          reinterpret_cast<const int*>(req.data()), ctx.is_train,
-          params.info->contexts[kCustomFunctionBackward]));
-
-  Imperative::Get()->set_is_training(prev_training);
-  Imperative::Get()->set_is_recording(prev_recording);
+  op::custom::CustomOperator::Get()->Push(
+    [=]() {
+      CHECK(reinterpret_cast<CustomFunctionBwdFunc>(
+          params.info->callbacks[kCustomFunctionBackward])(
+              inputs.size(), outputs.size(),
+              const_cast<NDArrayHandle*>(ptrs.data()),
+              reinterpret_cast<const int*>(req.data()), ctx.is_train,
+              params.info->contexts[kCustomFunctionBackward]));
+    }, ctx, false, ctx.is_train, cpys);
 }
 
-// infer storage function for custom op, which assigns kDefaultStorage for
-// all undefined stypes, and dispatch on DispatchMode::kFComputeEx.
-inline bool InferStorageType(const nnvm::NodeAttrs& attrs,
-                             const int dev_mask,
-                             DispatchMode* dispatch_mode,
-                             std::vector<int> *iattr,
-                             std::vector<int> *oattr) {
-  for (int& v : *oattr) {
-    if (v == -1) v = kDefaultStorage;
-  }
-  for (int& v : *iattr) {
-    if (v == -1) v = kDefaultStorage;
-  }
-  op::dispatch_mode_assign(dispatch_mode, DispatchMode::kFComputeEx);
-  return true;
-}
 
 NNVM_REGISTER_OP(_CustomFunction)
 .set_num_inputs([](const NodeAttrs& attrs) {
@@ -150,9 +138,8 @@ NNVM_REGISTER_OP(_CustomFunction)
   })
 .set_attr<FCreateOpState>("FCreateOpState", CreateState)
 .set_attr<nnvm::FGradient>("FGradient", Gradient)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", Forward)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", Forward)
-.set_attr<FInferStorageType>("FInferStorageType", InferStorageType);
+.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", Forward)
+.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", Forward);
 
 
 NNVM_REGISTER_OP(_backward_CustomFunction)
@@ -167,11 +154,10 @@ NNVM_REGISTER_OP(_backward_CustomFunction)
 .set_attr<bool>("TIsBackward", true)
 .set_attr<bool>("TIsLayerOpBackward", true)
 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
-    return ExecType::kLocal;
+    return ExecType::kAsync;
   })
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", Backward)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", Backward)
-.set_attr<FInferStorageType>("FInferStorageType", InferStorageType);
+.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", Backward)
+.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", Backward);
 
 }  // namespace custom_function
 }  // namespace mxnet
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 77853a6..5f95df3 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -1473,9 +1473,6 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) {
       CHECK_EQ(opnode.exec->in_array.size(), 1U);
       CHECK_EQ(opnode.exec->out_array.size(), 1U);
       CopyFromTo(opnode.exec->in_array[0], &(opnode.exec->out_array[0]));
-    } else if (opnode.exec->exec_type() == ExecType::kLocal) {
-      bool is_gpu = opnode.ctx.dev_mask() == gpu::kDevMask;
-      opnode.exec->Run(RunContext{opnode.ctx, nullptr}, is_gpu);
     } else if (opnode.cached_opr != nullptr) {
 #if MXNET_USE_PROFILER
       bool profiling = engine::Profiler::Get()->GetState() == engine::Profiler::kRunning;
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index 8be1eb4..add568d 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -341,12 +341,15 @@ inline void PushFCompute(const FCompute& fn,
                   const std::vector<uint32_t>& mutate_idx,
                   const std::vector<OpReqType>& req) {
   using namespace common;
+  static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
+
   bool is_train = Imperative::Get()->is_training();
+  ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync;
+  CHECK(exec_type == ExecType::kSync);
   std::vector<NDArray> inputs, outputs;
   DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs);
   Engine::Get()->PushSync(
-    [ctx, attrs, fn, inputs, outputs, requested, is_train, mutate_idx, req](
-        RunContext rctx) {
+    [=](RunContext rctx) {
       std::vector<TBlob> input_blobs, output_blobs;
       // pre-fcompute and post-fcompute storage fallback src NDArrays and dst NDArrays
       std::vector<NDArray> pre_temp_src, pre_temp_dst, post_temp_dst, post_temp_src;
@@ -354,8 +357,8 @@ inline void PushFCompute(const FCompute& fn,
       std::unordered_map<uint32_t, uint32_t> in_temp_idx_map;
       // setup blobs
       SetupDefaultBlobsInOut(inputs, outputs, &input_blobs, &output_blobs,
-                             &pre_temp_src, &pre_temp_dst, &post_temp_src, &post_temp_dst,
-                             &in_temp_idx_map, mutate_idx);
+                             &pre_temp_src, &pre_temp_dst, &post_temp_src,
+                             &post_temp_dst, &in_temp_idx_map, mutate_idx);
       // setup context
       OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
       bool is_gpu = ctx.dev_mask() == gpu::kDevMask;
@@ -384,27 +387,23 @@ inline void PushFComputeEx(const FComputeEx& fn,
   static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
 
   bool is_train = Imperative::Get()->is_training();
-  ExecType exec_type = ExecType::kSync;
-  if (fexec_type.count(op)) {
-    exec_type = fexec_type[op](attrs);
-  }
+  ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync;
   std::vector<NDArray> inputs, outputs;
   DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs);
-  const auto& run = [ctx, exec_type, is_train, attrs, fn, inputs, outputs, requested, req](
-        RunContext rctx) {
+  const auto& run = [=](RunContext rctx) {
       OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
       fn(attrs, opctx, inputs, req, outputs);
-      if (exec_type == ExecType::kSync) {
-        if (rctx.get_ctx().dev_mask() == gpu::kDevMask) {
-          rctx.get_stream<gpu>()->Wait();
-        }
+      if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync) {
+        rctx.get_stream<gpu>()->Wait();
       }
     };
-  if (exec_type == ExecType::kLocal) {
+
+  if (exec_type == ExecType::kCrossDeviceCopy) {
     run(RunContext{ctx, nullptr});
   } else {
+    CHECK(exec_type == ExecType::kSync);
     Engine::Get()->PushSync(run, ctx, read_vars, write_vars, FnProperty::kNormal,
-      0, PROFILER_MESSAGE(op->name.c_str()));
+                            0, PROFILER_MESSAGE(op->name.c_str()));
   }
 }
 
@@ -424,42 +423,30 @@ inline void PushOperator(const OpStatePtr& state,
   static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
 
   bool is_train = Imperative::Get()->is_training();
-  ExecType exec_type = ExecType::kSync;
-  if (fexec_type.count(op)) {
-    exec_type = fexec_type[op](attrs);
-  }
+  ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync;
   std::vector<NDArray> inputs, outputs;
   DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs);
 
   auto fcompute = common::GetFCompute<FStatefulCompute>(op, "FStatefulCompute", ctx);
   auto fcompute_ex = common::GetFCompute<FStatefulComputeEx>(op, "FStatefulComputeEx", ctx);
   if (fcompute_ex != nullptr && dispatch_mode == DispatchMode::kFComputeEx) {
-    const auto& run = [state, fcompute_ex, inputs, outputs, requested, is_train,
-                       exec_type, req](
-          RunContext rctx) {
-        OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
-        fcompute_ex(state, opctx, inputs, req, outputs);
-        if (exec_type == ExecType::kSync) {
-          if (rctx.get_ctx().dev_mask() == gpu::kDevMask) {
+    CHECK(exec_type == ExecType::kSync);
+    Engine::Get()->PushSync(
+        [=](RunContext rctx) {
+          OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
+          fcompute_ex(state, opctx, inputs, req, outputs);
+          if (ctx.dev_mask() == gpu::kDevMask) {
             rctx.get_stream<gpu>()->Wait();
           }
-        }
-      };
-    if (exec_type == ExecType::kLocal) {
-      run(RunContext{ctx, nullptr});
-    } else {
-      Engine::Get()->PushSync(run, ctx, read_vars, write_vars, FnProperty::kNormal,
-                               0, PROFILER_MESSAGE(op->name.c_str()));
-    }
+        }, ctx, read_vars, write_vars, FnProperty::kNormal,
+        0, PROFILER_MESSAGE(op->name.c_str()));
   } else {
     CHECK(fcompute != nullptr)
         << "One of FStatefulCompute and FStatefulComputeEx must be registered "
         << "for stateful operator " << op->name;
-    CHECK(exec_type == ExecType::kSync || exec_type == ExecType::kAsync);
-    Engine::Get()->PushSync(
-      [state, fcompute, inputs, outputs, requested, is_train, exec_type, mutate_idx, req](
-          RunContext rctx) {
-        OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
+
+    const auto& run = [=](RunContext rctx, engine::CallbackOnComplete on_complete) {
+        OpContext opctx{is_train, rctx, on_complete, requested};
 
         std::vector<TBlob> input_blobs, output_blobs;
         // pre-fcompute and post-fcompute storage fallback src NDArrays and dst NDArrays
@@ -477,13 +464,23 @@ inline void PushOperator(const OpStatePtr& state,
         fcompute(state, opctx, input_blobs, req, output_blobs);
         // post-fcompute fallback, cast to original storage type, if necessary
         CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu);
-        if (exec_type == ExecType::kSync) {
-          if (is_gpu) {
-            rctx.get_stream<gpu>()->Wait();
-          }
+        if (is_gpu && exec_type == ExecType::kSync) {
+          rctx.get_stream<gpu>()->Wait();
         }
-      }, ctx, read_vars, write_vars, FnProperty::kNormal,
-      0, PROFILER_MESSAGE(op->name.c_str()));
+      };
+
+    if (exec_type == ExecType::kSync) {
+      Engine::Get()->PushSync(
+          [=](RunContext rctx) {
+            run(rctx, engine::CallbackOnComplete());
+          }, ctx, read_vars, write_vars, FnProperty::kNormal,
+          0, PROFILER_MESSAGE(op->name.c_str()));
+    } else {
+      CHECK(exec_type == ExecType::kAsync);
+      Engine::Get()->PushAsync(
+          run, ctx, read_vars, write_vars, FnProperty::kAsync,
+          0, PROFILER_MESSAGE(op->name.c_str()));
+    }
   }
 }
 
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 212fd7c..8a3bb8d 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -1355,7 +1355,7 @@ NNVM_REGISTER_OP(_copyto)
     return true;
   })
 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
-    return ExecType::kLocal;
+    return ExecType::kCrossDeviceCopy;
   })
 .set_attr<nnvm::FGradient>("FGradient", op::ElemwiseGradUseNone{"_copyto"})
 .set_attr<bool>("TIsBackward", true)
diff --git a/src/operator/custom/custom-inl.h b/src/operator/custom/custom-inl.h
index 13101da..38aeefd 100644
--- a/src/operator/custom/custom-inl.h
+++ b/src/operator/custom/custom-inl.h
@@ -30,6 +30,7 @@
 #include <dmlc/parameter.h>
 #include <mxnet/operator.h>
 #include <mxnet/c_api.h>
+#include <mxnet/imperative.h>
 #include <map>
 #include <vector>
 #include <string>
@@ -46,7 +47,7 @@ namespace mxnet {
 namespace op {
 namespace custom {
 
-class Registry {
+class CustomOperator {
  public:
   void Register(const std::string &op_type, CustomOpPropCreator creator) {
     std::lock_guard<std::mutex> lock(mutex_);
@@ -63,11 +64,80 @@ class Registry {
     return nullptr;
   }
 
-  static Registry* Get();
+  template<typename Func>
+  void Push(const Func& func,
+            const OpContext& ctx,
+            bool recording,
+            bool training,
+            const std::vector<NDArray>& arrs) {
+    if (naive_engine_) {
+      func();
+      ctx.async_on_complete();
+      return;
+    }
+    std::unique_lock<std::mutex> lock(mutex_);
+    q_.push(
+      [=]() mutable {
+        bool prev_recording = Imperative::Get()->set_is_recording(recording);
+        bool prev_training = Imperative::Get()->set_is_training(training);
+
+        func();
+
+        Imperative::Get()->set_is_training(prev_training);
+        Imperative::Get()->set_is_recording(prev_recording);
+
+        std::vector<Engine::VarHandle> vars;
+        for (const auto& i : arrs) vars.push_back(i.var());
+        Engine::Get()->PushSync([=](RunContext rctx) {
+            ctx.async_on_complete();
+          }, ctx.run_ctx.ctx, vars, {},
+          FnProperty::kNormal, 0, PROFILER_MESSAGE("CustomOperator"));
+      });
+    cv_.notify_all();
+  }
+
+  ~CustomOperator() {
+    if (naive_engine_) return;
+    {
+      std::unique_lock<std::mutex> lock(mutex_);
+      destructing_ = true;
+      cv_.notify_all();
+    }
+    worker_.join();
+  }
+
+  static CustomOperator* Get();
+
  private:
-  Registry() {}
+  CustomOperator() {
+    destructing_ = false;
+    naive_engine_ = true;
+    if (std::string("NaiveEngine") != dmlc::GetEnv("MXNET_ENGINE_TYPE", std::string())) {
+      naive_engine_ = false;
+      worker_ = std::thread(
+        [&]() {
+          std::unique_lock<std::mutex> lock(mutex_);
+          while (!q_.empty() || !destructing_) {
+            cv_.wait(lock, [&] {return !q_.empty() || destructing_;});
+            while (!q_.empty()) {
+              auto fn = q_.front();
+              lock.unlock();
+              fn();
+              lock.lock();
+              q_.pop();
+            }
+          }
+        });
+    }
+  }
   std::mutex mutex_;
   std::map<std::string, CustomOpPropCreator> registry_;
+  // async worker
+  std::condition_variable cv_;
+  std::thread worker_;
+  std::queue<std::function<void(void)> > q_;
+  bool naive_engine_;
+  bool destructing_;
 };
 
 }  // namespace custom
diff --git a/src/operator/custom/custom.cc b/src/operator/custom/custom.cc
index 280b01b..beb5f3d 100644
--- a/src/operator/custom/custom.cc
+++ b/src/operator/custom/custom.cc
@@ -26,7 +26,6 @@
 #include "./custom-inl.h"
 #include <mxnet/base.h>
 #include <mxnet/ndarray.h>
-#include <mxnet/imperative.h>
 
 #include "../elemwise_op_common.h"
 
@@ -34,8 +33,8 @@ namespace mxnet {
 namespace op {
 namespace custom {
 
-Registry* Registry::Get() {
-  static Registry inst;
+CustomOperator* CustomOperator::Get() {
+  static CustomOperator inst;
   return &inst;
 }
 
@@ -75,8 +74,8 @@ void AttrParser(NodeAttrs* attrs) {
     }
   }
   CHECK(!params.op_type.empty()) << "Required argument `op_type` is missing.";
-  CustomOpPropCreator creator = Registry::Get()->Find(params.op_type);
-  CHECK(Registry::Get()->Find(params.op_type) != nullptr)
+  CustomOpPropCreator creator = CustomOperator::Get()->Find(params.op_type);
+  CHECK(CustomOperator::Get()->Find(params.op_type) != nullptr)
       << "Cannot find custom operator " << params.op_type;
   params.info.reset(new MXCallbackList, [](MXCallbackList* ptr){
       reinterpret_cast<CustomOpDelFunc>(ptr->callbacks[kCustomOpPropDelete])(
@@ -269,103 +268,95 @@ OpStatePtr CreateState(const NodeAttrs& attrs, Context ctx,
 
 void Forward(const OpStatePtr& state,
              const OpContext& ctx,
-             const std::vector<NDArray>& inputs,
+             const std::vector<TBlob>& inputs,
              const std::vector<OpReqType>& req,
-             const std::vector<NDArray>& outputs) {
+             const std::vector<TBlob>& outputs) {
   const CustomParam& params = state.get_state<CustomParam>();
   std::vector<void*> ptrs;
   std::vector<int> tags;
+  std::vector<NDArray> cpys;
+
+  auto dev_id = ctx.run_ctx.ctx.dev_id;
 
   for (size_t i = 0; i < params.num_args; ++i) {
-    NDArray *nd = new NDArray(inputs[i].Detach());
+    NDArray *nd = new NDArray(inputs[i], dev_id);
+    cpys.push_back(*nd);
     ptrs.push_back(reinterpret_cast<void*>(nd));
     tags.push_back(0);
   }
 
   for (size_t i = 0; i < params.num_outs; ++i) {
-    NDArray *nd = new NDArray(outputs[i].Detach());
+    NDArray *nd = new NDArray(outputs[i], dev_id);
+    cpys.push_back(*nd);
     ptrs.push_back(reinterpret_cast<void*>(nd));
     tags.push_back(1);
   }
 
   for (size_t i = 0; i < params.num_auxs; ++i) {
-    NDArray *nd = new NDArray(inputs[i+params.num_args].Detach());
+    NDArray *nd = new NDArray(inputs[i+params.num_args], dev_id);
+    cpys.push_back(*nd);
     ptrs.push_back(reinterpret_cast<void*>(nd));
     tags.push_back(4);
   }
 
-  bool prev_recording = Imperative::Get()->set_is_recording(false);
-  bool prev_training = Imperative::Get()->set_is_training(ctx.is_train);
-
-  CHECK(reinterpret_cast<CustomOpFBFunc>(params.info->callbacks[kCustomOpForward])(
-    ptrs.size(), ptrs.data(), tags.data(), reinterpret_cast<const int*>(req.data()),
-    static_cast<int>(ctx.is_train), params.info->contexts[kCustomOpForward]));
-
-  Imperative::Get()->set_is_training(prev_training);
-  Imperative::Get()->set_is_recording(prev_recording);
+  CustomOperator::Get()->Push(
+    [=]() {
+      CHECK(reinterpret_cast<CustomOpFBFunc>(params.info->callbacks[kCustomOpForward])(
+        ptrs.size(), const_cast<void**>(ptrs.data()), const_cast<int*>(tags.data()),
+        reinterpret_cast<const int*>(req.data()), static_cast<int>(ctx.is_train),
+        params.info->contexts[kCustomOpForward]));
+    }, ctx, false, ctx.is_train, cpys);
 }
 
 
 void Backward(const OpStatePtr& state,
               const OpContext& ctx,
-              const std::vector<NDArray>& inputs,
+              const std::vector<TBlob>& inputs,
               const std::vector<OpReqType>& req,
-              const std::vector<NDArray>& outputs) {
+              const std::vector<TBlob>& outputs) {
   const CustomParam& params = state.get_state<CustomParam>();
 
   size_t total = 2*params.num_args + 2*params.num_outs + params.num_auxs;
   std::vector<void*> ptrs(params.num_args + 2*params.num_outs, nullptr);
   std::vector<int> tags;
+  std::vector<NDArray> cpys;
+
   ptrs.reserve(total);
   tags.reserve(total);
   for (size_t i = 0; i < params.num_outs; ++i) tags.push_back(3);
   for (size_t i = 0; i < params.num_args; ++i) tags.push_back(0);
   for (size_t i = 0; i < params.num_outs; ++i) tags.push_back(1);
 
+  auto dev_id = ctx.run_ctx.ctx.dev_id;
+
   for (size_t i = 0; i < params.bwd_idx.size(); ++i) {
-    NDArray *nd = new NDArray(inputs[i].Detach());
+    NDArray *nd = new NDArray(inputs[i], dev_id);
+    cpys.push_back(*nd);
     ptrs[params.bwd_idx[i]] = reinterpret_cast<void*>(nd);
   }
   for (size_t i = 0; i < ptrs.size(); ++i) {
     if (ptrs[i] == nullptr) ptrs[i] = reinterpret_cast<void*>(new NDArray());
   }
   for (const auto& i : outputs) {
-    NDArray* nd = new NDArray(i.Detach());
+    NDArray* nd = new NDArray(i, dev_id);
+    cpys.push_back(*nd);
     ptrs.push_back(reinterpret_cast<void*>(nd));
     tags.push_back(2);
   }
   for (size_t i = 0; i < params.num_auxs; ++i) {
-    NDArray* nd = new NDArray(inputs[inputs.size()-params.num_auxs+i].Detach());
+    NDArray* nd = new NDArray(inputs[inputs.size()-params.num_auxs+i], dev_id);
+    cpys.push_back(*nd);
     ptrs.push_back(reinterpret_cast<void*>(nd));
     tags.push_back(4);
   }
 
-  bool prev_recording = Imperative::Get()->set_is_recording(false);
-  bool prev_training = Imperative::Get()->set_is_training(ctx.is_train);
-
-  CHECK(reinterpret_cast<CustomOpFBFunc>(params.info->callbacks[kCustomOpBackward])(
-    ptrs.size(), ptrs.data(), tags.data(), reinterpret_cast<const int*>(req.data()),
-    static_cast<int>(ctx.is_train), params.info->contexts[kCustomOpBackward]));
-
-  Imperative::Get()->set_is_training(prev_training);
-  Imperative::Get()->set_is_recording(prev_recording);
-}
-
-// infer storage function for custom op, which assigns kDefaultStorage for
-// all undefined stypes, and dispatch on DispatchMode::kFComputeEx.
-inline bool InferStorageType(const nnvm::NodeAttrs& attrs,
-                             const int dev_mask,
-                             DispatchMode* dispatch_mode,
-                             std::vector<int> *iattr,
-                             std::vector<int> *oattr) {
-  for (int& v : *oattr) {
-    if (v == -1) v = kDefaultStorage;
-  }
-  for (int& v : *iattr) {
-    if (v == -1) v = kDefaultStorage;
-  }
-  dispatch_mode_assign(dispatch_mode, DispatchMode::kFComputeEx);
-  return true;
+  CustomOperator::Get()->Push(
+    [=]() {
+      CHECK(reinterpret_cast<CustomOpFBFunc>(params.info->callbacks[kCustomOpBackward])(
+        ptrs.size(), const_cast<void**>(ptrs.data()), const_cast<int*>(tags.data()),
+        reinterpret_cast<const int*>(req.data()), static_cast<int>(ctx.is_train),
+        params.info->contexts[kCustomOpBackward]));
+    }, ctx, false, ctx.is_train, cpys);
 }
 
 NNVM_REGISTER_OP(Custom)
@@ -401,13 +392,12 @@ Please check the tutorial here: http://mxnet.io/how_to/new_op.html.
     return ret;
   })
 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
-    return ExecType::kLocal;
+    return ExecType::kAsync;
   })
 .set_attr<nnvm::FGradient>("FGradient", Gradient)
 .set_attr<FCreateOpState>("FCreateOpState", CreateState)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", Forward)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", Forward)
-.set_attr<FInferStorageType>("FInferStorageType", InferStorageType)
+.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", Forward)
+.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", Forward)
 .add_argument("data", "NDArray-or-Symbol[]", "Input data for the custom operator.")
 .add_argument("op_type", "string", "Name of the custom operator. "
               "This is the name that is passed to `mx.operator.register` "
@@ -426,11 +416,10 @@ NNVM_REGISTER_OP(_backward_Custom)
 .set_attr<bool>("TIsLayerOpBackward", true)
 .set_attr<bool>("TIsBackward", true)
 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
-    return ExecType::kLocal;
+    return ExecType::kAsync;
   })
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", Backward)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", Backward)
-.set_attr<FInferStorageType>("FInferStorageType", InferStorageType);
+.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", Backward)
+.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", Backward);
 
 }  // namespace custom
 }  // namespace op

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].