You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/10 17:20:21 UTC

[GitHub] eric-haibin-lin closed pull request #10714: [MXNET-364] broadcast_add/sub between CSR and 1D dense vector on CPU

eric-haibin-lin closed pull request #10714: [MXNET-364] broadcast_add/sub between CSR and 1D dense vector on CPU
URL: https://github.com/apache/incubator-mxnet/pull/10714
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h
index 42a8f0f2c15..933d0be4658 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -105,6 +105,32 @@ inline bool BinaryBroadcastMulStorageType(const nnvm::NodeAttrs& attrs,
   return dispatched;
 }
 
+inline bool BinaryBroadcastAddStorageType(const nnvm::NodeAttrs& attrs,
+                                          const int dev_mask,
+                                          DispatchMode* dispatch_mode,
+                                          std::vector<int>* in_attrs,
+                                          std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 2U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  const int lhs_stype = in_attrs->at(0);
+  const int rhs_stype = in_attrs->at(1);
+  int& out_stype = out_attrs->at(0);
+  bool dispatched = false;
+  if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
+    dispatched = storage_type_assign(&out_stype, kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFCompute);
+  }
+  if (!dispatched && ((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
+                      (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage))) {
+    dispatched = storage_type_assign(&out_stype, kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFComputeEx);
+  }
+  if (!dispatched) {
+    dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+  }
+  return dispatched;
+}
+
 #define BROADCAST_NDIM_SWITCH(ndim, NDim, ...)  \
   if (ndim <= 2) {                    \
     const int NDim = 2;               \
@@ -183,6 +209,24 @@ struct binary_broadcast_kernel {
       KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs[lidx], rhs[ridx]));
     }
   }
+
+  /*! \brief Map function for binary_broadcast_kernel */
+  MSHADOW_XINLINE static void Map(int base, int length, OpReqType req,
+                                  const Shape <ndim> &lstride, const Shape <ndim> &rstride,
+                                  const Shape <ndim> &oshape, DType lhs, DType *rhs,
+                                  DType *out) {
+    Shape <ndim> coord = unravel(base, oshape);
+    auto lidx = static_cast<index_t>(dot(coord, lstride));
+    auto ridx = static_cast<index_t>(dot(coord, rstride));
+    KERNEL_ASSIGN(out[base], req, OP::Map(lhs, rhs[ridx]));
+    // starts from 1 to avoid extra inc at end of loop
+    for (int i = 1; i < length; ++i) {
+      inc(&coord, oshape, &lidx, lstride, &ridx, rstride);
+      // When tuning, don't actually run the op, since it's not going to be tuned against
+      // the actual op we'll eventually be using
+      KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs, rhs[ridx]));
+    }
+  }
 };
 
 template<int req, typename OP>
@@ -200,6 +244,25 @@ struct csr_dns_csr_broadcast_kernel {
   }
 };
 
+template<int req, typename OP, bool reverse = false>
+struct csr_dns_map_kernel {
+  template <typename DType, typename CType, typename RType>
+  MSHADOW_XINLINE static void Map(int row, const DType *csr_data, const CType *csr_indices,
+                                  const RType *csr_indptr, DType *out, const nnvm::dim_t num_rows,
+                                  const nnvm::dim_t num_cols) {
+    if (row < num_rows) {
+      const nnvm::dim_t curr_row_i = csr_indptr[row];
+      const nnvm::dim_t next_row_i = csr_indptr[row + 1];
+      for (nnvm::dim_t iter = curr_row_i; iter < next_row_i; iter++) {
+        const nnvm::dim_t target = row * num_cols + csr_indices[iter];
+        KERNEL_ASSIGN(out[target], req,
+                      reverse ? OP::Map(out[target], csr_data[iter]) :
+                                OP::Map(csr_data[iter], out[target]));
+      }
+    }
+  }
+};
+
 }  // namespace mxnet_op
 
 template<typename xpu, typename OP>
@@ -284,11 +347,77 @@ void BinaryBroadcastCsrDnsCsrImpl(const OpContext& ctx,
 }
 
 template<typename xpu, typename OP>
