You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/06/10 17:43:39 UTC

[incubator-mxnet] branch master updated: Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on CPU (#11113)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 935fc55  Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on CPU (#11113)
935fc55 is described below

commit 935fc555e306594c0b977efb5381673176116bed
Author: XiaotaoChen <ch...@gmail.com>
AuthorDate: Mon Jun 11 01:43:28 2018 +0800

    Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on CPU (#11113)
    
    * implement dot(dns, csr/csr.T)=dns on cpu
    
    * complete documentaion related to dot(dns, csr/csr.T)=dns on cpu
    
    * support fp16 by replacing MSHADOW_SGL_DBL_TYPE_SWITCH with MSHADOW_REAL_TYPE_SWITCH
---
 src/operator/tensor/dot-inl.h | 161 +++++++++++++++++++++++++++++++++++++++---
 src/operator/tensor/dot.cc    |   4 +-
 2 files changed, 154 insertions(+), 11 deletions(-)

diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h
index ffdb706..675cbe8 100644
--- a/src/operator/tensor/dot-inl.h
+++ b/src/operator/tensor/dot-inl.h
@@ -264,11 +264,15 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
   if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage &&
       !param.transpose_a) {
     target_stype = hint_has_value ? target_stype : kCSRStorage;
-    // dns, csr -> csr on CPU
-    if (dev_mask == mshadow::cpu::kDevMask && !param.transpose_b) {
-      if (target_stype == kCSRStorage) {
+    if (dev_mask == mshadow::cpu::kDevMask) {
+      // dns, csr -> csr on CPU
+      if (target_stype == kCSRStorage && !param.transpose_b) {
         dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode,
                                          DispatchMode::kFComputeEx);
+      // dns, csr/csr.T -> dns on CPU
+      } else if (target_stype == kDefaultStorage) {
+        dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
+                                         DispatchMode::kFComputeEx);
       }
     // dns, csr/csr.T -> dns on GPU
     } else if (dev_mask == mshadow::gpu::kDevMask) {
@@ -327,7 +331,7 @@ inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs,
       dispatched = true;
     }
   }
-  if (!dispatched && dev_mask == mshadow::gpu::kDevMask && !param.transpose_a &&
+  if (!dispatched && !param.transpose_a &&
       lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage &&
       ograd_stype == kDefaultStorage) {
     if (type_assign(&lhs_grad_stype, kDefaultStorage) &&
@@ -655,7 +659,101 @@ struct DotDnsCsrCsrByRowBlocks {
   }
 };
 
+/*!
+ * \brief CPU Kernel of dot(dns1, csr) = dns2
+ * Parallelization by row blocks
+ */
+struct DotDnsCsrDnsByRowBlocks {
+  /*!
+   * \brief
+   * \param i           the i-th thread
+   * \param out         output matrix
+   * \param data_l      data of lhs
+   * \param data_r      values of csr
+   * \param indptr_r    row offsets of csr
+   * \param col_idx_r   column indices of csr
+   * \param seg_len     workload of this thread
+   * \param num_rows_l  number of rows in lhs
+   * \param num_cols_l  number of columns in lhs
+   * \param num_rows_r  number of rows in rhs
+   * \param num_cols_r  number of columns in rhs
+   */
+  template<typename DType, typename IType, typename CType>
+  MSHADOW_CINLINE static void Map(int i,
+                                  DType* out,
+                                  const DType* data_l,
+                                  const DType* data_r,
+                                  const IType* indptr_r,
+                                  const CType* col_idx_r,
+                                  const nnvm::dim_t seg_len,
+                                  const nnvm::dim_t num_rows_l,
+                                  const nnvm::dim_t num_cols_l,
+                                  const nnvm::dim_t num_rows_r,
+                                  const nnvm::dim_t num_cols_r) {
+    using nnvm::dim_t;
+    const dim_t seg_start = i * seg_len;
+    if (seg_start >= num_rows_l) return;
+    const dim_t seg_end = std::min(seg_start + seg_len, num_rows_l);
+    for (dim_t j = 0; j < num_rows_r; ++j) {
+      if (indptr_r[j] == indptr_r[j+1]) continue;
+      for (IType k = indptr_r[j]; k < indptr_r[j+1]; ++k) {
+        const CType col_idx = col_idx_r[k];
+        const DType val = data_r[k];
+        for (dim_t r = seg_start; r < seg_end; ++r) {
+          out[r*num_cols_r+col_idx] += data_l[r*num_cols_l+j] * val;
+        }
+      }
+    }
+  }
+};
 
+/*!
+ * \brief CPU Kernel of dot(dns1, csr.T) = dns2
+ * Parallelization by row blocks
+ */
+struct DotDnsCsrTransDnsByRowBlocks {
+  /*!
+   * \brief
+   * \param i           the i-th thread
+   * \param out         output matrix
+   * \param data_l      data of lhs
+   * \param data_r      values of csr
+   * \param indptr_r    row offsets of csr
+   * \param col_idx_r   column indices of csr
+   * \param seg_len     workload of this thread
+   * \param num_rows_l  number of rows in lhs
+   * \param num_cols_l  number of columns in lhs
+   * \param num_rows_r  number of rows in rhs
+   * \param num_cols_r  number of columns in rhs
+   */
+  template<typename DType, typename IType, typename CType>
+  MSHADOW_CINLINE static void Map(int i,
+                                  DType* out,
+                                  const DType* data_l,
+                                  const DType* data_r,
+                                  const IType* indptr_r,
+                                  const CType* col_idx_r,
+                                  const nnvm::dim_t seg_len,
+                                  const nnvm::dim_t num_rows_l,
+                                  const nnvm::dim_t num_cols_l,
+                                  const nnvm::dim_t num_rows_r,
+                                  const nnvm::dim_t num_cols_r) {
+    using nnvm::dim_t;
+    const dim_t seg_start = i * seg_len;
+    if (seg_start >= num_rows_l) return;
+    const dim_t seg_end = std::min(seg_start + seg_len, num_rows_l);
+    for (dim_t j = 0; j < num_rows_r; ++j) {
+      if (indptr_r[j] == indptr_r[j+1]) continue;
+      for (IType k = indptr_r[j]; k < indptr_r[j+1]; ++k) {
+        const CType col_idx = col_idx_r[k];
+        const DType val = data_r[k];
+        for (dim_t r = seg_start; r < seg_end; ++r) {
+          out[r*num_rows_r+j] += data_l[r*num_cols_l+col_idx] * val;
+        }
+      }
+    }
+  }
+};
 
 /*!
  * \brief CPU Impl of dot(csr, dns1) = dns2 and dot(csr.T, dns1) = dns2
@@ -1031,13 +1129,58 @@ inline void DotDnsCsrCsrImpl(const OpContext& ctx, const cpu& cpu_dev,
 }
 
 /*
- * \brief Impl of dot(dns, csr) = dense (GPU only)
+ * \brief Impl of dot(dns, csr) = dns and dot(dns, csr.T) = dns
  */
 inline void DotDnsCsrDnsImpl(const OpContext& ctx, const cpu& cpu_dev,
-                             const TBlob& dns, const NDArray& rhs,
-                             const OpReqType req, NDArray* ret,
-                             const bool transpose_b) {
-  LOG(FATAL) << "dot(dense, csr) = dense is not implemented on CPU";
+  const TBlob& dns, const NDArray& rhs,
+  const OpReqType req, NDArray* ret,
+  const bool transpose_b) {
+  if (req == kNullOp) return;
+  CHECK_EQ(rhs.storage_type(), kCSRStorage);
+  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
+  if (!rhs.storage_initialized()) {
+    FillZerosCsrImpl(s, *ret);
+    return;
+  }
+
+  using nnvm::dim_t;
+
+  const TBlob data_r = rhs.data();
+  const TBlob indptr_r = rhs.aux_data(csr::kIndPtr);
+  const TBlob col_idx_r = rhs.aux_data(csr::kIdx);
+  const TBlob& data_l = dns;
+  const TBlob data_out = ret->data();
+
+  MSHADOW_REAL_TYPE_SWITCH(data_r.type_flag_, DType, {  // data type
+    MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, {  // indptr type
+      MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, {  // col idx type
+        dim_t num_threads;
+        if (req == kWriteTo || req == kWriteInplace) {
+          num_threads = data_out.Size();
+          mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(
+              s, num_threads, data_out.dptr<DType>());
+        }
+        num_threads = mxnet_op::get_num_threads<cpu>(data_out.shape_[0]);
+        // seg by output row
+        dim_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads;
+        if (transpose_b) {
+          mxnet_op::Kernel<DotDnsCsrTransDnsByRowBlocks, cpu>::Launch(s, num_threads,
+              data_out.dptr<DType>(), data_l.dptr<DType>(),
+              data_r.dptr<DType>(), indptr_r.dptr<IType>(),
+              col_idx_r.dptr<CType>(), seg_len,
+              dns.shape_[0], dns.shape_[1],
+              rhs.shape()[0], rhs.shape()[1]);
+        } else {
+          mxnet_op::Kernel<DotDnsCsrDnsByRowBlocks, cpu>::Launch(s, num_threads,
+              data_out.dptr<DType>(), data_l.dptr<DType>(),
+              data_r.dptr<DType>(), indptr_r.dptr<IType>(),
+              col_idx_r.dptr<CType>(), seg_len,
+              dns.shape_[0], dns.shape_[1],
+              rhs.shape()[0], rhs.shape()[1]);
+        }
+      });
+    });
+  });
 }
 
 inline bool DotShape(const nnvm::NodeAttrs& attrs,
diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc
index 2f44f53..556fd1f 100644
--- a/src/operator/tensor/dot.cc
+++ b/src/operator/tensor/dot.cc
@@ -60,8 +60,8 @@ forward_stype option for output storage type. Implemented sparse operations incl
 - dot(csr, default) = default
 - dot(csr, row_sparse) = default
 - dot(default, csr) = csr (CPU only)
-- dot(default, csr, forward_stype='default') = default (GPU only)
-- dot(default, csr, transpose_b=True, forward_stype='default') = default (GPU only)
+- dot(default, csr, forward_stype='default') = default
+- dot(default, csr, transpose_b=True, forward_stype='default') = default
 
 If the combination of input storage types and forward_stype does not match any of the
 above patterns, ``dot`` will fallback and generate output with default storage.

-- 
To stop receiving notification emails like this one, please contact
haibin@apache.org.