You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/03/27 21:00:53 UTC

[incubator-mxnet] branch master updated: speedup SequenceMask on GPU (#14445)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 09daf22  speedup SequenceMask on GPU (#14445)
09daf22 is described below

commit 09daf22c35de1ab42744690a50eae64cc8967c5b
Author: Hao Jin <hj...@gmail.com>
AuthorDate: Wed Mar 27 14:00:36 2019 -0700

    speedup SequenceMask on GPU (#14445)
---
 src/operator/sequence_mask-inl.h | 79 +++++++++-------------------------------
 src/operator/sequence_mask.cc    | 64 ++++++++++++++++++++++++++++++++
 src/operator/sequence_mask.cu    | 59 ++++++++++++++++++++++++++++++
 3 files changed, 140 insertions(+), 62 deletions(-)

diff --git a/src/operator/sequence_mask-inl.h b/src/operator/sequence_mask-inl.h
index 372cf57..05a9424 100644
--- a/src/operator/sequence_mask-inl.h
+++ b/src/operator/sequence_mask-inl.h
@@ -65,70 +65,24 @@ struct SequenceMaskParam : public dmlc::Parameter<SequenceMaskParam> {
   }
 };
 
-// (seqlen, batch, rest) case
-template <int req>
-struct SequenceMask0Kernel {
-  template <typename DType, typename IType>
-  MSHADOW_XINLINE static void Map(int b, DType *in, const IType *idx,
-                                  index_t max_s_len, index_t batch_size,
-                                  index_t restsize, DType value) {
-    const index_t seqpos = static_cast<int>(idx[b]);
-#pragma unroll
-    for (index_t s = seqpos; s < max_s_len; ++s) {
-      index_t incr = (s * batch_size * restsize) + (b * restsize);
-#pragma unroll
-      for (index_t r = 0; r < restsize; ++r)
-        KERNEL_ASSIGN(in[incr + r], req, value);
-    }
-  }
-};
-
-// (batch, seqlen, rest) case
-template <int req>
-struct SequenceMask1Kernel {
-  template <typename DType, typename IType>
-  MSHADOW_XINLINE static void Map(int b, DType *in, const IType *idx,
-                                  index_t max_s_len, index_t batch_size,
-                                  index_t restsize, DType value) {
-    const index_t seqpos = static_cast<int>(idx[b]);
-#pragma unroll
-    for (index_t s = seqpos; s < max_s_len; ++s) {
-      index_t incr = (b * max_s_len * restsize) + (s * restsize);
-#pragma unroll
-      for (index_t r = 0; r < restsize; ++r)
-        KERNEL_ASSIGN(in[incr + r], req, value);
-    }
-  }
-};
+template<typename DType, typename IType>
+void SequenceMaskExec(const mshadow::Tensor<cpu, 3, DType> &data,
+                  const mshadow::Tensor<cpu, 1, IType> &indices,
+                  const OpReqType req, mshadow::Stream<cpu> *const s,
+                  int axis, DType val);
+#ifdef __CUDACC__
+template<typename DType, typename IType>
+void SequenceMaskExec(const mshadow::Tensor<gpu, 3, DType> &data,
+                  const mshadow::Tensor<gpu, 1, IType> &indices,
+                  const OpReqType req, mshadow::Stream<gpu> *const s,
+                  int axis, DType val);
+#endif
 
 template <typename xpu, typename DType, typename IType>
 class SequenceMaskOp : public Operator {
  public:
   explicit SequenceMaskOp(SequenceMaskParam p) { this->param_ = p; }
 
-  void sequence_mask(const mshadow::Tensor<xpu, 3, DType> &data,
-                     const mshadow::Tensor<xpu, 1, IType> &indices,
-                     const OpReqType req, mshadow::Stream<xpu> *const s,
-                     DType val) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-
-    index_t batch = indices.size(0);
-    index_t max_seq_len = data.size(param_.axis);
-    index_t restsize = data.size(2);
-
-    MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
-      if (param_.axis == 1)
-        mxnet_op::Kernel<SequenceMask1Kernel<req_type>, xpu>::Launch(
-            s, batch, data.dptr_, indices.dptr_, max_seq_len, batch, restsize,
-            val);
-      else
-        mxnet_op::Kernel<SequenceMask0Kernel<req_type>, xpu>::Launch(
-            s, batch, data.dptr_, indices.dptr_, max_seq_len, batch, restsize,
-            val);
-    });
-  }
-
   virtual void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data,
                        const std::vector<OpReqType> &req,
                        const std::vector<TBlob> &out_data,
