You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/01/04 03:30:40 UTC

[incubator-mxnet] branch master updated: Add operator for dot(dns, csr) = csr (#8938)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 8505442  Add operator for dot(dns, csr) = csr (#8938)
8505442 is described below

commit 85054429e7d8786d39153b051c24c9427dc6c084
Author: Anirudh Subramanian <an...@gmail.com>
AuthorDate: Wed Jan 3 19:30:35 2018 -0800

    Add operator for dot(dns, csr) = csr (#8938)
    
    * Add operator for dot(dns, csr) = csr
    
    * Fix whitespace
    
    * Add comments
    
    * Add comments and fix error message
    
    * Fixes for dot dns csr
    
    * Fixes
    
    * Remove non required statements
    
    * Add fallback for GPU
    
    * Remove unused if
    
    * Fix comments and casting
    
    * Add operator to the documentation
---
 benchmark/python/sparse/dot.py                |  51 +++++--
 include/mxnet/ndarray.h                       |   5 +-
 src/operator/tensor/dot-inl.h                 | 201 ++++++++++++++++++++++++--
 src/operator/tensor/dot.cc                    |   1 +
 tests/python/unittest/test_sparse_operator.py |  27 ++++
 5 files changed, 260 insertions(+), 25 deletions(-)

diff --git a/benchmark/python/sparse/dot.py b/benchmark/python/sparse/dot.py
index 164e50a..5cfd540 100644
--- a/benchmark/python/sparse/dot.py
+++ b/benchmark/python/sparse/dot.py
@@ -275,7 +275,10 @@ def test_dot_synthetic(data_dict):
         # Create matrix instances
         lhs_nd = rand_ndarray(lhs_shape, lhs_stype, density=lhs_den, distribution=distribution)
         # only uniform distribution supported for rhs
-        rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=rhs_den, distribution="uniform")
+        if rhs_stype == 'csr':
+            rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=rhs_den, distribution=distribution)
+        else:
+            rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=rhs_den, distribution="uniform")
         lhs_dns = None
         rhs_dns = None
         dense_cost = None
@@ -337,27 +340,41 @@ def test_dot_synthetic(data_dict):
 
     def run_benchmark(ctx=None, lhs="csr", lhs_trans=False, rhs="dns", fw="mxnet", rhs_density=1,
                       distribution="uniform"):
-        if lhs != "csr":
-            raise ValueError("Value other than csr for lhs not supported")
+
         if rhs_density > 1 or rhs_density < 0:
             raise ValueError("rhs_density has to be between 0 and 1")
 
         print_benchmark_info(lhs, rhs, lhs_trans, fw)
 
+        if rhs == "csr":
+            lhs_stype = "default"
+            rhs_stype = "csr"
+            assert (lhs_stype == 'default'), "Only dot(default, csr) supported"
+            # Arrange dimensions according to use case. For below csr will have num_rows << num_cols
+            feature_dim_list = data_dict['batch_size']
+            batch_size_list = data_dict['m']
+            output_dim_list = data_dict['feature_dim']
+            density_list = data_dict['density']
+            default_output_index = data_dict['default_index']['feature_dim']
+            default_density_index = data_dict['default_index']['density']
+            default_feature_index = data_dict['default_index']['batch_size']
+            default_batch_size_index = data_dict['default_index']['output_dim']
+            num_repeat = data_dict['num_repeat']
 
-        lhs_stype = "csr"
-        rhs_stype = "row_sparse" if rhs == "rsp" else "default"
+        else:
+            lhs_stype = "csr"
+            rhs_stype = "row_sparse" if rhs == "rsp" else "default"
 
-        feature_dim_list = data_dict['feature_dim']
-        output_dim_list = data_dict['m']
-        batch_size_list = data_dict['batch_size']
-        density_list = data_dict['density']
+            feature_dim_list = data_dict['feature_dim']
+            output_dim_list = data_dict['m']
+            batch_size_list = data_dict['batch_size']
+            density_list = data_dict['density']
 
