You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/05/10 17:20:26 UTC

[incubator-mxnet] branch master updated: support broadcast_add/sub between csr and 1D dense vector (#10714)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 3972861  support broadcast_add/sub between csr and 1D dense vector (#10714)
3972861 is described below

commit 3972861ae6a95958b527331ace783a3c50f13072
Author: Hao Jin <ha...@users.noreply.github.com>
AuthorDate: Thu May 10 10:20:19 2018 -0700

    support broadcast_add/sub between csr and 1D dense vector (#10714)
---
 src/operator/tensor/elemwise_binary_broadcast_op.h | 183 ++++++++++++++++++++-
 .../tensor/elemwise_binary_broadcast_op_basic.cc   |  16 +-
 .../tensor/elemwise_binary_broadcast_op_basic.cu   |   6 +-
 src/operator/tensor/elemwise_binary_op.h           |  10 +-
 tests/python/unittest/test_sparse_operator.py      |  25 +++
 5 files changed, 226 insertions(+), 14 deletions(-)

diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h
index 42a8f0f..933d0be 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -105,6 +105,32 @@ inline bool BinaryBroadcastMulStorageType(const nnvm::NodeAttrs& attrs,
   return dispatched;
 }
 
+inline bool BinaryBroadcastAddStorageType(const nnvm::NodeAttrs& attrs,
+                                          const int dev_mask,
+                                          DispatchMode* dispatch_mode,
+                                          std::vector<int>* in_attrs,
+                                          std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 2U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  const int lhs_stype = in_attrs->at(0);
+  const int rhs_stype = in_attrs->at(1);
+  int& out_stype = out_attrs->at(0);
+  bool dispatched = false;
+  if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
+    dispatched = storage_type_assign(&out_stype, kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFCompute);
+  }
+  if (!dispatched && ((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
+                      (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage))) {
+    dispatched = storage_type_assign(&out_stype, kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFComputeEx);
+  }
+  if (!dispatched) {
+    dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+  }
+  return dispatched;
+}
+
 #define BROADCAST_NDIM_SWITCH(ndim, NDim, ...)  \
   if (ndim <= 2) {                    \
     const int NDim = 2;               \
@@ -183,6 +209,24 @@ struct binary_broadcast_kernel {
       KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs[lidx], rhs[ridx]));
     }
   }
+
+  /*! \brief Map function for binary_broadcast_kernel */
+  MSHADOW_XINLINE static void Map(int base, int length, OpReqType req,
+                                  const Shape <ndim> &lstride, const Shape <ndim> &rstride,
+                                  const Shape <ndim> &oshape, DType lhs, DType *rhs,
+                                  DType *out) {
+    Shape <ndim> coord = unravel(base, oshape);
+    auto lidx = static_cast<index_t>(dot(coord, lstride));
+    auto ridx = static_cast<index_t>(dot(coord, rstride));
+    KERNEL_ASSIGN(out[base], req, OP::Map(lhs, rhs[ridx]));
+    // starts from 1 to avoid extra inc at end of loop
+    for (int i = 1; i < length; ++i) {
+      inc(&coord, oshape, &lidx, lstride, &ridx, rstride);
+      // When tuning, don't actually run the op, since it's not going to be tuned against
+      // the actual op we'll eventually be using
+      KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs, rhs[ridx]));
+    }
+  }
 };
 
 template<int req, typename OP>
@@ -200,6 +244,25 @@ struct csr_dns_csr_broadcast_kernel {
   }
 };
 
