You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/01/05 19:29:27 UTC

[GitHub] piiswrong closed pull request #9283: Fix custom op multi-gpu scaling

piiswrong closed pull request #9283: Fix custom op multi-gpu scaling
URL: https://github.com/apache/incubator-mxnet/pull/9283
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h
index 8cb8a99b46..fb41d39609 100644
--- a/include/mxnet/op_attr_types.h
+++ b/include/mxnet/op_attr_types.h
@@ -92,8 +92,6 @@ enum class ExecType {
    *  will call OpContext.async_on_complete when operation finishes.
    */
   kAsync,
-  /*! \brief Run this operator on the scheduling thread without pushing to engine. */
-  kLocal,
   /*!
    * \brief Cross device copy operation, this is a special operator
    *  That indicates copy across devices, the input and output can sit on different device.
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 027f00ba87..c55f6c5781 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -1170,7 +1170,7 @@ int MXRtcFree(RtcHandle handle) {
 
 int MXCustomOpRegister(const char* op_type, CustomOpPropCreator creator) {
   API_BEGIN();
-  mxnet::op::custom::Registry::Get()->Register(op_type, creator);
+  mxnet::op::custom::CustomOperator::Get()->Register(op_type, creator);
   API_END();
 }
 
diff --git a/src/c_api/c_api_function.cc b/src/c_api/c_api_function.cc
index 3cd4f66a72..e8ca18944d 100644
--- a/src/c_api/c_api_function.cc
+++ b/src/c_api/c_api_function.cc
@@ -29,6 +29,7 @@
 
 #include "./c_api_common.h"
 #include "../operator/operator_common.h"
+#include "../operator/custom/custom-inl.h"
 
 namespace mxnet {
 namespace custom_function {
@@ -62,68 +63,55 @@ std::vector<nnvm::NodeEntry> Gradient(
 }
 
 OpStatePtr CreateState(const nnvm::NodeAttrs& attrs,
-                               Context ctx,
-                               const std::vector<TShape>& ishape,
-                               const std::vector<int>& itype) {
+                       Context ctx,
+                       const std::vector<TShape>& ishape,
+                       const std::vector<int>& itype) {
   LOG(FATAL) << "Not reached";
   return OpStatePtr::Create<void*>(nullptr);
 }
 
 void Forward(const OpStatePtr& state,
              const OpContext& ctx,
-             const std::vector<NDArray>& inputs,
+             const std::vector<TBlob>& inputs,
              const std::vector<OpReqType>& req,
-             const std::vector<NDArray>& outputs) {
+             const std::vector<TBlob>& outputs) {
   LOG(FATAL) << "Not reached";
 }
 
 void Backward(const OpStatePtr& state,
               const OpContext& ctx,
-              const std::vector<NDArray>& inputs,
+              const std::vector<TBlob>& inputs,
               const std::vector<OpReqType>& req,
-              const std::vector<NDArray>& outputs) {
+              const std::vector<TBlob>& outputs) {
   const CustomFunctionParam& params = state.get_state<CustomFunctionParam>();
 
   std::vector<NDArrayHandle> ptrs;
+  std::vector<NDArray> cpys;
+
+  auto dev_id = ctx.run_ctx.ctx.dev_id;
 
   for (const auto& i : inputs) {
-    NDArray* nd = new NDArray(i.Detach());
+    NDArray* nd = new NDArray(i, dev_id);
     ptrs.push_back(reinterpret_cast<NDArrayHandle>(nd));
+    cpys.push_back(*nd);
   }
   for (const auto& i : outputs) {
-    NDArray* nd = new NDArray(i.Detach());
+    NDArray* nd = new NDArray(i, dev_id);
     ptrs.push_back(reinterpret_cast<NDArrayHandle>(nd));
+    cpys.push_back(*nd);
   }
 
-  bool prev_recording = Imperative::Get()->set_is_recording(false);
-  bool prev_training = Imperative::Get()->set_is_training(ctx.is_train);
-
-  CHECK(reinterpret_cast<CustomFunctionBwdFunc>(
-      params.info->callbacks[kCustomFunctionBackward])(
-          inputs.size(), outputs.size(), ptrs.data(),
-          reinterpret_cast<const int*>(req.data()), ctx.is_train,
-          params.info->contexts[kCustomFunctionBackward]));
-
-  Imperative::Get()->set_is_training(prev_training);
-  Imperative::Get()->set_is_recording(prev_recording);
+  op::custom::CustomOperator::Get()->Push(
+    [=]() {
+      CHECK(reinterpret_cast<CustomFunctionBwdFunc>(
+          params.info->callbacks[kCustomFunctionBackward])(
+              inputs.size(), outputs.size(),
+              const_cast<NDArrayHandle*>(ptrs.data()),
+              reinterpret_cast<const int*>(req.data()), ctx.is_train,
+              params.info->contexts[kCustomFunctionBackward]));
+    }, ctx, false, ctx.is_train, cpys);
 }
 
-// infer storage function for custom op, which assigns kDefaultStorage for
-// all undefined stypes, and dispatch on DispatchMode::kFComputeEx.
-inline bool InferStorageType(const nnvm::NodeAttrs& attrs,
-                             const int dev_mask,
-                             DispatchMode* dispatch_mode,
-                             std::vector<int> *iattr,
-                             std::vector<int> *oattr) {
-  for (int& v : *oattr) {
-    if (v == -1) v = kDefaultStorage;
-  }
-  for (int& v : *iattr) {
-    if (v == -1) v = kDefaultStorage;
-  }
-  op::dispatch_mode_assign(dispatch_mode, DispatchMode::kFComputeEx);
-  return true;
-}
 
 NNVM_REGISTER_OP(_CustomFunction)
 .set_num_inputs([](const NodeAttrs& attrs) {
@@ -150,9 +138,8 @@ NNVM_REGISTER_OP(_CustomFunction)
   })
 .set_attr<FCreateOpState>("FCreateOpState", CreateState)
 .set_attr<nnvm::FGradient>("FGradient", Gradient)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", Forward)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", Forward)
-.set_attr<FInferStorageType>("FInferStorageType", InferStorageType);
+.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", Forward)
+.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", Forward);
 
 
 NNVM_REGISTER_OP(_backward_CustomFunction)
@@ -167,11 +154,10 @@ NNVM_REGISTER_OP(_backward_CustomFunction)
 .set_attr<bool>("TIsBackward", true)
 .set_attr<bool>("TIsLayerOpBackward", true)
 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
-    return ExecType::kLocal;
+    return ExecType::kAsync;
   })
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", Backward)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", Backward)
-.set_attr<FInferStorageType>("FInferStorageType", InferStorageType);
+.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", Backward)
+.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", Backward);
 
 }  // namespace custom_function
 }  // namespace mxnet
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 77853a61ca..5f95df3d6a 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -1473,9 +1473,6 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) {
       CHECK_EQ(opnode.exec->in_array.size(), 1U);
       CHECK_EQ(opnode.exec->out_array.size(), 1U);
       CopyFromTo(opnode.exec->in_array[0], &(opnode.exec->out_array[0]));
-    } else if (opnode.exec->exec_type() == ExecType::kLocal) {
-      bool is_gpu = opnode.ctx.dev_mask() == gpu::kDevMask;
-      opnode.exec->Run(RunContext{opnode.ctx, nullptr}, is_gpu);
     } else if (opnode.cached_opr != nullptr) {
 #if MXNET_USE_PROFILER
       bool profiling = engine::Profiler::Get()->GetState() == engine::Profiler::kRunning;
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index e265cce28e..df8979ab79 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -337,12 +337,15 @@ inline void PushFCompute(const FCompute& fn,
                   const std::vector<uint32_t>& mutate_idx,
                   const std::vector<OpReqType>& req) {
   using namespace common;
+  static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
+
   bool is_train = Imperative::Get()->is_training();
+  ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync;
+  CHECK(exec_type == ExecType::kSync);
   std::vector<NDArray> inputs, outputs;
   DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs);
   Engine::Get()->PushSync(
-    [ctx, attrs, fn, inputs, outputs, requested, is_train, mutate_idx, req](
-        RunContext rctx) {
+    [=](RunContext rctx) {
       std::vector<TBlob> input_blobs, output_blobs;
       // pre-fcompute and post-fcompute storage fallback src NDArrays and dst NDArrays
       std::vector<NDArray> pre_temp_src, pre_temp_dst, post_temp_dst, post_temp_src;
@@ -350,8 +353,8 @@ inline void PushFCompute(const FCompute& fn,
       std::unordered_map<uint32_t, uint32_t> in_temp_idx_map;
       // setup blobs
       SetupDefaultBlobsInOut(inputs, outputs, &input_blobs, &output_blobs,
-                             &pre_temp_src, &pre_temp_dst, &post_temp_src, &post_temp_dst,
-                             &in_temp_idx_map, mutate_idx);
+                             &pre_temp_src, &pre_temp_dst, &post_temp_src,
+                             &post_temp_dst, &in_temp_idx_map, mutate_idx);
       // setup context
       OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
       bool is_gpu = ctx.dev_mask() == gpu::kDevMask;
@@ -380,27 +383,23 @@ inline void PushFComputeEx(const FComputeEx& fn,
   static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
 
   bool is_train = Imperative::Get()->is_training();
-  ExecType exec_type = ExecType::kSync;
-  if (fexec_type.count(op)) {
-    exec_type = fexec_type[op](attrs);
-  }
+  ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync;
   std::vector<NDArray> inputs, outputs;
   DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs);
-  const auto& run = [ctx, exec_type, is_train, attrs, fn, inputs, outputs, requested, req](
-        RunContext rctx) {
+  const auto& run = [=](RunContext rctx) {
       OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
       fn(attrs, opctx, inputs, req, outputs);
-      if (exec_type == ExecType::kSync) {
-        if (rctx.get_ctx().dev_mask() == gpu::kDevMask) {
-          rctx.get_stream<gpu>()->Wait();
-        }
+      if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync) {
+        rctx.get_stream<gpu>()->Wait();
       }
     };
-  if (exec_type == ExecType::kLocal) {
+
+  if (exec_type == ExecType::kCrossDeviceCopy) {
     run(RunContext{ctx, nullptr});
   } else {
+    CHECK(exec_type == ExecType::kSync);
     Engine::Get()->PushSync(run, ctx, read_vars, write_vars, FnProperty::kNormal,
-      0, PROFILER_MESSAGE(op->name.c_str()));
+                            0, PROFILER_MESSAGE(op->name.c_str()));
   }
 }
 
@@ -420,42 +419,30 @@ inline void PushOperator(const OpStatePtr& state,
   static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType");
 
   bool is_train = Imperative::Get()->is_training();
-  ExecType exec_type = ExecType::kSync;
-  if (fexec_type.count(op)) {
-    exec_type = fexec_type[op](attrs);
-  }
+  ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync;
   std::vector<NDArray> inputs, outputs;
   DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs);
 
   auto fcompute = common::GetFCompute<FStatefulCompute>(op, "FStatefulCompute", ctx);
   auto fcompute_ex = common::GetFCompute<FStatefulComputeEx>(op, "FStatefulComputeEx", ctx);
   if (fcompute_ex != nullptr && dispatch_mode == DispatchMode::kFComputeEx) {
-    const auto& run = [state, fcompute_ex, inputs, outputs, requested, is_train,
-                       exec_type, req](
-          RunContext rctx) {
-        OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
-        fcompute_ex(state, opctx, inputs, req, outputs);
-        if (exec_type == ExecType::kSync) {
-          if (rctx.get_ctx().dev_mask() == gpu::kDevMask) {
+    CHECK(exec_type == ExecType::kSync);
+    Engine::Get()->PushSync(
+        [=](RunContext rctx) {
+          OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
+          fcompute_ex(state, opctx, inputs, req, outputs);
+          if (ctx.dev_mask() == gpu::kDevMask) {
             rctx.get_stream<gpu>()->Wait();
           }
-        }
-      };
-    if (exec_type == ExecType::kLocal) {
-      run(RunContext{ctx, nullptr});
-    } else {
-      Engine::Get()->PushSync(run, ctx, read_vars, write_vars, FnProperty::kNormal,
-                               0, PROFILER_MESSAGE(op->name.c_str()));
-    }
+        }, ctx, read_vars, write_vars, FnProperty::kNormal,
+        0, PROFILER_MESSAGE(op->name.c_str()));
   } else {
     CHECK(fcompute != nullptr)
         << "One of FStatefulCompute and FStatefulComputeEx must be registered "
         << "for stateful operator " << op->name;
-    CHECK(exec_type == ExecType::kSync || exec_type == ExecType::kAsync);
-    Engine::Get()->PushSync(
-      [state, fcompute, inputs, outputs, requested, is_train, exec_type, mutate_idx, req](
-          RunContext rctx) {
-        OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
+
+    const auto& run = [=](RunContext rctx, engine::CallbackOnComplete on_complete) {
+        OpContext opctx{is_train, rctx, on_complete, requested};
 
         std::vector<TBlob> input_blobs, output_blobs;
         // pre-fcompute and post-fcompute storage fallback src NDArrays and dst NDArrays
@@ -473,13 +460,23 @@ inline void PushOperator(const OpStatePtr& state,
         fcompute(state, opctx, input_blobs, req, output_blobs);
         // post-fcompute fallback, cast to original storage type, if necessary
         CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu);
-        if (exec_type == ExecType::kSync) {
-          if (is_gpu) {
-            rctx.get_stream<gpu>()->Wait();
-          }
+        if (is_gpu && exec_type == ExecType::kSync) {
+          rctx.get_stream<gpu>()->Wait();
         }
-      }, ctx, read_vars, write_vars, FnProperty::kNormal,
-      0, PROFILER_MESSAGE(op->name.c_str()));
+      };
+
+    if (exec_type == ExecType::kSync) {
+      Engine::Get()->PushSync(
+          [=](RunContext rctx) {
+            run(rctx, engine::CallbackOnComplete());
+          }, ctx, read_vars, write_vars, FnProperty::kNormal,
+          0, PROFILER_MESSAGE(op->name.c_str()));
+    } else {
+      CHECK(exec_type == ExecType::kAsync);
+      Engine::Get()->PushAsync(
+          run, ctx, read_vars, write_vars, FnProperty::kAsync,
+          0, PROFILER_MESSAGE(op->name.c_str()));
+    }
   }
 }
 
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index f09f168977..0be423d24f 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -1338,7 +1338,7 @@ NNVM_REGISTER_OP(_copyto)
     return true;
   })
 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
-    return ExecType::kLocal;
+    return ExecType::kCrossDeviceCopy;
   })
 .set_attr<nnvm::FGradient>("FGradient", op::ElemwiseGradUseNone{"_copyto"})
 .set_attr<bool>("TIsBackward", true)
diff --git a/src/operator/custom/custom-inl.h b/src/operator/custom/custom-inl.h
index 13101da61b..38aeefd66a 100644
--- a/src/operator/custom/custom-inl.h
+++ b/src/operator/custom/custom-inl.h
@@ -30,6 +30,7 @@
 #include <dmlc/parameter.h>
 #include <mxnet/operator.h>
 #include <mxnet/c_api.h>
+#include <mxnet/imperative.h>
 #include <map>
 #include <vector>
 #include <string>
@@ -46,7 +47,7 @@ namespace mxnet {
 namespace op {
 namespace custom {
 
-class Registry {
+class CustomOperator {
  public:
   void Register(const std::string &op_type, CustomOpPropCreator creator) {
     std::lock_guard<std::mutex> lock(mutex_);
@@ -63,11 +64,80 @@ class Registry {
     return nullptr;
   }
 
-  static Registry* Get();
+  template<typename Func>
+  void Push(const Func& func,
+            const OpContext& ctx,
+            bool recording,
+            bool training,
+            const std::vector<NDArray>& arrs) {
+    if (naive_engine_) {
+      func();
+      ctx.async_on_complete();
+      return;
+    }
+    std::unique_lock<std::mutex> lock(mutex_);
+    q_.push(
+      [=]() mutable {
+        bool prev_recording = Imperative::Get()->set_is_recording(recording);
+        bool prev_training = Imperative::Get()->set_is_training(training);
+
+        func();
+
+        Imperative::Get()->set_is_training(prev_training);
+        Imperative::Get()->set_is_recording(prev_recording);
+
+        std::vector<Engine::VarHandle> vars;
+        for (const auto& i : arrs) vars.push_back(i.var());
+        Engine::Get()->PushSync([=](RunContext rctx) {
+            ctx.async_on_complete();
+          }, ctx.run_ctx.ctx, vars, {},
+          FnProperty::kNormal, 0, PROFILER_MESSAGE("CustomOperator"));
+      });
+    cv_.notify_all();
+  }
+
+  ~CustomOperator() {
+    if (naive_engine_) return;
+    {
+      std::unique_lock<std::mutex> lock(mutex_);
+      destructing_ = true;
+      cv_.notify_all();
+    }
+    worker_.join();
+  }
+
+  static CustomOperator* Get();
+
  private:
-  Registry() {}
+  CustomOperator() {
+    destructing_ = false;
+    naive_engine_ = true;
+    if (std::string("NaiveEngine") != dmlc::GetEnv("MXNET_ENGINE_TYPE", std::string())) {
+      naive_engine_ = false;
+      worker_ = std::thread(
+        [&]() {
+          std::unique_lock<std::mutex> lock(mutex_);
+          while (!q_.empty() || !destructing_) {
+            cv_.wait(lock, [&] {return !q_.empty() || destructing_;});
+            while (!q_.empty()) {
+              auto fn = q_.front();
+              lock.unlock();
+              fn();
+              lock.lock();
+              q_.pop();
+            }
+          }
+        });
+    }
+  }
   std::mutex mutex_;
   std::map<std::string, CustomOpPropCreator> registry_;
+  // async worker
+  std::condition_variable cv_;
+  std::thread worker_;
+  std::queue<std::function<void(void)> > q_;
+  bool naive_engine_;
+  bool destructing_;
 };
 
 }  // namespace custom
diff --git a/src/operator/custom/custom.cc b/src/operator/custom/custom.cc
index 280b01b22e..beb5f3dc9f 100644
--- a/src/operator/custom/custom.cc
+++ b/src/operator/custom/custom.cc
@@ -26,7 +26,6 @@
 #include "./custom-inl.h"
 #include <mxnet/base.h>
 #include <mxnet/ndarray.h>
-#include <mxnet/imperative.h>
 
 #include "../elemwise_op_common.h"
 
@@ -34,8 +33,8 @@ namespace mxnet {
 namespace op {
 namespace custom {
 
-Registry* Registry::Get() {
-  static Registry inst;
+CustomOperator* CustomOperator::Get() {
+  static CustomOperator inst;
   return &inst;
 }
 
@@ -75,8 +74,8 @@ void AttrParser(NodeAttrs* attrs) {
     }
   }
   CHECK(!params.op_type.empty()) << "Required argument `op_type` is missing.";
-  CustomOpPropCreator creator = Registry::Get()->Find(params.op_type);
-  CHECK(Registry::Get()->Find(params.op_type) != nullptr)
+  CustomOpPropCreator creator = CustomOperator::Get()->Find(params.op_type);
+  CHECK(CustomOperator::Get()->Find(params.op_type) != nullptr)
       << "Cannot find custom operator " << params.op_type;
   params.info.reset(new MXCallbackList, [](MXCallbackList* ptr){
       reinterpret_cast<CustomOpDelFunc>(ptr->callbacks[kCustomOpPropDelete])(
@@ -269,103 +268,95 @@ OpStatePtr CreateState(const NodeAttrs& attrs, Context ctx,
 
 void Forward(const OpStatePtr& state,
              const OpContext& ctx,
-             const std::vector<NDArray>& inputs,
+             const std::vector<TBlob>& inputs,
              const std::vector<OpReqType>& req,
-             const std::vector<NDArray>& outputs) {
+             const std::vector<TBlob>& outputs) {
   const CustomParam& params = state.get_state<CustomParam>();
   std::vector<void*> ptrs;
   std::vector<int> tags;
+  std::vector<NDArray> cpys;
+
+  auto dev_id = ctx.run_ctx.ctx.dev_id;
 
   for (size_t i = 0; i < params.num_args; ++i) {
-    NDArray *nd = new NDArray(inputs[i].Detach());
+    NDArray *nd = new NDArray(inputs[i], dev_id);
+    cpys.push_back(*nd);
     ptrs.push_back(reinterpret_cast<void*>(nd));
     tags.push_back(0);
   }
 
   for (size_t i = 0; i < params.num_outs; ++i) {
-    NDArray *nd = new NDArray(outputs[i].Detach());
+    NDArray *nd = new NDArray(outputs[i], dev_id);
+    cpys.push_back(*nd);
     ptrs.push_back(reinterpret_cast<void*>(nd));
     tags.push_back(1);
   }
 
   for (size_t i = 0; i < params.num_auxs; ++i) {
-    NDArray *nd = new NDArray(inputs[i+params.num_args].Detach());
+    NDArray *nd = new NDArray(inputs[i+params.num_args], dev_id);
+    cpys.push_back(*nd);
     ptrs.push_back(reinterpret_cast<void*>(nd));
     tags.push_back(4);
   }
 
-  bool prev_recording = Imperative::Get()->set_is_recording(false);
-  bool prev_training = Imperative::Get()->set_is_training(ctx.is_train);
-
-  CHECK(reinterpret_cast<CustomOpFBFunc>(params.info->callbacks[kCustomOpForward])(
-    ptrs.size(), ptrs.data(), tags.data(), reinterpret_cast<const int*>(req.data()),
-    static_cast<int>(ctx.is_train), params.info->contexts[kCustomOpForward]));
-
-  Imperative::Get()->set_is_training(prev_training);
-  Imperative::Get()->set_is_recording(prev_recording);
+  CustomOperator::Get()->Push(
+    [=]() {
+      CHECK(reinterpret_cast<CustomOpFBFunc>(params.info->callbacks[kCustomOpForward])(
+        ptrs.size(), const_cast<void**>(ptrs.data()), const_cast<int*>(tags.data()),
+        reinterpret_cast<const int*>(req.data()), static_cast<int>(ctx.is_train),
+        params.info->contexts[kCustomOpForward]));
+    }, ctx, false, ctx.is_train, cpys);
 }
 
 
 void Backward(const OpStatePtr& state,
               const OpContext& ctx,
-              const std::vector<NDArray>& inputs,
+              const std::vector<TBlob>& inputs,
               const std::vector<OpReqType>& req,
-              const std::vector<NDArray>& outputs) {
+              const std::vector<TBlob>& outputs) {
   const CustomParam& params = state.get_state<CustomParam>();
 
   size_t total = 2*params.num_args + 2*params.num_outs + params.num_auxs;
   std::vector<void*> ptrs(params.num_args + 2*params.num_outs, nullptr);
   std::vector<int> tags;
+  std::vector<NDArray> cpys;
+
   ptrs.reserve(total);
   tags.reserve(total);
   for (size_t i = 0; i < params.num_outs; ++i) tags.push_back(3);
   for (size_t i = 0; i < params.num_args; ++i) tags.push_back(0);
   for (size_t i = 0; i < params.num_outs; ++i) tags.push_back(1);
 
+  auto dev_id = ctx.run_ctx.ctx.dev_id;
+
   for (size_t i = 0; i < params.bwd_idx.size(); ++i) {
-    NDArray *nd = new NDArray(inputs[i].Detach());
+    NDArray *nd = new NDArray(inputs[i], dev_id);
+    cpys.push_back(*nd);
     ptrs[params.bwd_idx[i]] = reinterpret_cast<void*>(nd);
   }
   for (size_t i = 0; i < ptrs.size(); ++i) {
     if (ptrs[i] == nullptr) ptrs[i] = reinterpret_cast<void*>(new NDArray());
   }
   for (const auto& i : outputs) {
-    NDArray* nd = new NDArray(i.Detach());
+    NDArray* nd = new NDArray(i, dev_id);
+    cpys.push_back(*nd);
     ptrs.push_back(reinterpret_cast<void*>(nd));
     tags.push_back(2);
   }
   for (size_t i = 0; i < params.num_auxs; ++i) {
-    NDArray* nd = new NDArray(inputs[inputs.size()-params.num_auxs+i].Detach());
+    NDArray* nd = new NDArray(inputs[inputs.size()-params.num_auxs+i], dev_id);
+    cpys.push_back(*nd);
     ptrs.push_back(reinterpret_cast<void*>(nd));
     tags.push_back(4);
   }
 
-  bool prev_recording = Imperative::Get()->set_is_recording(false);
-  bool prev_training = Imperative::Get()->set_is_training(ctx.is_train);
-
-  CHECK(reinterpret_cast<CustomOpFBFunc>(params.info->callbacks[kCustomOpBackward])(
-    ptrs.size(), ptrs.data(), tags.data(), reinterpret_cast<const int*>(req.data()),
-    static_cast<int>(ctx.is_train), params.info->contexts[kCustomOpBackward]));
-
-  Imperative::Get()->set_is_training(prev_training);
-  Imperative::Get()->set_is_recording(prev_recording);
-}
-
-// infer storage function for custom op, which assigns kDefaultStorage for
-// all undefined stypes, and dispatch on DispatchMode::kFComputeEx.
-inline bool InferStorageType(const nnvm::NodeAttrs& attrs,
-                             const int dev_mask,
-                             DispatchMode* dispatch_mode,
-                             std::vector<int> *iattr,
-                             std::vector<int> *oattr) {
-  for (int& v : *oattr) {
-    if (v == -1) v = kDefaultStorage;
-  }
-  for (int& v : *iattr) {
-    if (v == -1) v = kDefaultStorage;
-  }
-  dispatch_mode_assign(dispatch_mode, DispatchMode::kFComputeEx);
-  return true;
+  CustomOperator::Get()->Push(
+    [=]() {
+      CHECK(reinterpret_cast<CustomOpFBFunc>(params.info->callbacks[kCustomOpBackward])(
+        ptrs.size(), const_cast<void**>(ptrs.data()), const_cast<int*>(tags.data()),
+        reinterpret_cast<const int*>(req.data()), static_cast<int>(ctx.is_train),
+        params.info->contexts[kCustomOpBackward]));
+    }, ctx, false, ctx.is_train, cpys);
 }
 
 NNVM_REGISTER_OP(Custom)
@@ -401,13 +392,12 @@ Please check the tutorial here: http://mxnet.io/how_to/new_op.html.
     return ret;
   })
 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
-    return ExecType::kLocal;
+    return ExecType::kAsync;
   })
 .set_attr<nnvm::FGradient>("FGradient", Gradient)
 .set_attr<FCreateOpState>("FCreateOpState", CreateState)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", Forward)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", Forward)
-.set_attr<FInferStorageType>("FInferStorageType", InferStorageType)
+.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", Forward)
+.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", Forward)
 .add_argument("data", "NDArray-or-Symbol[]", "Input data for the custom operator.")
 .add_argument("op_type", "string", "Name of the custom operator. "
               "This is the name that is passed to `mx.operator.register` "
@@ -426,11 +416,10 @@ NNVM_REGISTER_OP(_backward_Custom)
 .set_attr<bool>("TIsLayerOpBackward", true)
 .set_attr<bool>("TIsBackward", true)
 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
-    return ExecType::kLocal;
+    return ExecType::kAsync;
   })
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", Backward)
-.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", Backward)
-.set_attr<FInferStorageType>("FInferStorageType", InferStorageType);
+.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", Backward)
+.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", Backward);
 
 }  // namespace custom
 }  // namespace op


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services