-        default_output_index = data_dict['default_index']['output_dim']
-        default_batch_size_index = data_dict['default_index']['batch_size']
-        default_feature_index = data_dict['default_index']['feature_dim']
-        default_density_index = data_dict['default_index']['density']
-        num_repeat = data_dict['num_repeat']
+            default_output_index = data_dict['default_index']['output_dim']
+            default_batch_size_index = data_dict['default_index']['batch_size']
+            default_feature_index = data_dict['default_index']['feature_dim']
+            default_density_index = data_dict['default_index']['density']
+            num_repeat = data_dict['num_repeat']
 
         for output_dim in output_dim_list:
             if lhs_trans:
@@ -403,7 +420,7 @@ def test_dot_synthetic(data_dict):
                        feature_dim_list[default_feature_index]),
                       (output_row_dim,
                        output_dim_list[default_output_index]),
-                      lhs_stype, rhs_stype, density, rhs_density, lhs_trans, ctx,
+                      lhs_stype, rhs_stype, density, density, lhs_trans, ctx,
                       num_repeat=num_repeat, fw=fw, distribution=distribution)
 
     check_call(_LIB.MXSetNumOMPThreads(ctypes.c_int(ARGS.num_omp_threads)))
@@ -423,6 +440,10 @@ def test_dot_synthetic(data_dict):
                       rhs="rsp", lhs_trans=False,
                       fw="mxnet", rhs_density=0.05,
                       distribution=distribution)