+template<int req, typename OP, bool reverse = false>
+struct csr_dns_map_kernel {
+  template <typename DType, typename CType, typename RType>
+  MSHADOW_XINLINE static void Map(int row, const DType *csr_data, const CType *csr_indices,
+                                  const RType *csr_indptr, DType *out, const nnvm::dim_t num_rows,
+                                  const nnvm::dim_t num_cols) {
+    if (row < num_rows) {
+      const nnvm::dim_t curr_row_i = csr_indptr[row];
+      const nnvm::dim_t next_row_i = csr_indptr[row + 1];
+      for (nnvm::dim_t iter = curr_row_i; iter < next_row_i; iter++) {
+        const nnvm::dim_t target = row * num_cols + csr_indices[iter];
+        KERNEL_ASSIGN(out[target], req,
+                      reverse ? OP::Map(out[target], csr_data[iter]) :
+                                OP::Map(csr_data[iter], out[target]));
+      }
+    }
+  }
+};
+
 }  // namespace mxnet_op
 
 template<typename xpu, typename OP>
@@ -284,11 +347,77 @@ void BinaryBroadcastCsrDnsCsrImpl(const OpContext& ctx,
 }
 
 template<typename xpu, typename OP>
-void BinaryBroadcastComputeEx(const nnvm::NodeAttrs& attrs,
-                              const OpContext& ctx,
-                              const std::vector<NDArray>& inputs,
-                              const std::vector<OpReqType>& req,
-                              const std::vector<NDArray>& outputs) {
+void BinaryBroadcastCsrDnsDnsImpl(const OpContext& ctx,
+                                  const NDArray& csr,
+                                  const NDArray& dns,
+                                  const OpReqType req,
+                                  const NDArray& output,
+                                  const TShape& new_csrshape,
+                                  const TShape& new_dnsshape,
+                                  const TShape& new_oshape,
+                                  const int ndim,
+                                  const bool reverse) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using namespace csr;
+  CHECK(req == kWriteTo) << "Only kWriteTo supported for broadcast(csr, dns) = dns";
+  const bool legal_op = std::is_same<OP, mshadow_op::plus>::value ||
+                        std::is_same<OP, mshadow_op::minus>::value;
+  CHECK(legal_op) << "Only add/sub are supported for broadcast(csr, dns) = dns";
+  CHECK_EQ(csr.shape()[0], output.shape()[0]);
+  CHECK_EQ(csr.shape()[1], output.shape()[1]);
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  const nnvm::dim_t num_rows = output.shape()[0];
+  const nnvm::dim_t num_cols = output.shape()[1];
+  const TBlob& csr_data = csr.data();
+  const TBlob& csr_indices = csr.aux_data(kIdx);
+  const TBlob& csr_indptr = csr.aux_data(kIndPtr);
+  TBlob dns_data = dns.data();
+  TBlob out_data = output.data();
+
+  MSHADOW_TYPE_SWITCH(output.dtype(), DType, {
+    BROADCAST_NDIM_SWITCH(ndim, NDim, {
+      Shape<NDim> oshape = new_oshape.get<NDim>();
+      Shape<NDim> lstride = calc_stride(new_csrshape.get<NDim>());
+      Shape<NDim> rstride = calc_stride(new_dnsshape.get<NDim>());
+      if (reverse && std::is_same<OP, mshadow_op::minus>::value) {
+        Kernel<binary_broadcast_kernel<NDim, DType, mshadow_op::plus>, xpu>::
+        template LaunchEx(s, new_oshape.Size(), req, lstride, rstride, oshape,
+        DType(0), dns_data.dptr<DType>(), out_data.dptr<DType>());
+      } else {
+        Kernel<binary_broadcast_kernel<NDim, DType, OP>, xpu>::
+        template LaunchEx(s, new_oshape.Size(), req, lstride, rstride, oshape,
+        DType(0), dns_data.dptr<DType>(), out_data.dptr<DType>());
+      }
+    });
+  });
+  if (csr.storage_initialized()) {
+    MSHADOW_TYPE_SWITCH(csr.dtype(), DType, {
+      MSHADOW_IDX_TYPE_SWITCH(csr.aux_type(kIdx), CType, {
+        MSHADOW_IDX_TYPE_SWITCH(csr.aux_type(kIndPtr), RType, {
+          MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+            if (reverse && std::is_same<OP, mshadow_op::minus>::value) {
+              Kernel<csr_dns_map_kernel<req_type, mshadow_op::minus, true>, xpu>::Launch(
+                s, num_rows, csr_data.dptr<DType>(), csr_indices.dptr<CType>(),
+                csr_indptr.dptr<RType>(), out_data.dptr<DType>(), num_rows, num_cols);
+            } else {
+              Kernel<csr_dns_map_kernel<req_type, mshadow_op::plus>, xpu>::Launch(
+                s, num_rows, csr_data.dptr<DType>(), csr_indices.dptr<CType>(),
+                csr_indptr.dptr<RType>(), out_data.dptr<DType>(), num_rows, num_cols);
+            }
+          });
+        });
+      });
+    });
+  }
+}
+
+template<typename xpu, typename OP>
+void BinaryBroadcastComputeSparseEx(const nnvm::NodeAttrs& attrs,
+                                    const OpContext& ctx,
+                                    const std::vector<NDArray>& inputs,
+                                    const std::vector<OpReqType>& req,
+                                    const std::vector<NDArray>& outputs) {
   CHECK_EQ(inputs.size(), 2U);
   CHECK_EQ(outputs.size(), 1U);
   CHECK_EQ(req.size(), 1U);
@@ -317,6 +446,50 @@ void BinaryBroadcastComputeEx(const nnvm::NodeAttrs& attrs,
   }
 }
 
