You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/26 16:34:50 UTC

[GitHub] piiswrong closed pull request #9029: [MXNET-36] Update ndarray binary ops to use kernel launch instead of mshadow operations

piiswrong closed pull request #9029: [MXNET-36] Update ndarray binary ops to use kernel launch instead of mshadow operations
URL: https://github.com/apache/incubator-mxnet/pull/9029
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h
index fb41d396099..202220db265 100644
--- a/include/mxnet/op_attr_types.h
+++ b/include/mxnet/op_attr_types.h
@@ -49,8 +49,8 @@ enum OpReqType {
   kWriteTo,
   /*!
    * \brief perform an inplace write,
-   * Target shares memory with one of input arguments.
    * This option only happen when
+   * Target shares memory with one of input arguments.
    */
   kWriteInplace,
   /*! \brief add to the provided space */
diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h
index f1637c4e57d..c2ddcd8708d 100644
--- a/src/kvstore/kvstore_dist_server.h
+++ b/src/kvstore/kvstore_dist_server.h
@@ -297,9 +297,6 @@ class KVStoreDistServer {
           CopyFromTo(recved, &merged.array, 0);
         } else {
           NDArray out(kRowSparseStorage, stored.shape(), Context());
-          std::vector<Engine::VarHandle> const_vars;
-          const_vars.push_back(recved.var());
-          const_vars.push_back(merged.array.var());
           // accumulate row_sparse gradients
           // TODO(haibin) override + operator for row_sparse NDArray
           // instead of calling BinaryComputeRspRsp directly
@@ -309,7 +306,7 @@ class KVStoreDistServer {
               op::ElemwiseBinaryOp::ComputeEx<cpu, op::mshadow_op::plus>(
                 {}, {}, {recved, merged.array}, {kWriteTo}, {out});
               on_complete();
-            }, recved.ctx(), const_vars, {out.var()},
+            }, recved.ctx(), {recved.var(), merged.array.var()}, {out.var()},
             FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
           CopyFromTo(out, &merged.array, 0);
         }
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index ae7209e272b..d92236c900a 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -781,16 +781,18 @@ void TernaryOp(const NDArray &lhs,
 }
 
 /*!
- * \brief run a binary operation
- * \param lhs left operand
- * \param rhs right operand
- * \param out the output ndarray
- * \param binary_op the real
- */
+* \brief Performs some preparation required to apply binary operators.
+* Checks context and shape of ndarrays, allocates space for output
+* and prepares const variables for engine
+* \param lhs left operand
+* \param rhs right operand
+* \param out the output ndarray
+* \param binary_op the real operation
+*/
 template<typename OP>