+        run_benchmark(context, lhs="default",
+                      rhs="csr", lhs_trans=False,
+                      fw="mxnet", rhs_density=0.001,
+                      distribution=distribution)
         if not ARGS.gpu:
             run_benchmark(context, lhs="csr",
                           rhs="default", lhs_trans=False,
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index 8398b7b..a18d2da 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -305,7 +305,10 @@ class NDArray {
   bool fresh_out_grad() const;
   /*! \return updated grad state in entry_ */
   void set_fresh_out_grad(bool state) const;
-  // returns true if a sparse ndarray's aux_data and storage are initialized
+  /*! \brief Returns true if a sparse ndarray's aux_data and storage are initialized
+   * Throws an exception if the indices array shape is inconsistent
+   * Returns false if the indices array is empty(nnz = 0) for csr/row_sparse
+   */
   inline bool storage_initialized() const {
     if (is_none()) return false;
     auto stype = storage_type();
diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h
index 2432703..244f34e 100644
--- a/src/operator/tensor/dot-inl.h
+++ b/src/operator/tensor/dot-inl.h
@@ -202,22 +202,25 @@ void DotBackward_(const nnvm::NodeAttrs& attrs,
 inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
                                        const int dev_mask,
                                        DispatchMode* dispatch_mode,
-                                       std::vector<int> *in_attrs,
-                                       std::vector<int> *out_attrs) {
+                                       std::vector<int>* in_attrs,
+                                       std::vector<int>* out_attrs) {
   CHECK_EQ(in_attrs->size(), 2U);
   CHECK_EQ(out_attrs->size(), 1U);
   const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
-  // csr has many zero columns, so the result of dot(csr.T, matrix) should be rsp
+  // csr has many zero columns, so the result of dot(csr.T, matrix) should be
+  // rsp
   const auto& lhs_stype = in_attrs->at(0);
   const auto& rhs_stype = in_attrs->at(1);
   auto& out_stype = out_attrs->at(0);
   bool dispatched = false;
   bool only_lhs_transpose = param.transpose_a && !param.transpose_b;
-  bool rhs_rsp_or_dns  = rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage;
-  if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kDefaultStorage) {
+  bool rhs_rsp_or_dns =
+      rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage;
+  if (!dispatched && lhs_stype == kDefaultStorage &&
+      rhs_stype == kDefaultStorage) {
     // dns, dns -> dns
-    dispatched = storage_type_assign(&out_stype, kDefaultStorage,
-                                     dispatch_mode, DispatchMode::kFCompute);
+    dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
+                                     DispatchMode::kFCompute);
   }
   if (!dispatched && lhs_stype == kCSRStorage && only_lhs_transpose &&
       (rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage)) {
@@ -228,11 +231,22 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs,
   if (!dispatched && lhs_stype == kCSRStorage && rhs_rsp_or_dns &&
       !param.transpose_a && !param.transpose_b) {
     // csr, rsp/dns -> dns
-    dispatched = storage_type_assign(&out_stype, kDefaultStorage,
-                                     dispatch_mode, DispatchMode::kFComputeEx);
+    dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
+                                     DispatchMode::kFComputeEx);
+  }
+  if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage &&
+      !param.transpose_a && !param.transpose_b) {
+    // dns, csr -> csr
+    const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
+    const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback
+                                         : DispatchMode::kFComputeEx;
+    dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode,
+                                     dispatch_ex);
   }
   if (!dispatched) {
     dispatch_fallback(out_attrs, dispatch_mode);
+  }
+  if (static_cast<DispatchMode>(*dispatch_mode) == DispatchMode::kFComputeFallback) {
     LogStorageFallback(attrs, dev_mask, in_attrs, out_attrs);
   }
   return true;
@@ -528,6 +542,80 @@ struct DotCsrTransRspRspByRowBlocks {
 };
 
 /*!
+ * \brief CPU Kernel of PopulateCsrForNNC
+ * Parallelization by individual rows
+ * Populates the indptr and indices array
+ * based on number of non zero columns
+ */
+struct PopulateCsrForNNC {
+  /*!
+   * \brief
+   * \param i the i-th thread
+   * \param nnc_idx all non zero column indexes
+   * \param indptr_out indptr array for output
+   * \param col_idx_out column indices for output
+   * \param nnc number of non zero columns in the output
+   * \param num_rows_l number of rows in lhs
+   */
+  template <typename IType, typename CType>
+  MSHADOW_CINLINE static void Map(int i, const CType* nnc_idx,
+                                  IType* indptr_out, CType* col_idx_out,
+                                  const nnvm::dim_t nnc,
+                                  const nnvm::dim_t num_rows_l) {
+    const CType start_idx = i * nnc;
+    nnvm::dim_t cur = 0;
+    indptr_out[i] = start_idx;
+    if (static_cast<nnvm::dim_t>(i) == (num_rows_l - 1)) indptr_out[i + 1] = indptr_out[i] + nnc;
+    for (IType idx = start_idx; idx < (start_idx + nnc); idx++) {
+      col_idx_out[idx] = nnc_idx[cur++];
+    }
+  }
+};
+
+/*!
+ * \brief CPU Impl of dot(dns, csr) = csr
+ */
+struct DotDnsCsrCsrByRowBlocks {
+  /*!
+   * \brief
+   * \param i the i-th thread
+   * \param num_rows_r number of rows in rhs
+   * \param num_rows_l number of rows in lhs
+   * \param num_cols number of columns in output
+   * \param nnc number of non zero columns
+   */
+
+  template <typename DType, typename IType, typename CType>
+  MSHADOW_CINLINE static void Map(
+      int i, DType* out, const DType* data_l, const IType* indptr_r,
+      const CType* col_idx_r, const DType* data_r, const nnvm::dim_t seg_len,
+      const IType num_rows_r, const IType num_rows_l,
+      const nnvm::dim_t num_cols, const nnvm::dim_t nnc,
+      const CType* prefix_sum) {
+    using nnvm::dim_t;
+    const dim_t seg_start = i * seg_len;
+    if (seg_start >= num_rows_l) return;
+    const dim_t seg_end = std::min(seg_start + seg_len, num_rows_l);
+
+    for (dim_t j = seg_start; j < seg_end; j++) {
+      for (dim_t k = 0; k < num_rows_r; k++) {
+        const dim_t working_idx = j * num_rows_r + k;
+        const DType val = data_l[working_idx];
+        if (indptr_r[k] == indptr_r[k + 1]) continue;
+        const dim_t row_start = j * nnc;
+        for (dim_t cur = indptr_r[k]; cur < indptr_r[k + 1]; cur++) {
+          dim_t cur_col_idx_r = col_idx_r[cur];
+          const dim_t out_idx = row_start + prefix_sum[cur_col_idx_r] - 1;
+          out[out_idx] += val * data_r[cur];
+        }
+      }
+    }
+  }
+};
+
+
+
+/*!
  * \brief CPU Impl of dot(csr, dns1) = dns2 and dot(csr.T, dns1) = dns2
  */
 inline void DotCsrDnsDnsImpl(const OpContext& ctx,
@@ -811,6 +899,96 @@ inline void DotCsrRspRspImpl(const OpContext& ctx,
   });
 }
 