+template<typename xpu, typename OP>
+void BinaryBroadcastComputeDenseEx(const nnvm::NodeAttrs& attrs,
+                                   const OpContext& ctx,
+                                   const std::vector<NDArray>& inputs,
+                                   const std::vector<OpReqType>& req,
+                                   const std::vector<NDArray>& outputs) {
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  CHECK_LE(inputs[1].shape().ndim(), 2U)
+    << "input dense matrix should have less than or equal to 2 dimensions";
+  if (req[0] == kNullOp) return;
+  const NDArray& lhs = inputs[0];
+  const NDArray& rhs = inputs[1];
+  const NDArray& out = outputs[0];
+  const auto lhs_stype = lhs.storage_type();
+  const auto rhs_stype = rhs.storage_type();
+  const auto out_stype = out.storage_type();
+  bool reverse = (lhs_stype == kDefaultStorage);
+  const NDArray& dns = (reverse) ? lhs : rhs;
+  const NDArray& csr = (reverse) ? rhs : lhs;
+  TShape new_csrshape, new_dnsshape, new_oshape;
+  int ndim = BinaryBroadcastShapeCompact(csr.shape(), dns.shape(), out.shape(),
+                                         &new_csrshape, &new_dnsshape, &new_oshape);
+
+  if (((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
+       (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage)) &&
+      out_stype == kDefaultStorage) {
+    // If the input is a matrix with the same shape, should be elemwise
+    if (!ndim) {
+      mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
+      ElemwiseBinaryOp::DnsCsrDnsOp<xpu, OP>(
+        s, attrs, ctx, dns, csr, req[0], outputs[0], !reverse);
+    } else {
+      // broadcast(CSR, Dense(1D)) = CSR
+      BinaryBroadcastCsrDnsDnsImpl<xpu, OP>(ctx, csr, dns, req[0], out,
+                                            new_csrshape, new_dnsshape, new_oshape,
+                                            ndim, reverse);
+    }
+  } else {
+    LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
+  }
+}
+
 template<typename xpu, typename LOP, typename ROP>
 void BinaryBroadcastBackwardUseNone(const nnvm::NodeAttrs& attrs,
                                     const OpContext& ctx,
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
index 6be4c26..78b2d45 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
@@ -48,8 +48,14 @@ Example::
    broadcast_plus(x, y) = [[ 1.,  1.,  1.],
                            [ 2.,  2.,  2.]]
 
+Supported sparse operations:
+   broadcast_add(csr, dense(1D)) = dense
+   broadcast_add(dense(1D), csr) = dense
+
 )code" ADD_FILELINE)
 .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::plus>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryBroadcastComputeDenseEx<cpu, op::mshadow_op::plus>)