@@ -155,8 +109,8 @@ class SequenceMaskOp : public Operator {
     if (param_.use_sequence_length) {
       Tensor<xpu, 1, IType> indices =
           in_data[seq_mask::kSequenceLength].get<xpu, 1, IType>(s);
-      sequence_mask(out, indices, req[seq_mask::kOut], s,
-                    static_cast<DType>(param_.value));
+      SequenceMaskExec<DType, IType>(out, indices, req[seq_mask::kOut], s,
+                   param_.axis, static_cast<DType>(param_.value));
     }
   }
 
@@ -198,11 +152,12 @@ class SequenceMaskOp : public Operator {
                 s3, s);
         out_g_temp = F<mshadow_op::identity>(out_g);
         out_g = out_g_temp;
-        sequence_mask(out_g, indices, kWriteInplace, s, DType(0.));
+        SequenceMaskExec<DType, IType>(out_g, indices, kWriteInplace, s, param_.axis, DType(0.));
         Assign(data_g, kAddTo, F<mshadow_op::identity>(out_g));
       } else {
         Assign(data_g, req[seq_mask::kData], F<mshadow_op::identity>(out_g));
-        sequence_mask(data_g, indices, req[seq_mask::kData], s, DType(0.));
+        SequenceMaskExec<DType, IType>(
+          data_g, indices, req[seq_mask::kData], s, param_.axis, DType(0.));
       }
     }
   }
