You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/10 17:46:27 UTC

[GitHub] eric-haibin-lin closed pull request #10939: [MXNET-420] broadcast_mul/div between csr and 1D dense on GPU

eric-haibin-lin closed pull request #10939: [MXNET-420] broadcast_mul/div between csr and 1D dense on GPU
URL: https://github.com/apache/incubator-mxnet/pull/10939
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h
index 933d0be4658..e5b77e11283 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -88,16 +88,13 @@ inline bool BinaryBroadcastMulStorageType(const nnvm::NodeAttrs& attrs,
   const int rhs_stype = in_attrs->at(1);
   int& out_stype = out_attrs->at(0);
   bool dispatched = false;
-  // For GPU, directly fallback
-  const auto dispatch_ex = (dev_mask == mshadow::gpu::kDevMask)? DispatchMode::kFComputeFallback :
-                           DispatchMode::kFComputeEx;
   if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
     dispatched = storage_type_assign(&out_stype, kDefaultStorage,
                                      dispatch_mode, DispatchMode::kFCompute);
   }
   if (!dispatched && lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) {
     dispatched = storage_type_assign(&out_stype, kCSRStorage,
-                                     dispatch_mode, dispatch_ex);
+                                     dispatch_mode, DispatchMode::kFComputeEx);
   }
   if (!dispatched) {
     dispatched = dispatch_fallback(out_attrs, dispatch_mode);
@@ -229,12 +226,20 @@ struct binary_broadcast_kernel {
   }
 };
 
-template<int req, typename OP>
+template<int req, typename OP, bool col_vec>
 struct csr_dns_csr_broadcast_kernel {
-  template <typename DType, typename CType, typename RType>
+  /*!
+   * \brief Map function for broadcast between csr and 1D vector
+   * \param row          global thread id/assigned row id
+   * \param csr_data     ptr to data buffer of csr matrix
+   * \param csr_indices  ptr to indices buffer of csr matrix
+   * \param csr_indptr   ptr to indptr buffer of csr matrix
+   * \param dns          ptr to data buffer of the dense vector
+   * \param out          ptr to the data buffer of the result csr matrix
+   */
+  template<typename DType, typename CType, typename RType>
   MSHADOW_XINLINE static void Map(int row, const DType *csr_data, const CType *csr_indices,
-                                  const RType *csr_indptr, const DType *dns,
-                                  DType *out, const nnvm::dim_t row_length, bool col_vec) {
+                                  const RType *csr_indptr, const DType *dns, DType *out) {
     const nnvm::dim_t curr_row_i = csr_indptr[row];
     const nnvm::dim_t next_row_i = csr_indptr[row + 1];
     for (nnvm::dim_t iter = curr_row_i; iter < next_row_i; iter++) {
@@ -242,6 +247,23 @@ struct csr_dns_csr_broadcast_kernel {
                     (col_vec)? dns[row] : dns[csr_indices[iter]]));
     }
   }
+
+  /*!
+   * \brief Map function for broadcast between csr and a scalar
+   * \param i           global thread id
+   * \param csr_data    ptr to data buffer of csr matrix
+   * \param scalar_ptr  ptr to data buffer of the scalar tensor, only the 0-th element is used
+   * \param out         ptr to the data buffer of output csr matrix
+   * \param nnz         number of non-zero elements in input csr matrix
+   */
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, const DType *csr_data, const DType* scalar_ptr,
+                                  DType *out, const nnvm::dim_t nnz) {
+    const DType scale = scalar_ptr[0];
+    if (i < nnz) {
+      KERNEL_ASSIGN(out[i], req, OP::Map(csr_data[i], scale));
+    }
+  }
 };
 
 template<int req, typename OP, bool reverse = false>