-void BinaryOp(const NDArray &lhs,
-              const NDArray &rhs,
-              NDArray *out) {
+std::vector<Engine::VarHandle> BinaryOpPrepare(const NDArray &lhs,
+                                               const NDArray &rhs,
+                                               NDArray *out) {
   // no check if both of them are on cpu
   if (lhs.ctx().dev_mask() != cpu::kDevMask || rhs.ctx().dev_mask() != cpu::kDevMask) {
     CHECK(lhs.ctx() == rhs.ctx()) << "operands context mismatch";
@@ -805,15 +807,71 @@ void BinaryOp(const NDArray &lhs,
       CHECK(out->ctx() == lhs.ctx()) << "target context mismatch";
     }
     CHECK(out->shape() == OP::GetShape(lhs.shape(), rhs.shape()))
-        << "target shape mismatch";
+      << "target shape mismatch";
   }
+  std::vector<Engine::VarHandle> const_vars;
+  // prepare const variables for engine
+  if (lhs.var() != out->var()) const_vars.push_back(lhs.var());
+  if (rhs.var() != out->var()) const_vars.push_back(rhs.var());
+  return const_vars;
+}
+
+/*!
+* \brief run a binary operation using the kernel launch method
+* \param lhs left operand
+* \param rhs right operand
+* \param out the output ndarray
+* \param binary_op the real operation
+*/
+template<typename OP>
+void BinaryOpKernel(const NDArray &lhs,
+                    const NDArray &rhs,
+                    NDArray *out) {
+  std::vector<Engine::VarHandle> const_vars = BinaryOpPrepare<OP>(lhs, rhs, out);
   // important: callback must always capture by value
   NDArray ret = *out;
-  // get the const variables
-  std::vector<Engine::VarHandle> const_vars;
-  if (lhs.var() != ret.var()) const_vars.push_back(lhs.var());
-  if (rhs.var() != ret.var()) const_vars.push_back(rhs.var());
+  switch (lhs.ctx().dev_mask()) {
+    case cpu::kDevMask: {
+      Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) {
+        TBlob tmp = ret.data();
+        mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
+        ndarray::BinaryOpKernelImpl<OP>(s, lhs.data(), rhs.data(), &tmp);
+      },
+      lhs.ctx(), const_vars, {ret.var()},
+      FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
+      break;
+    }
+#if MXNET_USE_CUDA
+    case gpu::kDevMask: {
+      Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) {
+        TBlob tmp = ret.data();
+        mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
+        ndarray::BinaryOpKernelImpl<OP>(s, lhs.data(), rhs.data(), &tmp);
+        // Wait GPU kernel to complete
+        ctx.get_stream<gpu>()->Wait();
+      }, lhs.ctx(), const_vars, {ret.var()},
+      FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
+      break;
+}
+#endif
+    default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+  }
+}
 