-void BinaryBroadcastComputeEx(const nnvm::NodeAttrs& attrs,
-                              const OpContext& ctx,
-                              const std::vector<NDArray>& inputs,
-                              const std::vector<OpReqType>& req,
-                              const std::vector<NDArray>& outputs) {
+void BinaryBroadcastCsrDnsDnsImpl(const OpContext& ctx,
+                                  const NDArray& csr,
+                                  const NDArray& dns,
+                                  const OpReqType req,
+                                  const NDArray& output,
+                                  const TShape& new_csrshape,
+                                  const TShape& new_dnsshape,
+                                  const TShape& new_oshape,
+                                  const int ndim,
+                                  const bool reverse) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using namespace csr;
+  CHECK(req == kWriteTo) << "Only kWriteTo supported for broadcast(csr, dns) = dns";
+  const bool legal_op = std::is_same<OP, mshadow_op::plus>::value ||
+                        std::is_same<OP, mshadow_op::minus>::value;
+  CHECK(legal_op) << "Only add/sub are supported for broadcast(csr, dns) = dns";
+  CHECK_EQ(csr.shape()[0], output.shape()[0]);
+  CHECK_EQ(csr.shape()[1], output.shape()[1]);
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  const nnvm::dim_t num_rows = output.shape()[0];
+  const nnvm::dim_t num_cols = output.shape()[1];
+  const TBlob& csr_data = csr.data();
+  const TBlob& csr_indices = csr.aux_data(kIdx);
+  const TBlob& csr_indptr = csr.aux_data(kIndPtr);
+  TBlob dns_data = dns.data();
+  TBlob out_data = output.data();
+
+  MSHADOW_TYPE_SWITCH(output.dtype(), DType, {
+    BROADCAST_NDIM_SWITCH(ndim, NDim, {
+      Shape<NDim> oshape = new_oshape.get<NDim>();
+      Shape<NDim> lstride = calc_stride(new_csrshape.get<NDim>());
+      Shape<NDim> rstride = calc_stride(new_dnsshape.get<NDim>());
+      if (reverse && std::is_same<OP, mshadow_op::minus>::value) {
+        Kernel<binary_broadcast_kernel<NDim, DType, mshadow_op::plus>, xpu>::
+        template LaunchEx(s, new_oshape.Size(), req, lstride, rstride, oshape,
+        DType(0), dns_data.dptr<DType>(), out_data.dptr<DType>());
+      } else {
+        Kernel<binary_broadcast_kernel<NDim, DType, OP>, xpu>::
+        template LaunchEx(s, new_oshape.Size(), req, lstride, rstride, oshape,
+        DType(0), dns_data.dptr<DType>(), out_data.dptr<DType>());
+      }
+    });
+  });
+  if (csr.storage_initialized()) {
+    MSHADOW_TYPE_SWITCH(csr.dtype(), DType, {
+      MSHADOW_IDX_TYPE_SWITCH(csr.aux_type(kIdx), CType, {
+        MSHADOW_IDX_TYPE_SWITCH(csr.aux_type(kIndPtr), RType, {
+          MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+            if (reverse && std::is_same<OP, mshadow_op::minus>::value) {
+              Kernel<csr_dns_map_kernel<req_type, mshadow_op::minus, true>, xpu>::Launch(
+                s, num_rows, csr_data.dptr<DType>(), csr_indices.dptr<CType>(),
+                csr_indptr.dptr<RType>(), out_data.dptr<DType>(), num_rows, num_cols);
+            } else {
+              Kernel<csr_dns_map_kernel<req_type, mshadow_op::plus>, xpu>::Launch(
+                s, num_rows, csr_data.dptr<DType>(), csr_indices.dptr<CType>(),
+                csr_indptr.dptr<RType>(), out_data.dptr<DType>(), num_rows, num_cols);
+            }
+          });
+        });
+      });
+    });
+  }
+}
+
+template<typename xpu, typename OP>
+void BinaryBroadcastComputeSparseEx(const nnvm::NodeAttrs& attrs,
+                                    const OpContext& ctx,
+                                    const std::vector<NDArray>& inputs,
+                                    const std::vector<OpReqType>& req,
+                                    const std::vector<NDArray>& outputs) {
   CHECK_EQ(inputs.size(), 2U);
   CHECK_EQ(outputs.size(), 1U);
   CHECK_EQ(req.size(), 1U);
@@ -317,6 +446,50 @@ void BinaryBroadcastComputeEx(const nnvm::NodeAttrs& attrs,
   }
 }
 
