You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/12/18 20:23:45 UTC

[incubator-mxnet] branch master updated: csr slice operator, gpu implementation (#8814)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9b0a5ca  csr slice operator, gpu implementation (#8814)
9b0a5ca is described below

commit 9b0a5caf58a1312d1cb520ece861ca4e4601f8fa
Author: Ziyue Huang <zy...@gmail.com>
AuthorDate: Tue Dec 19 04:23:42 2017 +0800

    csr slice operator, gpu implementation (#8814)
    
    * csr slice, gpu implementation
    
    * update comments
    
    * test already exists
    
    * common impl of csr slice on dim one
    
    * remove unnecessary stream->wait
    
    * add doc
    
    * trigger
---
 src/operator/tensor/matrix_op-inl.h |  84 ++++++---------------------
 src/operator/tensor/matrix_op.cc    |  62 ++++++++++++++++++++
 src/operator/tensor/matrix_op.cu    | 113 +++++++++++++++++++++++++++++++++++-
 3 files changed, 191 insertions(+), 68 deletions(-)

diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h
index 367f8de..51cffb1 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -397,9 +397,7 @@ inline bool SliceForwardInferStorageType(const nnvm::NodeAttrs& attrs,
   const auto& in_stype = in_attrs->at(0);
   auto& out_stype = out_attrs->at(0);
   bool dispatched = false;
-  const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
-  const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback :
-                                         DispatchMode::kFComputeEx;
+  const auto dispatch_ex = DispatchMode::kFComputeEx;
   // If step = 1, no need to fallback; otherwise fallback to dense
   bool trivial_step = false;
   if (param.step.ndim() == 0U) {
@@ -452,7 +450,6 @@ void SliceCsrIndPtrImpl(const int begin, const int end, RunContext ctx,
 
 /*
  * Slice a CSR NDArray for first dimension
- * Only implemented for CPU
  */
 template<typename xpu>
 void SliceDimOneCsrImpl(const TShape &begin, const TShape &end, const OpContext& ctx,
@@ -460,7 +457,6 @@ void SliceDimOneCsrImpl(const TShape &begin, const TShape &end, const OpContext&
   using namespace mshadow;
   using namespace mxnet_op;
   using namespace csr;
-  CHECK((std::is_same<xpu, cpu>::value)) << "SliceDimOneCsrImpl is only implemented for CPU";
   nnvm::dim_t begin_row = begin[0];
   nnvm::dim_t end_row = end[0];
   nnvm::dim_t indptr_len = end_row - begin_row + 1;
@@ -471,10 +467,13 @@ void SliceDimOneCsrImpl(const TShape &begin, const TShape &end, const OpContext&
       MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
         RType* in_indptr = in.aux_data(kIndPtr).dptr<RType>();
         RType* out_indptr = out.aux_data(kIndPtr).dptr<RType>();
-        SliceCsrIndPtrImpl<cpu, RType>(begin_row, end_row, ctx.run_ctx, in_indptr, out_indptr);
+        SliceCsrIndPtrImpl<xpu, RType>(begin_row, end_row, ctx.run_ctx, in_indptr, out_indptr);
 
-        // retrieve nnz (CPU implementation)
-        int nnz = out_indptr[indptr_len - 1];
+        Stream<xpu> *s = ctx.get_stream<xpu>();
+
+        RType nnz = 0;
+        mshadow::Copy(Tensor<cpu, 1, RType>(&nnz, Shape1(1)),
+                      Tensor<xpu, 1, RType>(out_indptr + indptr_len - 1, Shape1(1), s));
         // return csr zeros if nnz = 0
         if (nnz == 0) {
           out.set_aux_shape(kIdx, Shape1(0));
@@ -487,10 +486,15 @@ void SliceDimOneCsrImpl(const TShape &begin, const TShape &end, const OpContext&
         IType* out_idx = out.aux_data(kIdx).dptr<IType>();
         DType* in_data = in.data().dptr<DType>();
         DType* out_data = out.data().dptr<DType>();
-        int offset = in_indptr[begin_row];
-        // this is also a CPU-only implementation
-        memcpy(out_idx, in_idx + offset, nnz * sizeof(IType));
-        memcpy(out_data, in_data + offset, nnz * sizeof(DType));
+
+        RType offset = 0;
+        mshadow::Copy(Tensor<cpu, 1, RType>(&offset, Shape1(1)),
+                      Tensor<xpu, 1, RType>(in_indptr + begin_row, Shape1(1), s));
+
+        mshadow::Copy(Tensor<xpu, 1, IType>(out_idx, Shape1(nnz), s),
+                      Tensor<xpu, 1, IType>(in_idx + offset, Shape1(nnz), s), s);
+        mshadow::Copy(Tensor<xpu, 1, DType>(out_data, Shape1(nnz), s),
+                      Tensor<xpu, 1, DType>(in_data + offset, Shape1(nnz), s), s);
       });
     });
   });
@@ -535,69 +539,15 @@ struct SliceDimTwoCsrAssign {
 
 /*
  * Slice a CSR NDArray for two dimensions
- * Only implemented for CPU
  */
 template<typename xpu>
 void SliceDimTwoCsrImpl(const TShape &begin, const TShape &end, const OpContext& ctx,
-                        const NDArray &in, const NDArray &out) {
-  using namespace mshadow;
-  using namespace mxnet_op;
-  using namespace csr;
-  CHECK((std::is_same<xpu, cpu>::value)) << "SliceDimTwoCsrImpl is only implemented for CPU";
-  nnvm::dim_t begin_row = begin[0], end_row = end[0];
-  nnvm::dim_t begin_col = begin[1], end_col = end[1];
-  nnvm::dim_t indptr_len = end_row - begin_row + 1;
-  out.CheckAndAllocAuxData(kIndPtr, Shape1(indptr_len));
-  // assume idx indptr share the same type
-  MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIndPtr), RType, {
-    MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIdx), IType, {
-      MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
-        RType *in_indptr = in.aux_data(kIndPtr).dptr<RType>();
-        IType *in_idx = in.aux_data(kIdx).dptr<IType>();
-        DType *in_data = in.data().dptr<DType>();
-        // retrieve nnz (CPU implementation)
-        RType *out_indptr = out.aux_data(kIndPtr).dptr<RType>();
-        int nnz = 0;
-        out_indptr[0] = 0;
-        // loop through indptr array and corresponding indices to count for nnz
-        for (nnvm::dim_t i = 0; i < indptr_len - 1; i++) {
-          out_indptr[i+1] = out_indptr[i];
-          for (RType j = in_indptr[i + begin_row];
-               j < in_indptr[i + begin_row + 1]; j++) {
-            // indices of CSRNDArray are in ascending order per row
-            if (in_idx[j] >= end_col) {
-              break;
-            } else if (in_idx[j] >= begin_col) {
-              out_indptr[i+1]++;
-              nnz++;
-            }
-          }
-        }
-        // returns zeros in csr format if nnz = 0
-        if (nnz == 0) {
-          out.set_aux_shape(kIdx, Shape1(0));
-          return;
-        }
-        out.CheckAndAllocAuxData(kIdx, Shape1(nnz));
-        out.CheckAndAllocData(Shape1(nnz));
-        IType *out_idx = out.aux_data(kIdx).dptr<IType>();
-        DType *out_data = out.data().dptr<DType>();
-
-        Stream<xpu> *s = ctx.get_stream<xpu>();
-        Kernel<SliceDimTwoCsrAssign, xpu>::Launch(s, indptr_len - 1, out_idx, out_data,
-                                                  out_indptr, in_idx, in_data,
-                                                  in_indptr + begin_row,
-                                                  begin_col, end_col);
-      });
-    });
-  });
-}
+                        const NDArray &in, const NDArray &out);
 
 
 template<typename xpu>
 void SliceCsrImpl(const SliceParam &param, const OpContext& ctx,
                   const NDArray &in, OpReqType req, const NDArray &out) {
-  CHECK((std::is_same<xpu, cpu>::value)) << "Slice for CSR input only implemented for CPU";
   if (req == kNullOp) return;
   CHECK_NE(req, kAddTo) << "kAddTo for Slice on CSR input is not supported";
   CHECK_NE(req, kWriteInplace) << "kWriteInplace for Slice on CSR input is not supported";
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index 8f36e35..e8fdce4 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -28,6 +28,64 @@
 
 namespace mxnet {
 namespace op {
+
+
+template<>
+void SliceDimTwoCsrImpl<cpu>(const TShape &begin, const TShape &end, const OpContext& ctx,
+                             const NDArray &in, const NDArray &out) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using namespace csr;
+  nnvm::dim_t begin_row = begin[0], end_row = end[0];
+  nnvm::dim_t begin_col = begin[1], end_col = end[1];
+  nnvm::dim_t indptr_len = end_row - begin_row + 1;
+  out.CheckAndAllocAuxData(kIndPtr, Shape1(indptr_len));
+  // assume idx indptr share the same type
+  MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIndPtr), RType, {
+    MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIdx), IType, {
+      MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
+        RType *in_indptr = in.aux_data(kIndPtr).dptr<RType>();
+        IType *in_idx = in.aux_data(kIdx).dptr<IType>();
+        DType *in_data = in.data().dptr<DType>();
+        // retrieve nnz (CPU implementation)
+        RType *out_indptr = out.aux_data(kIndPtr).dptr<RType>();
+        int nnz = 0;
+        out_indptr[0] = 0;
+        // loop through indptr array and corresponding indices to count for nnz
+        for (nnvm::dim_t i = 0; i < indptr_len - 1; i++) {
+          out_indptr[i+1] = out_indptr[i];
+          for (RType j = in_indptr[i + begin_row];
+               j < in_indptr[i + begin_row + 1]; j++) {
+            // indices of CSRNDArray are in ascending order per row
+            if (in_idx[j] >= end_col) {
+              break;
+            } else if (in_idx[j] >= begin_col) {
+              out_indptr[i+1]++;
+              nnz++;
+            }
+          }
+        }
+        // returns zeros in csr format if nnz = 0
+        if (nnz == 0) {
+          out.set_aux_shape(kIdx, Shape1(0));
+          return;
+        }
+        out.CheckAndAllocAuxData(kIdx, Shape1(nnz));
+        out.CheckAndAllocData(Shape1(nnz));
+        IType *out_idx = out.aux_data(kIdx).dptr<IType>();
+        DType *out_data = out.data().dptr<DType>();
+
+        Stream<cpu> *s = ctx.get_stream<cpu>();
+        Kernel<SliceDimTwoCsrAssign, cpu>::Launch(s, indptr_len - 1, out_idx, out_data,
+                                                  out_indptr, in_idx, in_data,
+                                                  in_indptr + begin_row,
+                                                  begin_col, end_col);
+      });
+    });
+  });
+}
+
+
 DMLC_REGISTER_PARAMETER(ReshapeParam);
 DMLC_REGISTER_PARAMETER(TransposeParam);
 DMLC_REGISTER_PARAMETER(ExpandDimParam);
