You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/11/04 07:14:12 UTC

[incubator-mxnet] branch master updated: customized take forward for CPU (#12997)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 31ebb95  customized take forward for CPU (#12997)
31ebb95 is described below

commit 31ebb9571efb987a7e5e482d30012cb3e94c9567
Author: HyperZealot <40...@users.noreply.github.com>
AuthorDate: Sun Nov 4 02:13:56 2018 -0500

    customized take forward for CPU (#12997)
---
 src/operator/tensor/indexing_op.cc | 113 +++++++++++++++++++++++++++++++++++
 src/operator/tensor/indexing_op.cu | 115 ++++++++++++++++++++++++++++++++++++
 src/operator/tensor/indexing_op.h  | 117 ++++---------------------------------
 3 files changed, 239 insertions(+), 106 deletions(-)

diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc
index de0ede3..77236e0 100644
--- a/src/operator/tensor/indexing_op.cc
+++ b/src/operator/tensor/indexing_op.cc
@@ -28,6 +28,28 @@
 namespace mxnet {
 namespace op {
 
+template<bool clip = true>
+struct TakeCPU {
+  // assume that idx have been flattened to a 1-D tensor (N,)
+  // assume that out_data and in_data have been flattened to 2-D tensors, (N, M) and (K, M)
+  // M is the number of columns of in_data and out_data
+  // K is the number of rows of in_data
+  // i is the index of out_data
+  template<typename DType, typename IType>
+  MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data,
+                                  const IType* idx, const size_t M, const int64_t K) {
+    int64_t j = static_cast<int64_t>(idx[i]);
+    if (clip) {
+      if (j <= 0) j = 0;
+      else if (j >= K) j = K - 1;
+    } else {
+      j = j % K;
+      j += (j < 0) ? K : 0;
+    }
+    std::memcpy(out_data + i * M, in_data + j * M, M * sizeof(DType));
+  }
+};
+
 /*
  * \brief returns true if all indices are between [min, max]
  * \param data_ptr the indices to check
@@ -48,6 +70,29 @@ bool CheckIndexOutOfBound(const DType* data_ptr, size_t data_size,
   return is_valid;
 }
 
+// Embedding forward implementation with dense weight
+template<>
+void EmbeddingOpForwardDnsImpl<cpu>(mshadow::Stream<cpu>* s,
+                                    const TBlob& data,
+                                    const TBlob& weight,
+                                    const OpReqType req,
+                                    const TBlob& output) {
+  using namespace mxnet_op;
+  const TShape& ishape = data.shape_;
+  const TShape& oshape = output.shape_;
+
+  MSHADOW_TYPE_SWITCH(output.type_flag_, DType, {
+    MSHADOW_TYPE_SWITCH(data.type_flag_, IType, {
+      Tensor<cpu, 1, IType> idx = data.get_with_shape<cpu, 1, IType>(
+        Shape1(ishape.ProdShape(0, ishape.ndim())), s);
+      Tensor<cpu, 2, DType> wmat = weight.get<cpu, 2, DType>(s);
+      Tensor<cpu, 2, DType> out = output.get_with_shape<cpu, 2, DType>(
+        Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s);
+      Kernel<TakeCPU<true>, cpu>::Launch(s, oshape.Size() / wmat.shape_[1], out.dptr_, wmat.dptr_,
+                                         idx.dptr_, wmat.shape_[1], wmat.shape_[0]);
+    });
+  });
+}
 
 template<>
 void SparseEmbeddingOpForwardRspImpl<cpu>(const OpContext& ctx,
@@ -228,6 +273,74 @@ void TakeOpForwardCsrImpl<cpu>(const TakeParam& params,
 }
 
 template<>
+void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
+                        const OpContext& ctx,
+                        const std::vector<TBlob>& inputs,
+                        const std::vector<OpReqType>& req,
+                        const std::vector<TBlob>& outputs) {
+  using namespace mxnet_op;
+  if (req[take_::kOut] == kNullOp) return;
+  const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+
+  const TShape& idxshape = inputs[take_::kIdx].shape_;
+  const TShape& arrshape = inputs[take_::kArr].shape_;
+  const TShape& oshape = outputs[take_::kOut].shape_;
+
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0);
+
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {  // output data type
+    MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {  // index data type
+      if (actual_axis == 0) {
+        if (param.mode == take_::kClip) {
+          Kernel<TakeCPU<true>, cpu>::Launch(s, idxshape.Size(),
+                                             outputs[take_::kOut].dptr<DType>(),
+                                             inputs[take_::kArr].dptr<DType>(),
+                                             inputs[take_::kIdx].dptr<IType>(),
+                                             oshape.Size()/idxshape.Size(), arrshape[0]);
+        } else {
+          Kernel<TakeCPU<false>, cpu>::Launch(s, idxshape.Size(),
+                                              outputs[take_::kOut].dptr<DType>(),
+                                              inputs[take_::kArr].dptr<DType>(),
+                                              inputs[take_::kIdx].dptr<IType>(),
+                                              oshape.Size()/idxshape.Size(), arrshape[0]);
+        }
+      } else {
+        mshadow::Shape<10> in_strides;
+        int stride = 1;
+        for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) {
+          in_strides[i] = stride;
+        }
+        mshadow::Shape<10> out_strides;
+        stride = 1;
+        for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) {
+          out_strides[i] = stride;
+        }
+        if (param.mode == take_::kClip) {
+          Kernel<Take<true>, cpu>::Launch(s, oshape.Size(),
+                                          outputs[take_::kOut].dptr<DType>(),
+                                          inputs[take_::kArr].dptr<DType>(),
+                                          inputs[take_::kIdx].dptr<IType>(),
+                                          in_strides, out_strides, arrshape.ndim(),
+                                          oshape.ndim(), idxshape.ndim(),
+                                          arrshape[actual_axis], actual_axis);
+        } else if (param.mode == take_::kWrap) {
+          Kernel<Take<false>, cpu>::Launch(s, oshape.Size(),
+                                           outputs[take_::kOut].dptr<DType>(),
+                                           inputs[take_::kArr].dptr<DType>(),
+                                           inputs[take_::kIdx].dptr<IType>(),
+                                           in_strides, out_strides, arrshape.ndim(),
+                                           oshape.ndim(), idxshape.ndim(),
+                                           arrshape[actual_axis], actual_axis);
+        }
+      }
+    });
+  });
+}
+
+template<>
 inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const bool deterministic,
                                                   const OpContext& ctx,
                                                   const TBlob& ograd,
diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu
index 16cc697..0d72b18 100644
--- a/src/operator/tensor/indexing_op.cu
+++ b/src/operator/tensor/indexing_op.cu
@@ -116,6 +116,31 @@ struct AddTakeGradRspDeterministicKernel {
   }
 };
 
+/*! \brief name the struct Take instead of take
+ * to avoid conflict with the take function in mshadow
+ */
+template<bool clip = true>
+struct TakeGPU {
+  // assume that idx have been flattened to a 1-D tensor (N,)
+  // assume that out_data and in_data have been flattened to 2-D tensors, (N, M) and (K, M)
+  // M is the number of columns of in_data and out_data
+  // K is the number of rows of in_data
+  // i is the index of out_data
+  template<typename DType, typename IType>
+  MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data,
+                                  const IType* idx, const int64_t M, const int64_t K) {
+    int64_t j = static_cast<int64_t>(idx[i/M]);
+    if (clip) {
+      if (j <= 0) j = 0;
+      else if (j >= K) j = K - 1;
+    } else {
+      j = j % K;
+      j += (j < 0) ? K : 0;
+    }
+    out_data[i] = in_data[j * M + i % M];
+  }
+};
+
 /*
  * \brief returns true if all indices are between [min, max]
  * \param s the stream
@@ -137,6 +162,30 @@ bool CheckIndexOutOfBound(mshadow::Stream<gpu> *s, const DType* data_ptr, size_t
   return is_valid == 0;
 }
 
+// Embedding forward implementation with dense weight
+template<>
+void EmbeddingOpForwardDnsImpl<gpu>(mshadow::Stream<gpu>* s,
+                                    const TBlob& data,
+                                    const TBlob& weight,
+                                    const OpReqType req,
+                                    const TBlob& output) {
+  using namespace mxnet_op;
+  const TShape& ishape = data.shape_;
+  const TShape& oshape = output.shape_;
+
+  MSHADOW_TYPE_SWITCH(output.type_flag_, DType, {
+    MSHADOW_TYPE_SWITCH(data.type_flag_, IType, {
+      Tensor<gpu, 1, IType> idx = data.get_with_shape<gpu, 1, IType>(
+        Shape1(ishape.ProdShape(0, ishape.ndim())), s);
+      Tensor<gpu, 2, DType> wmat = weight.get<gpu, 2, DType>(s);
+      Tensor<gpu, 2, DType> out = output.get_with_shape<gpu, 2, DType>(
+        Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s);
+      Kernel<TakeGPU<true>, gpu>::Launch(s, oshape.Size(), out.dptr_, wmat.dptr_,
+                                         idx.dptr_, wmat.shape_[1], wmat.shape_[0]);
+    });
+  });
+}
+
 template<>
 void SparseEmbeddingOpForwardRspImpl<gpu>(const OpContext& ctx,
                                           const TBlob& data,
@@ -414,6 +463,72 @@ inline void GatherNDBackwardImpl(int N, int M, int K,
   mxnet_op::Kernel<backward_gather_nd_gpu, gpu>::Launch(s, N, N, M, K, strides, out, data, indices);
 }
 
+template<>
+void TakeOpForward<gpu>(const nnvm::NodeAttrs& attrs,
+                        const OpContext& ctx,
+                        const std::vector<TBlob>& inputs,
+                        const std::vector<OpReqType>& req,
+                        const std::vector<TBlob>& outputs) {
+  using namespace mxnet_op;
+  if (req[take_::kOut] == kNullOp) return;
+  const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+
+  const TShape& idxshape = inputs[take_::kIdx].shape_;
+  const TShape& arrshape = inputs[take_::kArr].shape_;
+  const TShape& oshape = outputs[take_::kOut].shape_;
+
+  Stream<gpu> *s = ctx.get_stream<gpu>();
+  const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0);
+
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {  // output data type
+    MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {  // index data type
+      if (actual_axis == 0) {
+        if (param.mode == take_::kClip) {
+          Kernel<TakeGPU<true>, gpu>::Launch(s, oshape.Size(),
+                                             outputs[take_::kOut].dptr<DType>(),
+                                             inputs[take_::kArr].dptr<DType>(),
+                                             inputs[take_::kIdx].dptr<IType>(),
+                                             oshape.Size()/idxshape.Size(), arrshape[0]);
+        } else {
+          Kernel<TakeGPU<false>, gpu>::Launch(s, oshape.Size(),
+                                              outputs[take_::kOut].dptr<DType>(),
+                                              inputs[take_::kArr].dptr<DType>(),
+                                              inputs[take_::kIdx].dptr<IType>(),
+                                              oshape.Size()/idxshape.Size(), arrshape[0]);
+        }
+      } else {
+        mshadow::Shape<10> in_strides;
+        int stride = 1;
+        for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) {
+          in_strides[i] = stride;
+        }
+        mshadow::Shape<10> out_strides;
+        stride = 1;
+        for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) {
+          out_strides[i] = stride;
+        }
+        if (param.mode == take_::kClip) {
+          Kernel<Take<true>, gpu>::Launch(s, oshape.Size(),
+                                          outputs[take_::kOut].dptr<DType>(),
+                                          inputs[take_::kArr].dptr<DType>(),
+                                          inputs[take_::kIdx].dptr<IType>(),
+                                          in_strides, out_strides, arrshape.ndim(), oshape.ndim(),
+                                          idxshape.ndim(), arrshape[actual_axis], actual_axis);
+        } else if (param.mode == take_::kWrap) {
+          Kernel<Take<false>, gpu>::Launch(s, oshape.Size(),
+                                           outputs[take_::kOut].dptr<DType>(),
+                                           inputs[take_::kArr].dptr<DType>(),
+                                           inputs[take_::kIdx].dptr<IType>(),
+                                           in_strides, out_strides, arrshape.ndim(), oshape.ndim(),
+                                           idxshape.ndim(), arrshape[actual_axis], actual_axis);
+        }
+      }
+    });
+  });
+}
+
 NNVM_REGISTER_OP(Embedding)
 .set_attr<FCompute>("FCompute<gpu>", EmbeddingOpForward<gpu>)
 .set_attr<FComputeEx>("FComputeEx<gpu>", SparseEmbeddingOpForwardEx<gpu>);
diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h
index 2a419e7..92b6e21 100644
--- a/src/operator/tensor/indexing_op.h
+++ b/src/operator/tensor/indexing_op.h
@@ -301,25 +301,6 @@ inline bool SparseEmbeddingOpBackwardStorageType(const nnvm::NodeAttrs& attrs,
  */
 template<bool clip = true>
 struct Take {
-  // assume that idx have been flattened to a 1-D tensor (N,)
-  // assume that out_data and in_data have been flattened to 2-D tensors, (N, M) and (K, M)
-  // M is the number of columns of in_data and out_data
-  // K is the number of rows of in_data
-  // i is the index of out_data
-  template<typename DType, typename IType>
-  MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data,
-                                  const IType* idx, const int M, const int K) {
-    int j = static_cast<int>(idx[i/M]);
-    if (clip) {
-      if (j <= 0) j = 0;
-      else if (j >= K) j = K - 1;
-    } else {
-      j = j % K;
-      j += (j < 0) ? K : 0;
-    }
-    out_data[i] = in_data[j * M + i % M];
-  }
-
   /*!
    * \brief Map function for take operator
    * \param i           global thread id
@@ -339,21 +320,21 @@ struct Take {
                                   const int in_ndims, const int out_ndims, const int idx_ndims,
                                   const int axis_dim, const int axis) {
     // i is the global flattened index in the output
-    const int out_head_index = (axis == 0) ? 0 : (i / out_stride[axis - 1]);
-    const int out_rest_index = (axis == 0) ? i : (i % out_stride[axis - 1]);
-    const int out_mid_index = out_rest_index / in_stride[axis];
-    const int out_tail_index = (axis == in_ndims - 1) ?
-                               0 : (out_rest_index % in_stride[axis]);
-    int idx_index = static_cast<int>(idx[out_mid_index]);
+    const int64_t out_head_index = (axis == 0) ? 0 : (i / out_stride[axis - 1]);
+    const int64_t out_rest_index = (axis == 0) ? i : (i % out_stride[axis - 1]);
+    const int64_t out_mid_index = out_rest_index / in_stride[axis];
+    const int64_t out_tail_index = (axis == in_ndims - 1) ?
+                                   0 : (out_rest_index % in_stride[axis]);
+    int64_t idx_index = static_cast<int64_t>(idx[out_mid_index]);
     if (clip) {
       idx_index = (idx_index < 0) ? 0 : idx_index;
       idx_index = (idx_index > axis_dim - 1) ? (axis_dim - 1) : idx_index;
     }
     idx_index %= axis_dim;
     idx_index += (idx_index < 0) ? axis_dim : 0;
-    const int in_tail_index = out_tail_index;
-    const int in_head_index = out_head_index;
-    int in_src_index = in_tail_index + idx_index * in_stride[axis];
+    const int64_t in_tail_index = out_tail_index;
+    const int64_t in_head_index = out_head_index;
+    int64_t in_src_index = in_tail_index + idx_index * in_stride[axis];
     in_src_index += (axis == 0) ? 0 : in_head_index * in_stride[axis - 1];
     out_data[i] = in_data[in_src_index];
   }
@@ -365,24 +346,7 @@ void EmbeddingOpForwardDnsImpl(mshadow::Stream<xpu>* s,
                                const TBlob& data,
                                const TBlob& weight,
                                const OpReqType req,
-                               const TBlob& output) {
-  using namespace mxnet_op;
-  const TShape& ishape = data.shape_;
-  const TShape& oshape = output.shape_;
-
-  MSHADOW_TYPE_SWITCH(output.type_flag_, DType, {
-    MSHADOW_TYPE_SWITCH(data.type_flag_, IType, {
-      Tensor<xpu, 1, IType> idx = data.get_with_shape<xpu, 1, IType>(
-        Shape1(ishape.ProdShape(0, ishape.ndim())), s);
-      Tensor<xpu, 2, DType> wmat = weight.get<xpu, 2, DType>(s);
-      Tensor<xpu, 2, DType> out = output.get_with_shape<xpu, 2, DType>(
-        Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s);
-      Kernel<Take<true>, xpu>::Launch(s, oshape.Size(), out.dptr_, wmat.dptr_,
-                                idx.dptr_, wmat.shape_[1], wmat.shape_[0]);
-    });
-  });
-}
-
+                               const TBlob& output);
 
 template<int req>
 struct TakeRspKernel {
@@ -825,66 +789,7 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs,
                    const OpContext& ctx,
                    const std::vector<TBlob>& inputs,
                    const std::vector<OpReqType>& req,
-                   const std::vector<TBlob>& outputs) {
-  using namespace mxnet_op;
-  if (req[take_::kOut] == kNullOp) return;
-  const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
-  CHECK_EQ(inputs.size(), 2U);
-  CHECK_EQ(outputs.size(), 1U);
-
-  const TShape& idxshape = inputs[take_::kIdx].shape_;
-  const TShape& arrshape = inputs[take_::kArr].shape_;
-  const TShape& oshape = outputs[take_::kOut].shape_;
-
-  Stream<xpu> *s = ctx.get_stream<xpu>();
-  const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0);
-
-  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {  // output data type
-    MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {  // index data type
-      if (actual_axis == 0) {
-        if (param.mode == take_::kClip) {
-          Kernel<Take<true>, xpu>::Launch(s, oshape.Size(),
-                                          outputs[take_::kOut].dptr<DType>(),
-                                          inputs[take_::kArr].dptr<DType>(),
-                                          inputs[take_::kIdx].dptr<IType>(),
-                                          oshape.Size()/idxshape.Size(), arrshape[0]);
-        } else {
-          Kernel<Take<false>, xpu>::Launch(s, oshape.Size(),
-                                           outputs[take_::kOut].dptr<DType>(),
-                                           inputs[take_::kArr].dptr<DType>(),
-                                           inputs[take_::kIdx].dptr<IType>(),
-                                           oshape.Size()/idxshape.Size(), arrshape[0]);
-        }
-      } else {
-        mshadow::Shape<10> in_strides;
-        int stride = 1;
-        for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) {
-          in_strides[i] = stride;
-        }
-        mshadow::Shape<10> out_strides;
-        stride = 1;
-        for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) {
-          out_strides[i] = stride;
-        }
-        if (param.mode == take_::kClip) {
-          Kernel<Take<true>, xpu>::Launch(s, oshape.Size(),
-                                          outputs[take_::kOut].dptr<DType>(),
-                                          inputs[take_::kArr].dptr<DType>(),
-                                          inputs[take_::kIdx].dptr<IType>(),
-                                          in_strides, out_strides, arrshape.ndim(), oshape.ndim(),
-                                          idxshape.ndim(), arrshape[actual_axis], actual_axis);
-        } else if (param.mode == take_::kWrap) {
-          Kernel<Take<false>, xpu>::Launch(s, oshape.Size(),
-                                           outputs[take_::kOut].dptr<DType>(),
-                                           inputs[take_::kArr].dptr<DType>(),
-                                           inputs[take_::kIdx].dptr<IType>(),
-                                           in_strides, out_strides, arrshape.ndim(), oshape.ndim(),
-                                           idxshape.ndim(), arrshape[actual_axis], actual_axis);
-        }
-      }
-    });
-  });
-}
+                   const std::vector<TBlob>& outputs);
 
 struct TakeGradGeneralKernel {
   /*!