+.set_attr<FInferStorageType>("FInferStorageType", BinaryBroadcastAddStorageType)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"});
 
 NNVM_REGISTER_OP(_backward_broadcast_add)
@@ -87,8 +93,14 @@ Example::
    broadcast_minus(x, y) = [[ 1.,  1.,  1.],
                             [ 0.,  0.,  0.]]
 
+Supported sparse operations:
+   broadcast_sub/minus(csr, dense(1D)) = dense
+   broadcast_sub/minus(dense(1D), csr) = dense
+
 )code" ADD_FILELINE)
 .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::minus>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryBroadcastComputeDenseEx<cpu, op::mshadow_op::minus>)
+.set_attr<FInferStorageType>("FInferStorageType", BinaryBroadcastAddStorageType)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"});
 
 NNVM_REGISTER_OP(_backward_broadcast_sub)
@@ -125,7 +137,7 @@ Supported sparse operations:
 
 )code" ADD_FILELINE)
 .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::mul>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryBroadcastComputeEx<cpu, op::mshadow_op::mul>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryBroadcastComputeSparseEx<cpu, op::mshadow_op::mul>)
 .set_attr<FInferStorageType>("FInferStorageType", BinaryBroadcastMulStorageType)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"});
 
@@ -164,7 +176,7 @@ Supported sparse operations:
 
 )code" ADD_FILELINE)
 .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::div>)
-.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryBroadcastComputeEx<cpu, op::mshadow_op::div>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryBroadcastComputeSparseEx<cpu, op::mshadow_op::div>)
 .set_attr<FInferStorageType>("FInferStorageType", BinaryBroadcastMulStorageType)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_div"});
 
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu
index dc0ba02..976f091 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu
+++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu
@@ -29,14 +29,16 @@
 namespace mxnet {
 namespace op {
 NNVM_REGISTER_OP(broadcast_add)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::plus>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::plus>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryBroadcastComputeDenseEx<gpu, op::mshadow_op::plus>);
 
 NNVM_REGISTER_OP(_backward_broadcast_add)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseNone<gpu, mshadow_op::identity,
                                                                 mshadow_op::identity>);
 
 NNVM_REGISTER_OP(broadcast_sub)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::minus>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::minus>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryBroadcastComputeDenseEx<gpu, op::mshadow_op::minus>);
 
 NNVM_REGISTER_OP(_backward_broadcast_sub)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseNone<gpu, mshadow_op::identity,
diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h
index 9d3f6e0..a5b73da 100644
--- a/src/operator/tensor/elemwise_binary_op.h
+++ b/src/operator/tensor/elemwise_binary_op.h
@@ -200,7 +200,7 @@ class ElemwiseBinaryOp : public OpBase {
     }
   }
 