@@ -320,21 +342,31 @@ void BinaryBroadcastCsrDnsCsrImpl(const OpContext& ctx,
       MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), CType, {
         MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIndPtr), RType, {
           MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+            // broadcast_mul/div between csr and a scalar case
             if ((dns.shape().ndim() == 2 && dns.shape()[0] == 1 && dns.shape()[1] == 1) ||
                 (dns.shape().ndim() == 1 && dns.shape()[0] == 1)) {
-              Kernel<op_with_req<OP, req_type>, xpu>::Launch(
-                s, nnz, output.data().dptr<DType>(), csr.data().dptr<DType>(),
-                dns.data().dptr<DType>()[0]);
+              Kernel<csr_dns_csr_broadcast_kernel<req_type, OP, false>, xpu>::Launch(
+                s, nnz, csr.data().dptr<DType>(), dns.data().dptr<DType>(),
+                output.data().dptr<DType>(), nnz);
             } else {
-              Kernel<csr_dns_csr_broadcast_kernel<req_type, OP>, xpu>::Launch(
-                s, num_rows, csr.data().dptr<DType>(), csr.aux_data(kIdx).dptr<CType>(),
-                csr.aux_data(kIndPtr).dptr<RType>(), dns.data().dptr<DType>(),
-                output.data().dptr<DType>(), csr.shape()[1], col_vec);
+              // broadcast_mul/div between csr and column vector
+              if (col_vec) {
+                Kernel<csr_dns_csr_broadcast_kernel<req_type, OP, true>, xpu>::Launch(
+                  s, num_rows, csr.data().dptr<DType>(), csr.aux_data(kIdx).dptr<CType>(),
+                  csr.aux_data(kIndPtr).dptr<RType>(), dns.data().dptr<DType>(),
+                  output.data().dptr<DType>());
+              // broadcast_mul/div between csr and row vector
+              } else {
+                Kernel<csr_dns_csr_broadcast_kernel<req_type, OP, false>, xpu>::Launch(
+                  s, num_rows, csr.data().dptr<DType>(), csr.aux_data(kIdx).dptr<CType>(),
+                  csr.aux_data(kIndPtr).dptr<RType>(), dns.data().dptr<DType>(),
+                  output.data().dptr<DType>());
+              }
             }
             Copy(output.aux_data(kIdx).FlatTo1D<xpu, CType>(),
-                 csr.aux_data(kIdx).FlatTo1D<xpu, CType>());
+                 csr.aux_data(kIdx).FlatTo1D<xpu, CType>(), s);
             Copy(output.aux_data(kIndPtr).FlatTo1D<xpu, RType>(),
-                 csr.aux_data(kIndPtr).FlatTo1D<xpu, RType>());
+                 csr.aux_data(kIndPtr).FlatTo1D<xpu, RType>(), s);
           });
         });
       });
@@ -432,10 +464,12 @@ void BinaryBroadcastComputeSparseEx(const nnvm::NodeAttrs& attrs,
   const auto out_stype = out.storage_type();
   // If the input is a matrix with the same shape, should be elemwise
   if ((rhs.shape().ndim() != 1U) && (rhs.shape()[0] != 1) && (rhs.shape()[1] != 1)) {
-    // Currently do not support elementwise_mul/div(csr, dense) = csr, log and exit
-    using common::operator_string;
-    LOG(FATAL) << operator_string(attrs, ctx, inputs, req, outputs)
-               << "\nIf shape of lhs and rhs match, please explicitly use elemwise_mul/div\n";
+    if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kCSRStorage) {
+      const bool supported_op = std::is_same<OP, mshadow_op::mul>::value;
+      CHECK(supported_op)
+        << "Please use elemwise_div for division between csr and dense of the same shape";
+      ElemwiseBinaryOp::DnsCsrCsrOp<xpu, mshadow_op::mul>(attrs, ctx, rhs, lhs, req[0], out, true);
+    }
   } else {
     // broadcast(CSR, Dense(1D)) = CSR
     if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kCSRStorage) {
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
index 61bc94e4df1..7301eced179 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
@@ -141,7 +141,7 @@ Example::
 
 Supported sparse operations:
 
-   broadcast_mul(csr, dense(1D)) = csr (CPU only)
+   broadcast_mul(csr, dense(1D)) = csr
 
 )code" ADD_FILELINE)
 .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::mul>)
@@ -182,7 +182,7 @@ Example::
 
 Supported sparse operations:
 
-   broadcast_div(csr, dense(1D)) = csr (CPU only)
+   broadcast_div(csr, dense(1D)) = csr
 
 )code" ADD_FILELINE)
 .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::div>)
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu
index 976f09152a3..a00330d2446 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu
+++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu
@@ -45,14 +45,16 @@ NNVM_REGISTER_OP(_backward_broadcast_sub)
                                                                 mshadow_op::negation>);
 
 NNVM_REGISTER_OP(broadcast_mul)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::mul>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::mul>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryBroadcastComputeSparseEx<gpu, op::mshadow_op::mul>);
 
 NNVM_REGISTER_OP(_backward_broadcast_mul)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, mshadow_op::right,
                                                                 mshadow_op::left>);
 
 NNVM_REGISTER_OP(broadcast_div)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::div>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::div>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", BinaryBroadcastComputeSparseEx<gpu, op::mshadow_op::div>);
 
 NNVM_REGISTER_OP(_backward_broadcast_div)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, mshadow_op::div_grad,


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services