+/*
+ * \brief CPU Impl of dot(dns, csr) = csr
+ */
+template<typename xpu>
+inline void DotDnsCsrCsrImpl(const OpContext& ctx,
+                             const TBlob& lhs, const NDArray& rhs,
+                             const OpReqType req, NDArray* ret) {
+  if (kNullOp == req) return;
+
+  CHECK_EQ(req, kWriteTo);
+  CHECK_EQ(rhs.storage_type(), kCSRStorage);
+
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using nnvm::dim_t;
+
+  /* Initialize data structures */
+  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
+  const NDArray& out = *ret;
+  const TBlob data_l = lhs;
+  const TBlob data_r = rhs.data();
+  const TBlob indptr_r = rhs.aux_data(csr::kIndPtr);
+  const TBlob col_idx_r = rhs.aux_data(csr::kIdx);
+  if (!rhs.storage_initialized()) {
+    FillZerosCsrImpl(s, *ret);
+    return;
+  }
+
+  MSHADOW_SGL_DBL_TYPE_SWITCH(data_r.type_flag_, DType, {     // data type
+    MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, {     // indptr type
+      MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, {  // colidx type
+        /* Allocate workspace */
+        CType num_cols_out = out.shape()[1];
+        CType rhs_data_size = static_cast<CType>(col_idx_r.shape_.Size());
+        size_t workspace_size = 2 * num_cols_out * sizeof(CType);
+        Tensor<cpu, 1, char> workspace =
+            ctx.requested[0].get_space_typed<cpu, 1, char>(
+                Shape1(workspace_size), s);
+        CType* col_flg = reinterpret_cast<dim_t*>(workspace.dptr_);
+
+        CType* prefix_sum = col_flg;
+        CType* nnc_idx = prefix_sum + num_cols_out;
+
+        /* Set the column flags for nnz columns */
+        mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, num_cols_out,
+                                                          col_flg);
+        mxnet_op::Kernel<MarkRowFlgKernel, cpu>::Launch(
+            s, rhs_data_size, col_flg, col_idx_r.dptr<CType>());
+
+        /* 1. Calculate prefix sum from col flgs
+         * 2. Storage all non zero column indexes in nnc_idx
+         */
+        CType cur = 0;
+        prefix_sum[0] = col_flg[0];
+        if (prefix_sum[0]) nnc_idx[cur++] = 0;
+        for (CType i = 1; i < num_cols_out; i++) {
+          prefix_sum[i] = prefix_sum[i - 1] + col_flg[i];
+          if (prefix_sum[i] > prefix_sum[i - 1]) nnc_idx[cur++] = i;
+        }
+
+        /* Allocate aux data for out */
+        IType num_rows_l = lhs.shape_[0];
+        dim_t nnc = prefix_sum[num_cols_out - 1];
+        dim_t nnz = nnc * num_rows_l;
+        out.CheckAndAllocAuxData(csr::kIndPtr, Shape1(num_rows_l + 1));
+        out.CheckAndAllocAuxData(csr::kIdx, Shape1(nnz));
+        out.CheckAndAllocData(Shape1(nnz));
+
+        /* Set csr indptr and index according to nnc_idx*/
+        IType* indptr_out = out.aux_data(csr::kIndPtr).dptr<IType>();
+        CType* col_idx_out = out.aux_data(csr::kIdx).dptr<CType>();
+        DType* data_out = out.data().dptr<DType>();
+        mxnet_op::Kernel<PopulateCsrForNNC, cpu>::Launch(
+            s, num_rows_l, nnc_idx, indptr_out, col_idx_out, nnc, num_rows_l);
+        mxnet_op::Kernel<mxnet_op::set_zero, cpu>::Launch(s, nnz, data_out);
+
+        const dim_t num_threads = mxnet_op::get_num_threads<cpu>(num_rows_l);
+        const dim_t seg_len = (num_rows_l + num_threads - 1) / num_threads;
+
+        IType num_rows_r = rhs.shape()[0];
+        mxnet_op::Kernel<DotDnsCsrCsrByRowBlocks, cpu>::Launch(
+            s, num_threads, data_out, data_l.dptr<DType>(),
+            indptr_r.dptr<IType>(), col_idx_r.dptr<CType>(),
+            data_r.dptr<DType>(), seg_len, num_rows_r, num_rows_l, num_cols_out,
+            nnc, prefix_sum);
+      });
+    });
+  });
+}
+
 inline bool DotShape(const nnvm::NodeAttrs& attrs,
                      std::vector<TShape> *in_attrs,
                      std::vector<TShape> *out_attrs) {
@@ -886,6 +1064,11 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs,
       && out_stype == kRowSparseStorage && !param.transpose_b) {
     NDArray ret = outputs[0];
     DotCsrRspRspImpl(ctx, xpu(), inputs[0], inputs[1], req[0], param.transpose_a, &ret);
+  } else if (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage &&
+             out_stype == kCSRStorage &&
+             !(param.transpose_a || param.transpose_b)) {
+    NDArray ret = outputs[0];
+    DotDnsCsrCsrImpl<xpu>(ctx, inputs[0].data(), inputs[1], req[0], &ret);
   } else {
     LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
   }
diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc
index a7fa2c7..834b559 100644
--- a/src/operator/tensor/dot.cc
+++ b/src/operator/tensor/dot.cc
@@ -56,6 +56,7 @@ The storage type of ``dot`` output depends on storage types of inputs and transp
 - dot(csr, default) = default
 - dot(csr.T, default) = row_sparse
 - dot(csr, row_sparse) = default
+- dot(default, csr) = csr
 - otherwise, ``dot`` generates output with default storage
 
 )doc" ADD_FILELINE)
diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py
index a56677c..134cb26 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -1223,6 +1223,31 @@ def test_sparse_dot():
                                 grad_req={'lhs': 'null', 'rhs': 'write'},
                                 rtol=1e-3, atol=1e-4)
 
+    def test_dot_dns_csr(lhs_shape, rhs_shape, lhs_density, rhs_density, trans_lhs=False, trans_rhs=False):
+        lhs_nd = rand_ndarray(lhs_shape, stype='default', density=lhs_density)
+        rhs_nd = rand_ndarray(rhs_shape, stype='csr', density=rhs_density)
+        rhs_dns = rhs_nd.tostype('default')
+
+        out = mx.nd.sparse.dot(lhs_nd, rhs_nd, transpose_a=trans_lhs, transpose_b=trans_rhs)
+        out_dns = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs, transpose_b=trans_rhs)
+        out_np = out_dns.asnumpy()
+        assert_almost_equal(out.asnumpy(), out_np, rtol=1e-4, atol=1e-5)
+
+        # test symbolic forward
+        lhs = mx.symbol.Variable('lhs', stype='default')
+        rhs = mx.symbol.Variable('rhs', stype='csr')
+        out = mx.symbol.sparse.dot(lhs, rhs, transpose_a=trans_lhs, transpose_b=trans_rhs)
+        location = {'lhs': lhs_nd, 'rhs': rhs_nd}
+        check_symbolic_forward(out, location, [out_np], rtol=1e-3, atol=1e-4)
+
+        # test symbolic backward
+        backward_trans = not trans_lhs
+        rhs_backward_grad = mx.nd.dot(lhs_nd, out_dns, transpose_a=backward_trans).asnumpy()
+        expected = {'rhs': rhs_backward_grad}
+        check_symbolic_backward(out, location, [out_np], expected,
+                                grad_req={'lhs': 'null', 'rhs': 'write'},
+                                rtol=1e-3, atol=1e-4)
+
     def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols):
         """Test for nnr_out = 0. Before the fix, the test would fail."""
         lhs = mx.nd.zeros(lhs_shape)
@@ -1248,10 +1273,12 @@ def test_sparse_dot():
         test_dot_csr(lhs_shape, (lhs_shape[0], 1), 'default', True,  lhs_d, rhs_d)  # (vector kernel)
         test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(5, 10)), 'default', False, lhs_d, rhs_d)  # test gpu SpMM
         test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(5, 10)), 'default', True, lhs_d, rhs_d)  # (scalar kernel)
+        test_dot_dns_csr(lhs_shape, (lhs_shape[1], rnd.randint(50, 200)), lhs_d, lhs_d)
         for rhs_d in density:
             test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False, lhs_d, rhs_d)
             test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True, lhs_d, rhs_d)
 
+
     test_sparse_dot_zero_output(rand_shape_2d(50, 200), False, 40)
     test_sparse_dot_zero_output(rand_shape_2d(50, 200), True, 40)
 

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].