+template<typename xpu, typename OP>
+void BinaryBroadcastComputeDenseEx(const nnvm::NodeAttrs& attrs,
+                                   const OpContext& ctx,
+                                   const std::vector<NDArray>& inputs,
+                                   const std::vector<OpReqType>& req,
+                                   const std::vector<NDArray>& outputs) {
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  CHECK_LE(inputs[1].shape().ndim(), 2U)
+    << "input dense matrix should have less than or equal to 2 dimensions";
+  if (req[0] == kNullOp) return;
+  const NDArray& lhs = inputs[0];
+  const NDArray& rhs = inputs[1];
+  const NDArray& out = outputs[0];
+  const auto lhs_stype = lhs.storage_type();
+  const auto rhs_stype = rhs.storage_type();
+  const auto out_stype = out.storage_type();
+  bool reverse = (lhs_stype == kDefaultStorage);
+  const NDArray& dns = (reverse) ? lhs : rhs;
+  const NDArray& csr = (reverse) ? rhs : lhs;
+  TShape new_csrshape, new_dnsshape, new_oshape;
+  int ndim = BinaryBroadcastShapeCompact(csr.shape(), dns.shape(), out.shape(),
+                                         &new_csrshape, &new_dnsshape, &new_oshape);
+
+  if (((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
+       (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage)) &&
+      out_stype == kDefaultStorage) {
+    // If the input is a matrix with the same shape, should be elemwise
+    if (!ndim) {
+      mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
+      ElemwiseBinaryOp::DnsCsrDnsOp<xpu, OP>(
+        s, attrs, ctx, dns, csr, req[0], outputs[0], !reverse);
+    } else {
+      // broadcast(CSR, Dense(1D)) = CSR
+      BinaryBroadcastCsrDnsDnsImpl<xpu, OP>(ctx, csr, dns, req[0], out,
+                                            new_csrshape, new_dnsshape, new_oshape,
+                                            ndim, reverse);
+    }
+  } else {
+    LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
+  }
+}
+
 template<typename xpu, typename LOP, typename ROP>
 void BinaryBroadcastBackwardUseNone(const nnvm::NodeAttrs& attrs,
                                     const OpContext& ctx,
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
index 6be4c265b9e..78b2d45567d 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
@@ -48,8 +48,14 @@ Example::
    broadcast_plus(x, y) = [[ 1.,  1.,  1.],
                            [ 2.,  2.,  2.]]
 
+Supported sparse operations:
+   broadcast_add(csr, dense(1D)) = dense
+   broadcast_add(dense(1D), csr) = dense
+
 )code" ADD_FILELINE)
 .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::plus>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryBroadcastComputeDenseEx<cpu, op::mshadow_op::plus>)
+.set_attr<FInferStorageType>("FInferStorageType", BinaryBroadcastAddStorageType)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"});
 
 NNVM_REGISTER_OP(_backward_broadcast_add)
@@ -87,8 +93,14 @@ Example::
    broadcast_minus(x, y) = [[ 1.,  1.,  1.],
                             [ 0.,  0.,  0.]]
 
+Supported sparse operations:
+   broadcast_sub/minus(csr, dense(1D)) = dense
+   broadcast_sub/minus(dense(1D), csr) = dense
+
 )code" ADD_FILELINE)
 .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::minus>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryBroadcastComputeDenseEx<cpu, op::mshadow_op::minus>)
+.set_attr<FInferStorageType>("FInferStorageType", BinaryBroadcastAddStorageType)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"});
 
 NNVM_REGISTER_OP(_backward_broadcast_sub)
@@ -125,7 +137,7 @@ Supported sparse operations:
 
 )code" ADD_FILELINE)
 .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::mul>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryBroadcastComputeEx<cpu, op::mshadow_op::mul>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryBroadcastComputeSparseEx<cpu, op::mshadow_op::mul>)
 .set_attr<FInferStorageType>("FInferStorageType", BinaryBroadcastMulStorageType)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"});
 
@@ -164,7 +176,7 @@ Supported sparse operations:
 
 )code" ADD_FILELINE)
 .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::div>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryBroadcastComputeEx<cpu, op::mshadow_op::div>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryBroadcastComputeSparseEx<cpu, op::mshadow_op::div>)
 .set_attr<FInferStorageType>("FInferStorageType", BinaryBroadcastMulStorageType)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_div"});
 
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu
index dc0ba021f56..976f09152a3 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu
+++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu
@@ -29,14 +29,16 @@
 namespace mxnet {
 namespace op {
 NNVM_REGISTER_OP(broadcast_add)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::plus>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::plus>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryBroadcastComputeDenseEx<gpu, op::mshadow_op::plus>);
 
 NNVM_REGISTER_OP(_backward_broadcast_add)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseNone<gpu, mshadow_op::identity,
                                                                 mshadow_op::identity>);
 
 NNVM_REGISTER_OP(broadcast_sub)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::minus>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::minus>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryBroadcastComputeDenseEx<gpu, op::mshadow_op::minus>);
 
 NNVM_REGISTER_OP(_backward_broadcast_sub)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseNone<gpu, mshadow_op::identity,
diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h
index 9d3f6e096dd..a5b73dadd3a 100644
--- a/src/operator/tensor/elemwise_binary_op.h
+++ b/src/operator/tensor/elemwise_binary_op.h
@@ -200,7 +200,7 @@ class ElemwiseBinaryOp : public OpBase {
     }
   }
 