- protected:
+ public:
   /*! \brief Binary op handling for lhr/rhs: RspDns, RspRsp, DnsRsp, or RspRsp->Dns result */
   template<typename OP>
   static void RspRspOp(mshadow::Stream<cpu> *s,
@@ -231,7 +231,7 @@ class ElemwiseBinaryOp : public OpBase {
 
   /*! \brief CSR -op- CSR binary operator for non-canonical NDArray */
   template<typename OP>
-  static inline void CsrCsrOp(mshadow::Stream<cpu> *s,
+  static void CsrCsrOp(mshadow::Stream<cpu> *s,
                               const nnvm::NodeAttrs &attrs,
                               const OpContext &ctx,
                               const NDArray &lhs,
@@ -241,7 +241,7 @@ class ElemwiseBinaryOp : public OpBase {
 
   /*! \brief CSR -op- CSR binary operator for non-canonical NDArray */
   template<typename OP>
-  static inline void CsrCsrOp(mshadow::Stream<gpu> *s,
+  static void CsrCsrOp(mshadow::Stream<gpu> *s,
                               const nnvm::NodeAttrs &attrs,
                               const OpContext &ctx,
                               const NDArray &lhs,
@@ -251,7 +251,7 @@ class ElemwiseBinaryOp : public OpBase {
 
   /*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
   template<typename xpu, typename OP>
-  static inline void DnsCsrDnsOp(mshadow::Stream<xpu> *s,
+  static void DnsCsrDnsOp(mshadow::Stream<xpu> *s,
                                  const nnvm::NodeAttrs &attrs,
                                  const OpContext &ctx,
                                  const NDArray &lhs,
@@ -262,7 +262,7 @@ class ElemwiseBinaryOp : public OpBase {
 
   /*! \brief DNS -op- RSP binary operator for non-canonical NDArray */
   template<typename xpu, typename OP>
-  static inline void DnsRspDnsOp(mshadow::Stream<xpu> *s,
+  static void DnsRspDnsOp(mshadow::Stream<xpu> *s,
                                  const nnvm::NodeAttrs &attrs,
                                  const OpContext &ctx,
                                  const NDArray &lhs,
diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py
index 2d5ed5a..f21334e 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -1769,6 +1769,31 @@ def test_sparse_embedding():
             check_sparse_embedding(in_dim, out_dim, batch, densities, deterministic, stype)
 
 @with_seed()
+def test_sparse_broadcast_add_sub():
+    def check_broadcast_add(mx_lhs, mx_rhs, np_lhs, np_rhs, dtype):
+        assert_almost_equal(mx.nd.sparse.add(mx_lhs, mx_rhs).asnumpy(), np.add(np_lhs, np_rhs), atol=1e-4)
+    def check_broadcast_sub(mx_lhs, mx_rhs, np_lhs, np_rhs, dtype):
+        assert_almost_equal(mx.nd.sparse.subtract(mx_lhs, mx_rhs).asnumpy(), np.subtract(np_lhs, np_rhs), atol=1e-4)
+    stype = 'csr'
+    shape = rand_shape_2d()
+    num_rows = shape[0]
+    num_cols = shape[1]
+    for density in [0.1 * i for i in range(10)]:
+        mx_lhs = rand_ndarray(shape, stype, density)
+        np_lhs = mx_lhs.asnumpy()
+        mx_rhs_row_2D = rand_ndarray((1, num_cols), 'default')
+        mx_rhs_row_1D = mx_rhs_row_2D.reshape((num_cols))
+        mx_rhs_col = rand_ndarray((num_rows, 1), 'default')
+        mx_rhs_scalar_2D = rand_ndarray((1, 1), 'default')
+        mx_rhs_scalar_1D = mx_rhs_scalar_2D.reshape((1, ))
+        for mx_rhs in [mx_rhs_row_2D, mx_rhs_row_1D, mx_rhs_col, mx_rhs_scalar_2D, mx_rhs_scalar_1D]:
+            np_rhs = mx_rhs.asnumpy()
+            check_broadcast_add(mx_lhs, mx_rhs, np_lhs, np_rhs, np.float32)
+            check_broadcast_sub(mx_lhs, mx_rhs, np_lhs, np_rhs, np.float32)
+            check_broadcast_add(mx_rhs, mx_lhs, np_rhs, np_lhs, np.float32)
+            check_broadcast_sub(mx_rhs, mx_lhs, np_rhs, np_lhs, np.float32)
+
+@with_seed()
 def test_sparse_broadcast_mul_div():
     def check_broadcast_mul(mx_lhs, mx_rhs, np_lhs, np_rhs, dtype):
         assert_almost_equal(mx.nd.sparse.multiply(mx_lhs, mx_rhs).asnumpy(), np.multiply(np_lhs, np_rhs), atol=1e-4)

-- 
To stop receiving notification emails like this one, please contact
haibin@apache.org.