+/*!
+ * \brief run a binary operation using mshadow operations
+ * \param lhs left operand
+ * \param rhs right operand
+ * \param out the output ndarray
+ * \param binary_op the real operation
+ */
+template<typename OP>
+void BinaryOp(const NDArray &lhs,
+              const NDArray &rhs,
+              NDArray *out) {
+  std::vector<Engine::VarHandle> const_vars = BinaryOpPrepare<OP>(lhs, rhs, out);
+  // important: callback must always capture by value
+  NDArray ret = *out;
   // redirect everything to mshadow operations
   switch (lhs.ctx().dev_mask()) {
     case cpu::kDevMask: {
@@ -1377,7 +1435,7 @@ template<typename OP>
 inline NDArray BinaryOpRet(const NDArray &lhs,
                            const NDArray &rhs) {
   NDArray ret;
-  BinaryOp<OP>(lhs, rhs, &ret);
+  BinaryOpKernel<OP>(lhs, rhs, &ret);
   return ret;
 }
 
@@ -1392,7 +1450,7 @@ inline NDArray ScalarOpRet(const NDArray &lhs,
 template<typename OP>
 inline NDArray &BinaryOpApply(NDArray *dst,
                               const NDArray &src) {
-  BinaryOp<OP>(*dst, src, dst);
+  BinaryOpKernel<OP>(*dst, src, dst);
   return *dst;
 }
 
diff --git a/src/ndarray/ndarray_function-inl.h b/src/ndarray/ndarray_function-inl.h
index a80d9db3637..d494f0882bb 100644
--- a/src/ndarray/ndarray_function-inl.h
+++ b/src/ndarray/ndarray_function-inl.h
@@ -47,6 +47,15 @@
   }
 #endif
 
+#ifndef DECL_BINARY_LAUNCH
+#define DECL_BINARY_LAUNCH(XPU, OP)                                               \
+  template <> \
+  void BinaryOpKernelImpl<OP, XPU>(mshadow::Stream<XPU> *s, \
+                          const TBlob& lhs, const TBlob& rhs, TBlob *out) { \
+    BinaryOpKernelLaunch<OP>(s, lhs, rhs, out); \
+  }
+#endif
+
 #ifndef DECL_SCALAR
 #define DECL_SCALAR(XPU, OP, FUN, REVERSE)                           \
   template<>                                                         \
@@ -433,18 +442,31 @@ void EvalBroadcast<DEVICE>(TBlob const& src, TBlob* ret, int size, RunContext ct
   out = mshadow::expr::broadcast_with_axis(in, 0, size);
 }
 
+template<typename OP, typename xpu>
+void BinaryOpKernelLaunch(mshadow::Stream<xpu>* s, const TBlob& lhs, const TBlob& rhs, TBlob *out) {
+  using namespace op::mxnet_op;
+  using namespace mshadow;
+  MSHADOW_TYPE_SWITCH(out->type_flag_, DType, {
+    Kernel<op_with_req<OP, kWriteInplace>, xpu >::
+    Launch(s,
+           lhs.Size(),
+           out->dptr<DType>(),
+           lhs.dptr<DType>(),
+           rhs.dptr<DType>());
+  });
+}
 // declarations
 DECL_BINARY(DEVICE, MatChooseRowElem, EvalMatChooseRowElem_)
 DECL_TERNARY(DEVICE, MatFillRowElem, EvalMatFillRowElem_)
 DECL_BINARY(DEVICE, OneHotEncode, EvalOneHot_)
-DECL_BINARY(DEVICE, Plus, EvalBinary_)
-DECL_BINARY(DEVICE, Minus, EvalBinary_)
-DECL_BINARY(DEVICE, Mul, EvalBinary_)
-DECL_BINARY(DEVICE, Div, EvalBinary_)
 DECL_SCALAR(DEVICE, Plus, EvalScalar_, true)
 DECL_SCALAR(DEVICE, Minus, EvalScalar_, true)
 DECL_SCALAR(DEVICE, Mul, EvalScalar_, true)
 DECL_SCALAR(DEVICE, Div, EvalScalar_, true)
+DECL_BINARY_LAUNCH(DEVICE, Plus)
+DECL_BINARY_LAUNCH(DEVICE, Minus)
+DECL_BINARY_LAUNCH(DEVICE, Mul)
+DECL_BINARY_LAUNCH(DEVICE, Div)
 
 // for reverse seq
 DECL_SCALAR(DEVICE, Plus, EvalScalar_, false)
diff --git a/src/ndarray/ndarray_function.h b/src/ndarray/ndarray_function.h
index 518bb773170..97c23b67592 100644
--- a/src/ndarray/ndarray_function.h
+++ b/src/ndarray/ndarray_function.h
@@ -46,20 +46,20 @@ struct BinaryBase {
 };
 
 // operators
-struct Plus : public BinaryBase {
-  typedef op::mshadow_op::plus mshadow_op;
+struct Plus : public BinaryBase, public mshadow::op::plus {
+  typedef mshadow::op::plus mshadow_op;
 };
 
-struct Minus : public BinaryBase {
-  typedef op::mshadow_op::minus mshadow_op;
+struct Minus : public BinaryBase, public mshadow::op::minus {
+  typedef mshadow::op::minus mshadow_op;
 };
 
-struct Mul : public BinaryBase {
-  typedef op::mshadow_op::mul mshadow_op;
+struct Mul : public BinaryBase, public mshadow::op::mul {
+  typedef mshadow::op::mul mshadow_op;
 };
 
-struct Div : public BinaryBase {
-  typedef op::mshadow_op::div mshadow_op;
+struct Div : public BinaryBase, public mshadow::op::div {
+  typedef mshadow::op::div mshadow_op;
 };
 
 struct Mod : public BinaryBase {
@@ -208,6 +208,10 @@ void Eval(mshadow::Stream<xpu> *s,
 template <typename Device>
 void EvalBroadcast(TBlob const& src, TBlob* ret, int size, RunContext ctx);
 
+template <typename OP, typename xpu>
+void BinaryOpKernelImpl(mshadow::Stream<xpu> *s, const TBlob& lhs,
+                        const TBlob& rhs, TBlob *out);
+
 }  // namespace ndarray
 }  // namespace mxnet
 #endif  // MXNET_NDARRAY_NDARRAY_FUNCTION_H_


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services