- protected:
+ public:
   /*! \brief Binary op handling for lhr/rhs: RspDns, RspRsp, DnsRsp, or RspRsp->Dns result */
   template<typename OP>
   static void RspRspOp(mshadow::Stream<cpu> *s,
@@ -231,7 +231,7 @@ class ElemwiseBinaryOp : public OpBase {
 
   /*! \brief CSR -op- CSR binary operator for non-canonical NDArray */
   template<typename OP>
-  static inline void CsrCsrOp(mshadow::Stream<cpu> *s,
+  static void CsrCsrOp(mshadow::Stream<cpu> *s,
                               const nnvm::NodeAttrs &attrs,
                               const OpContext &ctx,
                               const NDArray &lhs,
@@ -241,7 +241,7 @@ class ElemwiseBinaryOp : public OpBase {
 
   /*! \brief CSR -op- CSR binary operator for non-canonical NDArray */
   template<typename OP>
-  static inline void CsrCsrOp(mshadow::Stream<gpu> *s,
+  static void CsrCsrOp(mshadow::Stream<gpu> *s,
                               const nnvm::NodeAttrs &attrs,
                               const OpContext &ctx,
                               const NDArray &lhs,
@@ -251,7 +251,7 @@ class ElemwiseBinaryOp : public OpBase {
 
   /*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
   template<typename xpu, typename OP>
-  static inline void DnsCsrDnsOp(mshadow::Stream<xpu> *s,
+  static void DnsCsrDnsOp(mshadow::Stream<xpu> *s,
                                  const nnvm::NodeAttrs &attrs,
                                  const OpContext &ctx,
                                  const NDArray &lhs,
@@ -262,7 +262,7 @@ class ElemwiseBinaryOp : public OpBase {
 
   /*! \brief DNS -op- RSP binary operator for non-canonical NDArray */
   template<typename xpu, typename OP>
-  static inline void DnsRspDnsOp(mshadow::Stream<xpu> *s,
+  static void DnsRspDnsOp(mshadow::Stream<xpu> *s,
                                  const nnvm::NodeAttrs &attrs,
                                  const OpContext &ctx,
                                  const NDArray &lhs,
diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py
index 2d5ed5a1ee7..f21334e65cd 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -1768,6 +1768,31 @@ def check_sparse_embedding(in_dim, out_dim, batch, densities, deterministic, wei
             check_sparse_embedding(in_dim, out_dim, batch, densities, deterministic, stype)
             check_sparse_embedding(in_dim, out_dim, batch, densities, deterministic, stype)
 
+@with_seed()
+def test_sparse_broadcast_add_sub():
+    def check_broadcast_add(mx_lhs, mx_rhs, np_lhs, np_rhs, dtype):
+        assert_almost_equal(mx.nd.sparse.add(mx_lhs, mx_rhs).asnumpy(), np.add(np_lhs, np_rhs), atol=1e-4)
+    def check_broadcast_sub(mx_lhs, mx_rhs, np_lhs, np_rhs, dtype):
+        assert_almost_equal(mx.nd.sparse.subtract(mx_lhs, mx_rhs).asnumpy(), np.subtract(np_lhs, np_rhs), atol=1e-4)
+    stype = 'csr'
+    shape = rand_shape_2d()
+    num_rows = shape[0]
+    num_cols = shape[1]
+    for density in [0.1 * i for i in range(10)]:
+        mx_lhs = rand_ndarray(shape, stype, density)
+        np_lhs = mx_lhs.asnumpy()
+        mx_rhs_row_2D = rand_ndarray((1, num_cols), 'default')
+        mx_rhs_row_1D = mx_rhs_row_2D.reshape((num_cols))
+        mx_rhs_col = rand_ndarray((num_rows, 1), 'default')
+        mx_rhs_scalar_2D = rand_ndarray((1, 1), 'default')
+        mx_rhs_scalar_1D = mx_rhs_scalar_2D.reshape((1, ))
+        for mx_rhs in [mx_rhs_row_2D, mx_rhs_row_1D, mx_rhs_col, mx_rhs_scalar_2D, mx_rhs_scalar_1D]:
+            np_rhs = mx_rhs.asnumpy()
+            check_broadcast_add(mx_lhs, mx_rhs, np_lhs, np_rhs, np.float32)
+            check_broadcast_sub(mx_lhs, mx_rhs, np_lhs, np_rhs, np.float32)
+            check_broadcast_add(mx_rhs, mx_lhs, np_rhs, np_lhs, np.float32)
+            check_broadcast_sub(mx_rhs, mx_lhs, np_rhs, np_lhs, np.float32)
+
 @with_seed()
 def test_sparse_broadcast_mul_div():
     def check_broadcast_mul(mx_lhs, mx_rhs, np_lhs, np_rhs, dtype):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services