You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/08 12:09:44 UTC

[GitHub] ZiyueHuang commented on a change in pull request #8259: check_format of sparse ndrray

ZiyueHuang commented on a change in pull request #8259: check_format of sparse ndrray
URL: https://github.com/apache/incubator-mxnet/pull/8259#discussion_r149650525
 
 

 ##########
 File path: src/common/utils.h
 ##########
 @@ -43,9 +43,170 @@
 #include <algorithm>
 #include <functional>
 
+#include "../operator/mxnet_op.h"
+
 namespace mxnet {
 namespace common {
 
+
+/*!
+ * \brief IndPtr should be non-negative, in non-decreasing order, start with 0
+ *           and end with value equal with size of indices.
+ */
+struct csr_indptr_check {
+  template<typename DType, typename IType>
+  MSHADOW_XINLINE static void Map(int i, DType* out, const IType* indptr,
+                                  const nnvm::dim_t end, const nnvm::dim_t idx_size) {
+    if (indptr[i+1] < 0 || indptr[i+1] < indptr[i] ||
+        (i == 0 && indptr[i] != 0) ||
+        (i == end - 1 && indptr[end] != idx_size))
+      *out = kCSRIndPtrErr;
+  }
+};
+
+/*!
+ *  \brief Indices should be non-negative, less than the number of columns
+ *           and in ascending order per row.
+ */
+struct csr_idx_check {
+  template<typename DType, typename IType, typename RType>
+  MSHADOW_XINLINE static void Map(int i, DType* out, const IType* idx,
+                                  const RType* indptr, const nnvm::dim_t ncols) {
+    for (RType j = indptr[i]; j < indptr[i+1]; j++) {
+      if (idx[j] >= ncols || idx[j] < 0 ||
+          (j < indptr[i+1] - 1 && idx[j] >= idx[j+1])) {
+        *out = kCSRIdxErr;
+        break;
+      }
+    }
+  }
+};
+
+/*!
+ *  \brief Indices of RSPNDArray should be non-negative,
+ *           less than the size of first dimension and in ascending order
+ */
+struct rsp_idx_check {
+  template<typename DType, typename IType>
+  MSHADOW_XINLINE static void Map(int i, DType* out, const IType* idx,
+                                  const nnvm::dim_t end, const nnvm::dim_t nrows) {
+    if ((i < end && idx[i+1] <= idx[i])
+        || idx[i] < 0 || idx[i] >= nrows)
+      *out = kRSPIdxErr;
+  }
+};
+
+template<typename xpu>
+void CheckFormatWrapper(const RunContext &rctx, const NDArray &input,
+                        const TBlob &err_cpu, const bool full_check);
+
+/*!
+ * \brief Check the validity of CSRNDArray.
+ * \param rctx Execution context.
+ * \param input Input NDArray of CSRStorage.
+ * \param err_cpu Error number on cpu.
+ * \param full_check If true, rigorous check, O(N) operations,
+ *          otherwise basic check, O(1) operations.
+ */
+template<typename xpu>
+void CheckFormatCSRImpl(const RunContext &rctx, const NDArray &input,
+                        const TBlob &err_cpu, const bool full_check) {
+  using namespace op::mxnet_op;
+  CHECK_EQ(input.storage_type(), kCSRStorage)
+          << "CheckFormatCSRImpl is for CSRNDArray";
+  const TShape shape = input.shape();
+  const TShape idx_shape = input.aux_shape(csr::kIdx);
+  const TShape indptr_shape = input.aux_shape(csr::kIndPtr);
+  const TShape storage_shape = input.storage_shape();
+  if ((shape.ndim() != 2) ||
+      (idx_shape.ndim() != 1 || indptr_shape.ndim() != 1 || storage_shape.ndim() != 1) ||
+      (indptr_shape[0] != shape[0] + 1) ||
+      (idx_shape[0] != storage_shape[0])) {
+     MSHADOW_TYPE_SWITCH(err_cpu.type_flag_, DType, {
+       auto err = err_cpu.dptr<DType>();
+       *err = kCSRShapeErr;
+     });
+     return;
+  }
+  if (full_check) {
+    MSHADOW_TYPE_SWITCH(err_cpu.type_flag_, DType, {
+      MSHADOW_IDX_TYPE_SWITCH(input.aux_type(csr::kIndPtr), RType, {
+        MSHADOW_IDX_TYPE_SWITCH(input.aux_type(csr::kIdx), IType, {
+          mshadow::Stream<xpu> *s = rctx.get_stream<xpu>();
+          NDArray ret_xpu = NDArray(mshadow::Shape1(1),
+                                    rctx.get_ctx(), false, err_cpu.type_flag_);
+          TBlob val_xpu = ret_xpu.data();
+          Kernel<set_to_int<kNormalErr>, xpu>::Launch(s, val_xpu.Size(), val_xpu.dptr<DType>());
+          Kernel<csr_indptr_check, xpu>::Launch(s, indptr_shape[0] - 1, val_xpu.dptr<DType>(),
 
 Review comment:
   Done.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services