You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/10/20 22:08:05 UTC

[incubator-mxnet] branch master updated: sparse support for take(csr, axis=0) (#12889)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0ba259f  sparse support for take(csr, axis=0)  (#12889)
0ba259f is described below

commit 0ba259fe5fbc2c2ba355a9f6faf1a20cc57c6d07
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Sat Oct 20 15:07:48 2018 -0700

    sparse support for take(csr, axis=0)  (#12889)
    
    * initial commit
    
    * add test cases for mode
    
    * fix bug
    
    * add comment
    
    * more comments
---
 src/operator/tensor/indexing_op.cc           | 144 +++++++++++++++++++++++++++
 src/operator/tensor/indexing_op.h            |  65 ++++++++++++
 tests/python/unittest/test_sparse_ndarray.py |  18 ++++
 3 files changed, 227 insertions(+)

diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc
index b663ef0..98e2536 100644
--- a/src/operator/tensor/indexing_op.cc
+++ b/src/operator/tensor/indexing_op.cc
@@ -87,6 +87,142 @@ void SparseEmbeddingOpForwardRspImpl<cpu>(const OpContext& ctx,
   }
 }
 
+template<bool clip>
+struct CsrTakeDataKernel {
+  /*!
+   * \brief Map function for general case of take grad
+   * \param tid           global thread id
+   * \param out_idx       ptr to out idx
+   * \param out_data      ptr to out data
+   * \param out_indptr    ptr to out indptr
+   * \param src_data      ptr to original csr data
+   * \param src_idx       ptr to original csr idx
+   * \param idx_ptr       ptr to indices
+   * \param num_rows      maximum number of rows in src array
+   */
+  template<typename IType, typename DType, typename RType>
+  MSHADOW_XINLINE static void Map(int tid, RType* out_idx, DType* out_data,
+                                  const RType* out_indptr, const RType* src_idx,
+                                  const DType* src_data, const RType* src_indptr,
+                                  const IType* idx_ptr, const nnvm::dim_t num_rows) {
+    nnvm::dim_t idx = static_cast<nnvm::dim_t>(idx_ptr[tid]);
+    // clip mode
+    if (clip) {
+      if (idx < 0) idx = 0;
+      if (idx >= num_rows) idx = num_rows - 1;
+    } else {
+      // wrap mode
+      idx = idx % num_rows;
+      idx += (idx < 0) ? num_rows : 0;
+    }
+    int row_nnz = src_indptr[idx + 1] - src_indptr[idx];
+    for (int i = 0; i < row_nnz; i++) {
+        out_data[out_indptr[tid] + i] = src_data[src_indptr[idx] + i];
+        out_idx[out_indptr[tid] + i] = src_idx[src_indptr[idx] + i];
+    }
+  }
+};
+
+template<bool clip>
+struct CsrTakeRowCountKernel {
+  /*!
+   * \brief Map function for general case of take grad
+   * \param tid           global thread id
+   * \param out_indptr    ptr to out indptr
+   * \param src_indptr    ptr to original csr indptr
+   * \param idx_ptr       ptr to indices
+   * \param num_rows      maximum number of rows in src array
+   */
+  template<typename IType, typename RType>
+  MSHADOW_XINLINE static void Map(int tid, RType* out_indptr,
+                                  const RType* src_indptr, const IType* idx_ptr,
+                                  const nnvm::dim_t num_rows) {
+    if (tid == 0) out_indptr[0] = 0;
+    nnvm::dim_t idx = static_cast<nnvm::dim_t>(idx_ptr[tid - 1]);
+    // clip mode
+    if (clip) {
+      if (idx < 0) idx = 0;
+      if (idx >= num_rows) idx = num_rows - 1;
+    } else {
+      // wrap mode
+      idx = idx % num_rows;
+      idx += (idx < 0) ? num_rows : 0;
+    }
+    out_indptr[tid] = src_indptr[idx + 1] - src_indptr[idx];
+  }
+};
+
+template<>
+void TakeOpForwardCsrImpl<cpu>(const TakeParam& params,
+                               const OpContext& ctx,
+                               const TBlob& idx,
+                               const NDArray& arr,
+                               OpReqType req,
+                               const NDArray& out) {
+  using namespace csr;
+  using namespace mxnet_op;
+  using nnvm::dim_t;
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  if (req == kNullOp) return;
+  if (!arr.storage_initialized()) {
+    FillZerosCsrImpl(s, out);
+    return;
+  }
+  CHECK_EQ(idx.shape_.ndim(), 1U)
+          << "Take with CSR array only supports one-dimensional indices. "
+          << idx.shape_.ndim() << " dimensional input is given instead";
+  CHECK_EQ(req, kWriteTo) << "req = " << req << " is not supported for take(csr)";
+  auto axis = params.axis;
+  CHECK_EQ(axis, 0) << "axis = " << axis << " is not supported for take(csr)";
+  CHECK(params.mode == take_::kClip || params.mode == take_::kWrap)
+    << "mode = " << params.mode << " is not supported";
+  const dim_t num_rows = out.shape()[0];
+  const dim_t max_num_rows = arr.shape()[0];
+  out.CheckAndAllocAuxData(kIndPtr, {Shape1(num_rows + 1)});
+
+  MSHADOW_TYPE_SWITCH(idx.type_flag_, IType, {
+    MSHADOW_SGL_DBL_TYPE_SWITCH(arr.dtype(), DType, {
+      MSHADOW_IDX_TYPE_SWITCH(out.aux_type(kIdx), RType, {
+        RType* out_indptr = out.aux_data(kIndPtr).dptr<RType>();
+        const RType* src_indptr = arr.aux_data(kIndPtr).dptr<RType>();
+        const IType* idx_ptr = idx.dptr<IType>();
+        // gather per row nnz information for output
+        bool clip = params.mode == take_::kClip;
+        if (clip) {
+          Kernel<CsrTakeRowCountKernel<true>, cpu>::Launch(s, num_rows + 1,
+              out_indptr, src_indptr, idx_ptr, max_num_rows);
+        } else {
+          Kernel<CsrTakeRowCountKernel<false>, cpu>::Launch(s, num_rows + 1,
+              out_indptr, src_indptr, idx_ptr, max_num_rows);
+        }
+        // calculate prefix sum with single thread
+        for (dim_t i = 0; i < num_rows; i++) {
+           out_indptr[i + 1] += out_indptr[i];
+        }
+        // total number of non-zero rows
+        const dim_t nnz = out_indptr[num_rows];
+        if (nnz == 0) {
+          FillZerosCsrImpl(s, out);
+          return;
+        }
+        out.CheckAndAllocAuxData(kIdx, {Shape1(nnz)});
+        out.CheckAndAllocData(Shape1(nnz));
+        RType* out_idx = out.aux_data(kIdx).dptr<RType>();
+        DType* out_data = out.data().dptr<DType>();
+        const RType* src_idx = arr.aux_data(kIdx).dptr<RType>();
+        const DType* src_data = arr.data().dptr<DType>();
+        // copy indices and data for output
+        if (clip) {
+          Kernel<CsrTakeDataKernel<true>, cpu>::Launch(s, num_rows, out_idx,
+              out_data, out_indptr, src_idx, src_data, src_indptr, idx_ptr, max_num_rows);
+        } else {
+          Kernel<CsrTakeDataKernel<false>, cpu>::Launch(s, num_rows, out_idx,
+              out_data, out_indptr, src_idx, src_data, src_indptr, idx_ptr, max_num_rows);
+        }
+      });
+    });
+  });
+}
 
 template<>
 inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const bool deterministic,
@@ -400,6 +536,7 @@ dimension of data (by default outer-most one as axis=0) indexed by indices, and
 in an output tensor of rank q + (r - 1).
 
 Examples::
+
   x = [4.  5.  6.]
 
   // Trivial case, take the second element along the first axis.
@@ -431,6 +568,11 @@ Examples::
                                                       [[ 3.,  4.],
                                                        [ 5.,  6.]]]
 