diff --git a/src/operator/sequence_mask.cc b/src/operator/sequence_mask.cc
index c3bf12d..f4f81a8 100644
--- a/src/operator/sequence_mask.cc
+++ b/src/operator/sequence_mask.cc
@@ -27,6 +27,70 @@
 
 namespace mxnet {
 namespace op {
+
+// (seqlen, batch, rest) case
+template <int req>
+struct SequenceMask0CPUKernel {
+  template <typename DType, typename IType>
+  MSHADOW_XINLINE static void Map(int batch, DType *in, const IType *idx,
+                                  index_t max_s_len, index_t batch_size,
+                                  index_t restsize, DType value) {
+    const index_t seqpos = static_cast<int>(idx[batch]);
+#pragma unroll
+    for (index_t s = seqpos; s < max_s_len; ++s) {
+      index_t incr = (s * batch_size * restsize) + (batch * restsize);
+#pragma unroll
+      for (index_t r = 0; r < restsize; ++r)
+        KERNEL_ASSIGN(in[incr + r], req, value);
+    }
+  }
+};
+
+// (batch, seqlen, rest) case
+template <int req>
+struct SequenceMask1CPUKernel {
+  template <typename DType, typename IType>
+  MSHADOW_XINLINE static void Map(int batch, DType *in, const IType *idx,
+                                  index_t max_s_len, index_t batch_size,
+                                  index_t restsize, DType value) {
+    const index_t seqpos = static_cast<int>(idx[batch]);
+#pragma unroll
+    for (index_t s = seqpos; s < max_s_len; ++s) {
+      index_t incr = (batch * max_s_len * restsize) + (s * restsize);
+#pragma unroll
+      for (index_t r = 0; r < restsize; ++r)
+        KERNEL_ASSIGN(in[incr + r], req, value);
+    }
+  }
+};
+
+template<typename DType, typename IType>
+void SequenceMaskExec(
+       const mshadow::Tensor<cpu, 3, DType> &data,
+       const mshadow::Tensor<cpu, 1, IType> &indices,
+       const OpReqType req, mshadow::Stream<cpu> *const s,
+       int axis, DType val) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace mxnet_op;
+
+  index_t batch = indices.size(0);
+  index_t max_seq_len = data.size(axis);
+  index_t restsize = data.size(2);
+
+  MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+    if (axis == 1) {
+      Kernel<SequenceMask1CPUKernel<req_type>, cpu>::Launch(
+        s, batch, data.dptr_, indices.dptr_, max_seq_len, batch, restsize,
+        val);
+    } else {
+      Kernel<SequenceMask0CPUKernel<req_type>, cpu>::Launch(
+        s, batch, data.dptr_, indices.dptr_, max_seq_len, batch, restsize,
+        val);
+    }
+  });
+}
+
 template <>
 Operator *CreateOp<cpu>(SequenceMaskParam param, int dtype, int itype) {
   Operator *op = nullptr;
diff --git a/src/operator/sequence_mask.cu b/src/operator/sequence_mask.cu
index cec627c..8f196b4 100644
--- a/src/operator/sequence_mask.cu
+++ b/src/operator/sequence_mask.cu
@@ -29,6 +29,65 @@
 namespace mxnet {
 namespace op {
 
+// (seqlen, batch, rest) case
+template <int req>
+struct SequenceMask0GPUKernel {
+  template <typename DType, typename IType>
+  MSHADOW_XINLINE static void Map(int i, DType *in, const IType *idx,
+                                  index_t max_s_len, index_t batch_size,
+                                  index_t restsize, DType value) {
+    index_t batch = i / restsize % batch_size;
+    const index_t seqpos = static_cast<int>(idx[batch]);
+    index_t seq = i / restsize / batch_size;
+    if (seq >= seqpos) {
+      KERNEL_ASSIGN(in[i], req, value);
+    }
+  }
+};
+
+// (batch, seqlen, rest) case
+template <int req>
+struct SequenceMask1GPUKernel {
+  template <typename DType, typename IType>
+  MSHADOW_XINLINE static void Map(int i, DType *in, const IType *idx,
+                                  index_t max_s_len, index_t batch_size,
+                                  index_t restsize, DType value) {
+    index_t batch = i / restsize / max_s_len;
+    const index_t seqpos = static_cast<int>(idx[batch]);
+    index_t seq = i / restsize % max_s_len;
+    if (seq >= seqpos) {
+      KERNEL_ASSIGN(in[i], req, value);
+    }
+  }
+};
+
+template<typename DType, typename IType>
+void SequenceMaskExec(
+       const mshadow::Tensor<gpu, 3, DType> &data,
+       const mshadow::Tensor<gpu, 1, IType> &indices,
+       const OpReqType req, mshadow::Stream<gpu> *const s,
+       int axis, DType val) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace mxnet_op;
+
+  index_t batch = indices.size(0);
+  index_t max_seq_len = data.size(axis);
+  index_t restsize = data.size(2);
+
+  MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+    if (axis == 1) {
+      Kernel<SequenceMask1GPUKernel<req_type>, gpu>::Launch(
+        s, data.shape_.Size(), data.dptr_, indices.dptr_, max_seq_len, batch, restsize,
+        val);
+    } else {
+      Kernel<SequenceMask0GPUKernel<req_type>, gpu>::Launch(
+        s, data.shape_.Size(), data.dptr_, indices.dptr_, max_seq_len, batch, restsize,
+        val);
+    }
+  });
+}
+
 template <> Operator *CreateOp<gpu>(SequenceMaskParam param, int dtype, int itype) {
   Operator *op = NULL;
   MSHADOW_TYPE_SWITCH(dtype, DType, {