@@ -298,6 +356,10 @@ Example::
 .set_attr_parser(ParamParser<SliceParam>)
 .set_attr<nnvm::FInferShape>("FInferShape", SliceOpShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+})
 .set_attr<FInferStorageType>("FInferStorageType", SliceForwardInferStorageType)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_slice"})
 .set_attr<FCompute>("FCompute<cpu>", SliceOpForward<cpu>)
diff --git a/src/operator/tensor/matrix_op.cu b/src/operator/tensor/matrix_op.cu
index 30eaf23..b6597be 100644
--- a/src/operator/tensor/matrix_op.cu
+++ b/src/operator/tensor/matrix_op.cu
@@ -22,11 +22,121 @@
  * \file matrix_op.cu
  * \brief GPU Implementation of matrix operations
  */
+#include <cub/cub.cuh>
 #include "./matrix_op-inl.h"
 #include "./elemwise_unary_op.h"
 
+
 namespace mxnet {
 namespace op {
+
+/*!
+ * \brief Compute the number of elements of every row.
+ */
+struct SliceMarkCsrIndPtr {
+  /*! 
+   * \brief
+   * \param i           the i-th row of the output csr ndarray
+   * \param prefix_sum  indptr array of the output csr ndarray
+   * \param in_idx      indices array of the input csr ndarray
+   * \param in_indptr   indptr array of the input csr ndarray
+   * \param begin_col   starting indice
+   * \param end_col     ending indice
+   */
+  template<typename IType, typename RType>
+  MSHADOW_XINLINE static void Map(int i,
+                                  RType* prefix_sum,
+                                  const IType* in_idx,
+                                  const RType* in_indptr,
+                                  const int begin_col, const int end_col) {
+    if (i == 0) {
+      prefix_sum[0] = 0;
+    }
+    RType size = 0;
+    for (RType j = in_indptr[i]; j < in_indptr[i+1]; j++) {
+      // indices of CSRNDArray are in ascending order per row
+      if (in_idx[j] >= end_col) {
+        break;
+      } else if (in_idx[j] >= begin_col) {
+        size++;
+      }
+    }
+    prefix_sum[i+1] = size;
+  }
+};
+
+
+template<>
+void SliceDimTwoCsrImpl<gpu>(const TShape &begin, const TShape &end, const OpContext& ctx,
+                             const NDArray &in, const NDArray &out) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using namespace csr;
+
+  Stream<gpu> *s = ctx.get_stream<gpu>();
+
+  nnvm::dim_t begin_row = begin[0], end_row = end[0];
+  nnvm::dim_t begin_col = begin[1], end_col = end[1];
+  nnvm::dim_t indptr_len = end_row - begin_row + 1;
+  out.CheckAndAllocAuxData(kIndPtr, Shape1(indptr_len));
+  // assume idx indptr share the same type
+  MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIndPtr), RType, {
+    MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIdx), IType, {
+      MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
+        RType *in_indptr = in.aux_data(kIndPtr).dptr<RType>();
+        IType *in_idx = in.aux_data(kIdx).dptr<IType>();
+        DType *in_data = in.data().dptr<DType>();
+
+        RType *out_indptr = out.aux_data(kIndPtr).dptr<RType>();
+
+        Kernel<SliceMarkCsrIndPtr, gpu>::Launch(s, indptr_len - 1,
+                                                out_indptr,
+                                                in_idx,
+                                                in_indptr + begin_row,
+                                                begin_col, end_col);
+        void* d_temp_storage = NULL;
+        size_t temp_storage_bytes = 0;
+        cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                      temp_storage_bytes,
+                                      out_indptr,
+                                      out_indptr,
+                                      indptr_len,
+                                      Stream<gpu>::GetStream(s));
+        Tensor<gpu, 1, char> workspace = ctx.requested[0]
+            .get_space_typed<gpu, 1, char>(Shape1(temp_storage_bytes), s);
+        d_temp_storage = workspace.dptr_;
+
+        cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                      temp_storage_bytes,
+                                      out_indptr,
+                                      out_indptr,
+                                      indptr_len,
+                                      Stream<gpu>::GetStream(s));
+        // retrieve nnr
+        RType nnr = 0;
+        CUDA_CALL(cudaMemcpy(&nnr, &out_indptr[indptr_len-1], sizeof(RType),
+            cudaMemcpyDeviceToHost));
+
+        // returns zeros in csr format if nnr = 0
+        if (nnr == 0) {
+          out.set_aux_shape(kIdx, Shape1(0));
+          return;
+        }
+        out.CheckAndAllocAuxData(kIdx, Shape1(nnr));
+        out.CheckAndAllocData(Shape1(nnr));
+        IType *out_idx = out.aux_data(kIdx).dptr<IType>();
+        DType *out_data = out.data().dptr<DType>();
+
+        Kernel<SliceDimTwoCsrAssign, gpu>::Launch(s, indptr_len - 1, out_idx, out_data,
+                                                  out_indptr, in_idx, in_data,
+                                                  in_indptr + begin_row,
+                                                  begin_col, end_col);
+      });
+    });
+  });
+}
+
+
 NNVM_REGISTER_OP(Reshape)
 .set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
 
@@ -40,7 +150,8 @@ NNVM_REGISTER_OP(expand_dims)
 .set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
 
 NNVM_REGISTER_OP(slice)
-.set_attr<FCompute>("FCompute<gpu>", SliceOpForward<gpu>);
+.set_attr<FCompute>("FCompute<gpu>", SliceOpForward<gpu>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", SliceEx<gpu>);
 
 NNVM_REGISTER_OP(_backward_slice)
 .set_attr<FCompute>("FCompute<gpu>", SliceOpBackward<gpu>);

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].