+The storage type of ``take`` output depends upon the input storage type:
+
+   - take(default, default) = default
+   - take(csr, default, axis=0) = csr
+
 )code" ADD_FILELINE)
 .set_num_inputs(2)
 .set_num_outputs(1)
@@ -441,11 +583,13 @@ Examples::
   })
 .set_attr<nnvm::FInferShape>("FInferShape", TakeOpShape)
 .set_attr<nnvm::FInferType>("FInferType", TakeOpType)
+.set_attr<FInferStorageType>("FInferStorageType", TakeOpForwardStorageType)
 .set_attr<FResourceRequest>("FResourceRequest",
   [](const NodeAttrs& attrs) {
     return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
   })
 .set_attr<FCompute>("FCompute<cpu>", TakeOpForward<cpu>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", TakeOpForwardEx<cpu>)
 .set_attr<nnvm::FGradient>("FGradient",
   [](const nnvm::NodePtr& n,  const std::vector<nnvm::NodeEntry>& ograds) {
     return MakeNonlossGradNode("_backward_take", n, ograds,
diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h
index 1daf0a2..5282a7e 100644
--- a/src/operator/tensor/indexing_op.h
+++ b/src/operator/tensor/indexing_op.h
@@ -755,6 +755,71 @@ inline bool TakeOpType(const nnvm::NodeAttrs& attrs,
   return (*in_attrs)[0] != -1;
 }
 
+// storage type inference function for take
+inline bool TakeOpForwardStorageType(const nnvm::NodeAttrs& attrs,
+                                     const int dev_mask,
+                                     DispatchMode* dispatch_mode,
+                                     std::vector<int>* in_attrs,
+                                     std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 2U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  const int& idx_stype = in_attrs->at(take_::kIdx);
+  const int& arr_stype = in_attrs->at(take_::kArr);
+  int& out_stype = out_attrs->at(take_::kOut);
+  bool dispatched = false;
+  const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
+  if (!dispatched && idx_stype == kDefaultStorage && arr_stype == kDefaultStorage) {
+    // dns, dns -> dns
+    dispatched = storage_type_assign(&out_stype, kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFCompute);
+  }
+  if (!dispatched && idx_stype == kDefaultStorage && arr_stype == kCSRStorage &&
+      param.axis == 0 && (param.mode == take_::kWrap || param.mode == take_::kClip)) {
+    // take(dns, csr, axis=0) -> csr
+    dispatched = storage_type_assign(&out_stype, kCSRStorage,
+                                     dispatch_mode, DispatchMode::kFComputeEx);
+  }
+  if (!dispatched) {
+    dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+  }
+  return dispatched;
+}
+
+
+template<typename xpu>
+void TakeOpForwardCsrImpl(const TakeParam& params,
+                          const OpContext& ctx,
+                          const TBlob& idx,
+                          const NDArray& arr,
+                          OpReqType req,
+                          const NDArray& output);
+
+
+template<typename xpu>
+void TakeOpForwardEx(const nnvm::NodeAttrs& attrs,
+                     const OpContext& ctx,
+                     const std::vector<NDArray>& inputs,
+                     const std::vector<OpReqType>& req,
+                     const std::vector<NDArray>& outputs) {
+  CHECK_EQ(req[take_::kOut], kWriteTo);
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+  const NDArray& idx = inputs[take_::kIdx];
+  const NDArray& arr = inputs[take_::kArr];
+  const NDArray& out = outputs[take_::kOut];
+  const auto idx_stype = idx.storage_type();
+  const auto arr_stype = arr.storage_type();
+  const auto out_stype = out.storage_type();
+  const auto params = nnvm::get<TakeParam>(attrs.parsed);
+  if (idx_stype == kDefaultStorage && arr_stype == kCSRStorage &&
+      out_stype == kCSRStorage) {
+    // dns, csr -> csr
+    TakeOpForwardCsrImpl<xpu>(params, ctx, idx.data(), arr, req[0], out);
+  } else {
+    LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
+  }
+}
+
 template<typename xpu>
 void TakeOpForward(const nnvm::NodeAttrs& attrs,
                    const OpContext& ctx,
diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py
index 875dea7..8dd250c 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -1013,6 +1013,24 @@ def test_sparse_fc():
     # test FC with row_sparse weight w/ density=1, csr data (fallback)
     check_sparse_fc(5, 10, 8, 'csr')
 
+@with_seed()
+def test_sparse_take():
+    def check_sparse_take(density, mode):
+        data_shape = rand_shape_2d()
+        idx_shape = (np.random.randint(low=1, high=10),)
+        data = rand_ndarray(data_shape, 'csr', density=density)
+        idx = mx.nd.array(np.random.randint(low=-5, high=15, size=idx_shape))
+        result = mx.nd.take(data, idx, mode=mode)
+        data_np = data.asnumpy()
+        idx_np = idx.asnumpy().astype('int32')
+        expected_result = np.take(data_np, idx_np, mode=mode, axis=0)
+        assert_almost_equal(result.asnumpy(), expected_result)
+    densities = [0, 0.5, 1]
+    modes = ['clip', 'wrap']
+    for d in densities:
+        for m in modes:
+            check_sparse_take(d, m)
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()