You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/19 21:07:13 UTC

[incubator-mxnet] branch master updated: optimization for dot(csr.T, dense) = rsp (#8611)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new f79d22d  optimization for dot(csr.T, dense) = rsp (#8611)
f79d22d is described below

commit f79d22db25847453a9a286eb19e9064c246a82d4
Author: Ziyue Huang <zy...@gmail.com>
AuthorDate: Mon Nov 20 05:07:10 2017 +0800

    optimization for dot(csr.T, dense) = rsp (#8611)
    
    * optimization for dot(csr.T, dense) = rsp
    
    * remove unneccessary headers
    
    * load balance
    
    * minor fix and update comments
    
    * resolve
    
    * trigger
    
    * trigger
---
 src/operator/tensor/dot-inl.h | 93 ++++++++++++++++++++++++-------------------
 1 file changed, 52 insertions(+), 41 deletions(-)

diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h
index 7ab4710..2432703 100644
--- a/src/operator/tensor/dot-inl.h
+++ b/src/operator/tensor/dot-inl.h
@@ -30,9 +30,10 @@
 #include <algorithm>
 #include <utility>
 #include <type_traits>
-#include "./init_op.h"
+#include "./util/tensor_util-inl.h"
 #include "../mshadow_op.h"
 #include "../elemwise_op_common.h"
+#include "./init_op.h"
 #include "../mxnet_op.h"
 #ifdef __CUDACC__
 #include "./dot-inl.cuh"
@@ -364,19 +365,17 @@ struct DotCsrTransDnsDnsByRowBlocks {
 
 /*!
  * \brief CPU Kernel of dot(csr.T(), dns) = rsp
- * Parallelization by row blocks.
- * This kernel fills up the row_idx array of the rsp
- * with 1 for nonzero rows and 0 for zero rows.
- * The matrix will be compacted after this kernel call.
+ * Parallelization by row blocks which evenly partition the non-zero rows.
  */
 struct DotCsrTransDnsRspByRowBlocks {
   /*!
    * \brief
    * \param i the i-th thread
    */
-  template<typename DType, typename RType, typename IType, typename CType>
+  template<typename DType, typename IType, typename CType, typename RType>
   MSHADOW_CINLINE static void Map(int i,
                                   DType* out,
+                                  nnvm::dim_t* row_flg_sum,
                                   RType* row_idx,
                                   const DType* data_l,
                                   const IType* indptr_l,
@@ -384,21 +383,25 @@ struct DotCsrTransDnsRspByRowBlocks {
                                   const DType* data_r,
                                   const nnvm::dim_t seg_len,
                                   const nnvm::dim_t num_rows_l,
-                                  const nnvm::dim_t num_rows,
+                                  const nnvm::dim_t nnr,
                                   const nnvm::dim_t num_cols) {
     using nnvm::dim_t;
     const dim_t seg_start = i * seg_len;
-    if (seg_start >= num_rows) return;
+    if (seg_start >= nnr) return;
     const dim_t seg_end = (i + 1) * seg_len;
+    const dim_t col_start = row_idx[seg_start];
+    const dim_t col_end = seg_end >= nnr ? (row_idx[nnr-1] + 1) : row_idx[seg_end];
     for (dim_t j = 0; j < num_rows_l; ++j) {
       if (indptr_l[j] == indptr_l[j+1]) continue;
       const dim_t offset_r = j * num_cols;
       for (IType k = indptr_l[j]; k < indptr_l[j+1]; ++k) {
         const CType col_idx = col_idx_l[k];
-        if (col_idx < seg_start || col_idx >= seg_end) continue;
-        const dim_t offset_out = col_idx * num_cols;
-        row_idx[col_idx] = 1;
+        if (col_idx < col_start || col_idx >= col_end) continue;
+
+        const nnvm::dim_t rsp_row = row_flg_sum[col_idx] - 1;
+        const nnvm::dim_t offset_out = rsp_row * num_cols;
         const DType val = data_l[k];
+
         for (dim_t l = 0; l < num_cols; ++l) {
           out[offset_out+l] += data_r[offset_r+l] * val;
         }
@@ -605,43 +608,51 @@ inline void DotCsrDnsRspImpl(const OpContext& ctx,
   const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
   const TBlob& data_r = rhs;
 
-  // pre-allocate spaces for ret using the dense dimension size
-  ret->CheckAndAlloc({mshadow::Shape1(lhs.shape()[1])});
-  const TBlob data_out = ret->data();
-  const TBlob row_idx_out = ret->aux_data(rowsparse::kIdx);
-
   MSHADOW_SGL_DBL_TYPE_SWITCH(data_l.type_flag_, DType, {  // data type
     MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, {  // indptr type
       MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, {  // col idx type
-        MSHADOW_IDX_TYPE_SWITCH(row_idx_out.type_flag_, RType, {  // row idx type
+        MSHADOW_IDX_TYPE_SWITCH(ret->aux_type(rowsparse::kIdx), RType, {  // row idx type
+          const dim_t num_rows = lhs.shape()[1];
+          size_t workspace_size = 2 * (num_rows * sizeof(dim_t));
+          mshadow::Tensor<cpu, 1, char> workspace =
+            ctx.requested[0].get_space_typed<cpu, 1, char>(
+            mshadow::Shape1(workspace_size), s);
+          dim_t* row_flg = reinterpret_cast<dim_t*>(workspace.dptr_);
+          dim_t* prefix_sum = row_flg + num_rows;
+
+          Fill<false>(s, TBlob(row_flg, mshadow::Shape1(num_rows), cpu::kDevMask), kWriteTo, 0);
+          mxnet_op::Kernel<MarkRowFlgKernel, cpu>::Launch(s, lhs.aux_shape(csr::kIdx)[0], row_flg,
+            col_idx_l.dptr<CType>());
+
+          prefix_sum[0] = row_flg[0];
+          for (nnvm::dim_t i = 1; i < num_rows; i++) {
+            prefix_sum[i] = prefix_sum[i - 1] + row_flg[i];
+          }
+          dim_t nnr = prefix_sum[num_rows - 1];
+
+          if (nnr == 0) {
+            FillZerosRspImpl(s, *ret);
+            return;
+          }
+
+          ret->CheckAndAlloc({mshadow::Shape1(nnr)});
+          const TBlob& data_out = ret->data();
+          const TBlob& row_idx = ret->aux_data(rowsparse::kIdx);
+
           dim_t num_threads = data_out.Size();
           mxnet_op::Kernel<set_zero, cpu>::Launch(s, num_threads, data_out.dptr<DType>());
-          RType* row_idx = row_idx_out.dptr<RType>();
-          num_threads = row_idx_out.Size();
-          mxnet_op::Kernel<set_zero, cpu>::Launch(s, num_threads, row_idx);
-          num_threads = mxnet_op::get_num_threads<cpu>(data_out.shape_[0]);
-          dim_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads;
+          RType* row_idx_out = row_idx.dptr<RType>();
+
+          mxnet_op::Kernel<FillRspRowIdxKernel, cpu>::Launch(s, num_rows,
+            row_idx_out, prefix_sum, num_rows);
+
+          num_threads = mxnet_op::get_num_threads<cpu>(nnr);
+          dim_t seg_len = (nnr + num_threads - 1) / num_threads;
           if (trans_lhs) {
             mxnet_op::Kernel<DotCsrTransDnsRspByRowBlocks, cpu>::Launch(s, num_threads,
-                data_out.dptr<DType>(), row_idx, data_l.dptr<DType>(),
-                indptr_l.dptr<IType>(), col_idx_l.dptr<CType>(), data_r.dptr<DType>(),
-                seg_len, lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]);
-            dim_t nnr = 0;
-            nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr);
-            if (0 == nnr) {
-              FillZerosRspImpl(s, *ret);
-              return;
-            }
-            ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr));
-            mshadow::Tensor<cpu, 2, DType> rsp_data = data_out.FlatTo2D<cpu, DType>(s);
-            dim_t idx = 0;
-            for (index_t i = 0; i < ret->shape()[0]; ++i) {
-              if (row_idx[i] > 0) {
-                row_idx[idx] = i;
-                mshadow::Copy(rsp_data[idx], rsp_data[i], s);
-                ++idx;
-              }
-            }
+              data_out.dptr<DType>(), prefix_sum, row_idx_out, data_l.dptr<DType>(),
+              indptr_l.dptr<IType>(), col_idx_l.dptr<CType>(), data_r.dptr<DType>(),
+              seg_len, lhs.shape()[0], nnr, ret->shape()[1]);
           } else {
             LOG(FATAL) << "DotCsrDnsRspImpl has not implemented dot(csr, dns)=rsp yet.";
           }

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].