You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/07 21:15:47 UTC

[GitHub] eric-haibin-lin commented on a change in pull request #8331: [sparse] slice for csr on two dimensions, cpu implementation

eric-haibin-lin commented on a change in pull request #8331: [sparse] slice for csr on two dimensions, cpu implementation
URL: https://github.com/apache/incubator-mxnet/pull/8331#discussion_r149505376
 
 

 ##########
 File path: src/operator/tensor/matrix_op-inl.h
 ##########
 @@ -601,18 +595,151 @@ void SliceCsrImpl(const SliceParam &param, const OpContext& ctx,
   });
 }
 
+/*!
+ * \brief slice a CSRNDArray for two dimensions
+ */
+struct SliceDimTwoCsrAssign {
+  /*!
+   * \brief This function slices a CSRNDArray on axis one between begin_col and end_col
+   * \param i           loop index
+   * \param out_idx     output csr ndarray column indices    
+   * \param out_data    output csr ndarray data
+   * \param out_indptr  output csr ndarray row index pointer
+   * \param in_idx      input csr ndarray column indices
+   * \param in_data     input csr ndarray data
+   * \param in_indptr   input csr ndarray row index pointer
+   * \param begin_col   begin column indice
+   * \param end_col     end column indice
+   */
+  template<typename IType, typename RType, typename DType>
+  MSHADOW_XINLINE static void Map(int i,
+                                  IType* out_idx, DType* out_data,
+                                  const RType* out_indptr,
+                                  const IType* in_idx, const DType* in_data,
+                                  const RType* in_indptr,
+                                  const int begin_col, const int end_col) {
+    RType ind = out_indptr[i];
+    for (RType j = in_indptr[i]; j < in_indptr[i+1]; j++) {
+      // indices of CSRNDArray are in ascending order per row
+      if (in_idx[j] >= end_col) {
+        break;
+      } else if (in_idx[j] >= begin_col) {
+        out_idx[ind] = in_idx[j] - begin_col;
+        out_data[ind] = in_data[j];
+        ind++;
+      }
+    }
+  }
+};
+
+/*
+ * Slice a CSR NDArray for two dimensions
+ * Only implemented for CPU
+ */
+template<typename xpu>
+void SliceDimTwoCsrImpl(const TShape &begin, const TShape &end, const OpContext& ctx,
+                        const NDArray &in, const NDArray &out) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using namespace csr;
+  CHECK((std::is_same<xpu, cpu>::value)) << "SliceDimTwoCsrImpl is only implemented for CPU";
+  nnvm::dim_t begin_row = begin[0], end_row = end[0];
+  nnvm::dim_t begin_col = begin[1], end_col = end[1];
+  nnvm::dim_t indptr_len = end_row - begin_row + 1;
+  out.CheckAndAllocAuxData(kIndPtr, Shape1(indptr_len));
+  // assume idx indptr share the same type
+  MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIndPtr), RType, {
+    MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIdx), IType, {
+      MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
+        RType *in_indptr = in.aux_data(kIndPtr).dptr<RType>();
+        IType *in_idx = in.aux_data(kIdx).dptr<IType>();
+        DType *in_data = in.data().dptr<DType>();
+        // retrieve nnz (CPU implementation)
+        RType *out_indptr = out.aux_data(kIndPtr).dptr<RType>();
+        int nnz = 0;
+        out_indptr[0] = 0;
+        // loop through indptr array and corresponding indices to count for nnz
+        for (nnvm::dim_t i = 0; i < indptr_len - 1; i++) {
+          out_indptr[i+1] = out_indptr[i];
+          for (RType j = in_indptr[i + begin_row];
+               j < in_indptr[i + begin_row + 1]; j++) {
+            // indices of CSRNDArray are in ascending order per row
+            if (in_idx[j] >= end_col) {
+              break;
+            } else if (in_idx[j] >= begin_col) {
+              out_indptr[i+1]++;
+              nnz++;
+            }
+          }
+        }
+        // returns zeros in csr format if nnz = 0
+        if (nnz == 0) {
+          out.set_aux_shape(kIdx, Shape1(0));
 
 Review comment:
   Same comment here

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services