You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/01 19:19:56 UTC

[GitHub] piiswrong closed pull request #9882: Add force_deterministic option for sparse embedding

piiswrong closed pull request #9882: Add force_deterministic option for sparse embedding
URL: https://github.com/apache/incubator-mxnet/pull/9882
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/tensor/indexing_op-inl.cuh b/src/operator/tensor/indexing_op-inl.cuh
index 4458151f178..34cc2630254 100644
--- a/src/operator/tensor/indexing_op-inl.cuh
+++ b/src/operator/tensor/indexing_op-inl.cuh
@@ -200,57 +200,15 @@ AddTakeGradLargeBatchWorkspaceSize(size_t num_keys) {
 }
 
 template<typename IndexType, typename DType>
-inline void AddTakeGradLargeBatch(mshadow::Tensor<gpu, 2, DType> dst,
-                                  const mshadow::Tensor<gpu, 1, IndexType>& sorted,
-                                  const mshadow::Tensor<gpu, 1, IndexType>& index,
-                                  const mshadow::Tensor<gpu, 2, DType> &src,
-                                  mshadow::Tensor<gpu, 1, char>* workspace) {
-  CHECK_EQ(dst.CheckContiguous(), true);
-  CHECK_EQ(sorted.CheckContiguous(), true);
-  CHECK_EQ(index.CheckContiguous(), true);
-  CHECK_EQ(src.CheckContiguous(), true);
-  // const int kWarpBits = kMemUnitBits;
+inline void AddTakeGradLargeBatchKernelLaunch(mshadow::Tensor<gpu, 2, DType> dst,
+                                              const mshadow::Tensor<gpu, 1, IndexType>& sorted,
+                                              const mshadow::Tensor<gpu, 1, IndexType>& index,
+                                              const mshadow::Tensor<gpu, 2, DType> &src,
+                                              IndexType* sum_counts_ptr,
+                                              int* num_runs_ptr,
+                                              const mshadow::index_t num_rows) {
   cudaStream_t stream = mshadow::Stream<gpu>::GetStream(dst.stream_);
-  IndexType* sum_counts_ptr = NULL;
-  int* num_runs_ptr = NULL;
-  if (dst.size(0)*4 < src.size(0) && workspace != NULL) {
-    // Workspace given and potentially loops at least 4 times, use CUB to create sum_counts
-    CHECK_EQ(workspace->CheckContiguous(), true);
-    // workspace = [unique_out, counts_out, temporary_storage]
-    size_t unique_bytes = sorted.size(0)*sizeof(IndexType);
-    size_t counts_bytes = sorted.size(0)*sizeof(IndexType);
-    size_t num_runs_bytes = 1*sizeof(int);
-
-    size_t encode_bytes = 0;
-    cub::DeviceRunLengthEncode::Encode<IndexType*, IndexType*, IndexType*, int*>
-      (NULL, encode_bytes, NULL, NULL, NULL, NULL, sorted.size(0), stream);
-    size_t exclusivesum_bytes = 0;
-    cub::DeviceScan::ExclusiveSum<IndexType*, IndexType*>
-      (NULL, exclusivesum_bytes, NULL, NULL, sorted.size(0), stream);
-    size_t temporary_bytes = std::max(encode_bytes, exclusivesum_bytes);
-
-    // Check that we have enough storage
-    CHECK_GE(workspace->size(0), unique_bytes + counts_bytes +
-      num_runs_bytes + temporary_bytes);
-
-    IndexType* unique_out_ptr = reinterpret_cast<IndexType*>(workspace->dptr_);
-    IndexType* counts_out_ptr = reinterpret_cast<IndexType*>(workspace->dptr_ + unique_bytes);
-    num_runs_ptr = reinterpret_cast<int*>(workspace->dptr_ + unique_bytes +
-      counts_bytes);
-    void* temporary_storage = reinterpret_cast<void *>(workspace->dptr_ + unique_bytes +
-      counts_bytes + num_runs_bytes);
-
-    cub::DeviceRunLengthEncode::Encode<IndexType*, IndexType*, IndexType*, int*>
-    (temporary_storage, temporary_bytes, sorted.dptr_, unique_out_ptr, counts_out_ptr,
-      num_runs_ptr, sorted.size(0), stream);
-
-    sum_counts_ptr = unique_out_ptr;
-    cub::DeviceScan::ExclusiveSum<IndexType*, IndexType*>
-    (temporary_storage, temporary_bytes, counts_out_ptr, sum_counts_ptr,
-      sorted.size(0), stream);
-  }
-
-  const int num_unique_est = min(dst.size(0), src.size(0));
+  const int num_unique_est = min(num_rows, src.size(0));
   const int max_nthread = 128;
   const int num_y = max(src.size(0)/num_unique_est, 1);
   const int block_dim_x = kWarpSize;
@@ -307,6 +265,61 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor<gpu, 2, DType> dst,
   MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradLargeBatchKernel);
 }
 
+
+template<typename IndexType, typename DType>
+inline void AddTakeGradLargeBatch(mshadow::Tensor<gpu, 2, DType> dst,
+                                  const mshadow::Tensor<gpu, 1, IndexType>& sorted,
+                                  const mshadow::Tensor<gpu, 1, IndexType>& index,
+                                  const mshadow::Tensor<gpu, 2, DType> &src,
+                                  mshadow::Tensor<gpu, 1, char>* workspace) {
+  CHECK_EQ(dst.CheckContiguous(), true);
+  CHECK_EQ(sorted.CheckContiguous(), true);
+  CHECK_EQ(index.CheckContiguous(), true);
+  CHECK_EQ(src.CheckContiguous(), true);
+  // const int kWarpBits = kMemUnitBits;
+  cudaStream_t stream = mshadow::Stream<gpu>::GetStream(dst.stream_);
+  IndexType* sum_counts_ptr = NULL;
+  int* num_runs_ptr = NULL;
+  if (dst.size(0)*4 < src.size(0) && workspace != NULL) {
+    // Workspace given and potentially loops at least 4 times, use CUB to create sum_counts
+    CHECK_EQ(workspace->CheckContiguous(), true);
+    // workspace = [unique_out, counts_out, temporary_storage]
+    size_t unique_bytes = sorted.size(0)*sizeof(IndexType);
+    size_t counts_bytes = sorted.size(0)*sizeof(IndexType);
+    size_t num_runs_bytes = 1*sizeof(int);
+
+    size_t encode_bytes = 0;
+    cub::DeviceRunLengthEncode::Encode<IndexType*, IndexType*, IndexType*, int*>
+      (NULL, encode_bytes, NULL, NULL, NULL, NULL, sorted.size(0), stream);
+    size_t exclusivesum_bytes = 0;
+    cub::DeviceScan::ExclusiveSum<IndexType*, IndexType*>
+      (NULL, exclusivesum_bytes, NULL, NULL, sorted.size(0), stream);
+    size_t temporary_bytes = std::max(encode_bytes, exclusivesum_bytes);
+
+    // Check that we have enough storage
+    CHECK_GE(workspace->size(0), unique_bytes + counts_bytes +
+      num_runs_bytes + temporary_bytes);
+
+    IndexType* unique_out_ptr = reinterpret_cast<IndexType*>(workspace->dptr_);
+    IndexType* counts_out_ptr = reinterpret_cast<IndexType*>(workspace->dptr_ + unique_bytes);
+    num_runs_ptr = reinterpret_cast<int*>(workspace->dptr_ + unique_bytes +
+      counts_bytes);
+    void* temporary_storage = reinterpret_cast<void *>(workspace->dptr_ + unique_bytes +
+      counts_bytes + num_runs_bytes);
+
+    cub::DeviceRunLengthEncode::Encode<IndexType*, IndexType*, IndexType*, int*>
+    (temporary_storage, temporary_bytes, sorted.dptr_, unique_out_ptr, counts_out_ptr,
+      num_runs_ptr, sorted.size(0), stream);
+
+    sum_counts_ptr = unique_out_ptr;
+    cub::DeviceScan::ExclusiveSum<IndexType*, IndexType*>
+    (temporary_storage, temporary_bytes, counts_out_ptr, sum_counts_ptr,
+      sorted.size(0), stream);
+  }
+  AddTakeGradLargeBatchKernelLaunch(dst, sorted, index, src, sum_counts_ptr,
+                                    num_runs_ptr, dst.size(0));
+}
+
 }  // namespace op
 }  // namespace mxnet
 #endif  // MXNET_OPERATOR_TENSOR_INDEXING_OP_CUH_
diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc
index cce4537ae3a..bb65419a79c 100644
--- a/src/operator/tensor/indexing_op.cc
+++ b/src/operator/tensor/indexing_op.cc
@@ -70,7 +70,8 @@ void SparseEmbeddingOpForwardRspImpl<cpu>(const OpContext& ctx,
 
 
 template<>
-inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const OpContext& ctx,
+inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const SparseEmbeddingParam& param,
+                                                  const OpContext& ctx,
                                                   const TBlob& ograd,
                                                   const TBlob& data,
                                                   const OpReqType req,
@@ -178,6 +179,7 @@ GatherNDBackwardImpl(int N, int M, int K,
 }
 
 DMLC_REGISTER_PARAMETER(EmbeddingParam);
+DMLC_REGISTER_PARAMETER(SparseEmbeddingParam);
 DMLC_REGISTER_PARAMETER(TakeParam);
 DMLC_REGISTER_PARAMETER(OneHotParam);
 DMLC_REGISTER_PARAMETER(ScatterNDParam);
@@ -230,8 +232,8 @@ Examples::
   [](const NodeAttrs& attrs) {
     return std::vector<std::string>{"data", "weight"};
   })
-.set_attr<nnvm::FInferShape>("FInferShape", EmbeddingOpShape)
-.set_attr<nnvm::FInferType>("FInferType", EmbeddingOpType)
+.set_attr<nnvm::FInferShape>("FInferShape", EmbeddingOpShape<EmbeddingParam>)
+.set_attr<nnvm::FInferType>("FInferType", EmbeddingOpType<EmbeddingParam>)
 .set_attr<FResourceRequest>("FResourceRequest",
   [](const NodeAttrs& attrs) {
     return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
@@ -268,6 +270,11 @@ The storage type of weight must be `row_sparse`, and the gradient of the weight
 
     `SparseEmbedding` is designed for the use case where `input_dim` is very large (e.g. 100k).
     The operator is available on both CPU and GPU.
+    When `deterministic` is set to `True`, the accumulation of gradients follows a
+    deterministic order if a feature appears multiple times in the input. However, the
+    accumulation is usually slower when the order is enforced.
+    When the operator is used in recurrent neural network models on the GPU,
+    the recommended value for `deterministic` is `True`.
 
 Examples::
 
@@ -294,7 +301,7 @@ Examples::
 )code" ADD_FILELINE)
 .set_num_inputs(2)
 .set_num_outputs(1)
-.set_attr_parser(ParamParser<EmbeddingParam>)
+.set_attr_parser(ParamParser<SparseEmbeddingParam>)
 .set_attr<nnvm::FListInputNames>("FListInputNames",
   [](const NodeAttrs& attrs) {
     return std::vector<std::string>{"data", "weight"};
@@ -303,8 +310,8 @@ Examples::
   [](const NodeAttrs& attrs) {
     return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
   })
-.set_attr<nnvm::FInferShape>("FInferShape", EmbeddingOpShape)
-.set_attr<nnvm::FInferType>("FInferType", EmbeddingOpType)
+.set_attr<nnvm::FInferShape>("FInferShape", EmbeddingOpShape<SparseEmbeddingParam>)
+.set_attr<nnvm::FInferType>("FInferType", EmbeddingOpType<SparseEmbeddingParam>)
 .set_attr<FInferStorageType>("FInferStorageType", SparseEmbeddingOpForwardStorageType)
 .set_attr<FComputeEx>("FComputeEx<cpu>", SparseEmbeddingOpForwardEx<cpu>)
 .set_attr<nnvm::FGradient>("FGradient",
@@ -327,6 +334,7 @@ NNVM_REGISTER_OP(_backward_Embedding)
 .set_attr<FCompute>("FCompute<cpu>", EmbeddingOpBackward<cpu>);
 
 NNVM_REGISTER_OP(_backward_SparseEmbedding)
+.set_attr_parser(ParamParser<SparseEmbeddingParam>)
 .set_num_inputs(2)
 .set_num_outputs(2)
 .set_attr<FResourceRequest>("FResourceRequest",
diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu
index 762d8fd64c2..5cdf5060aec 100644
--- a/src/operator/tensor/indexing_op.cu
+++ b/src/operator/tensor/indexing_op.cu
@@ -60,6 +60,75 @@ struct AddTakeGradRspGPUKernel {
   }
 };
 
+/*
+ * \brief kernel for backward computation for take, executed with deterministic order
+ * \param thread_id the thread id
+ * \param out the output gradient data
+ * \param lookup_table the table to lookup the position of an id in gradient array
+ * \param sorted_data the sorted data input
+ * \param original_idx the original indices of the sorted data input
+ * \param ograd head gradient
+ * \param row_length the output dimension
+ * \param num_threads_per_row the number of threads to process a row together
+ * \param SZ the number of features a thread is responsible for
+ */
+template<int SZ>
+struct AddTakeGradRspDeterministicKernel {
+  template<typename DType>
+  __device__ __forceinline__ static void Map(int thread_id,
+                                             DType* out,
+                                             const nnvm::dim_t* lookup_table,
+                                             const nnvm::dim_t* sorted_data,
+                                             const nnvm::dim_t data_size,
+                                             const nnvm::dim_t* original_idx,
+                                             const DType* ograd,
+                                             const nnvm::dim_t row_length,
+                                             const nnvm::dim_t num_threads_per_row) {
+    using nnvm::dim_t;
+    int tid = thread_id / num_threads_per_row;
+    const int feature_start = thread_id % num_threads_per_row * SZ;
+    int num_features = SZ;
+    if (feature_start + num_features > row_length) {
+      num_features = row_length - feature_start;
+    }
+    if (tid == 0 || sorted_data[tid - 1] != sorted_data[tid]) {
+      DType acc[SZ];
+      #pragma unroll
+      for (int i = 0; i < SZ; i++) {
+        acc[i] = 0;
+      }
+      const dim_t data = sorted_data[tid];
+      const dim_t row_id = lookup_table[data];
+      const dim_t out_offset = row_id * row_length + feature_start;
+      do {
+        const dim_t idx = original_idx[tid];
+        const dim_t ograd_offset = idx * row_length + feature_start;
+        for (int i = 0; i < num_features; i++) {
+          acc[i] += ograd[ograd_offset + i];
+        }
+        tid++;
+      } while (tid < data_size && sorted_data[tid - 1] == sorted_data[tid]);
+      for (int i = 0; i < num_features; i++) {
+        out[out_offset + i] += acc[i];
+      }
+    }
+  }
+};
+
+/*
+ * \brief the kernel to generate a lookup table for positions of row ids
+ * \param i thread id
+ * \param out output table
+ * \param data the input row id in sorted order
+ */
+struct mark_lookup_table {
+  template<typename IType, typename DType>
+  MSHADOW_XINLINE static void Map(int i, IType* out, const DType* data) {
+    out[static_cast<nnvm::dim_t>(data[i])] = i;
+  }
+};
+
+
 template<>
 void SparseEmbeddingOpForwardRspImpl<gpu>(const OpContext& ctx,
                                           const TBlob& data,
@@ -103,13 +172,138 @@ void SparseEmbeddingOpForwardRspImpl<gpu>(const OpContext& ctx,
   }
 }
 
+template<typename IType, typename DType, typename RType>
+void SparseEmbeddingDeterministicKernelLaunch(const OpContext& ctx,
+                                              const TBlob& ograd,
+                                              const TBlob& data,
+                                              const OpReqType req,
+                                              const NDArray& output) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using namespace expr;
+  using namespace rowsparse;
+  using nnvm::dim_t;
+  mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
+  const dim_t num_rows = output.shape()[0];
+  const dim_t row_length = output.shape()[1];
+  const dim_t data_size = static_cast<dim_t>(data.shape_.Size());
+  // temp resource declarations
+  dim_t* lookup_table = NULL;
+  void* temp_storage = NULL;
+  dim_t* sorted_data = NULL;
+  dim_t* original_idx = NULL;
+  // calculate number of bytes for temp resources
+  size_t lookup_table_bytes = num_rows * sizeof(dim_t);
+  size_t sorted_data_storage_bytes = data_size * sizeof(dim_t);
+  size_t original_idx_storage_bytes = data_size * sizeof(dim_t);
+  size_t sort_workspace_size = SortByKeyWorkspaceSize<dim_t, dim_t, gpu>(data_size);
+  size_t unique_workspace_bytes = 0;
+  // estimate unique temp space
+  IType* data_ptr = data.dptr<IType>();
+  size_t *null_ptr = nullptr;
+  cub::DeviceSelect::Unique(NULL, unique_workspace_bytes, data_ptr, data_ptr,
+    null_ptr, data_size, Stream<gpu>::GetStream(s));
+  // One more space reserved for unique count
+  size_t temp_workspace_bytes = std::max(unique_workspace_bytes,
+                                         sort_workspace_size);
+  size_t total_storage_bytes = lookup_table_bytes + sorted_data_storage_bytes +
+                               original_idx_storage_bytes + temp_workspace_bytes;
+
+  // request resource and split it. layout is:
+  // lookup_table, sorted_data, original_idx, temp_storage
+  Tensor<gpu, 1, char> workspace = ctx.requested[0]
+      .get_space_typed<gpu, 1, char>(Shape1(total_storage_bytes), s);
+  lookup_table = reinterpret_cast<dim_t*>(workspace.dptr_);
+  sorted_data = reinterpret_cast<dim_t*>(workspace.dptr_ + lookup_table_bytes);
+  original_idx = reinterpret_cast<dim_t*>(workspace.dptr_ + lookup_table_bytes +
+                                          sorted_data_storage_bytes);
+  temp_storage = workspace.dptr_ + total_storage_bytes - temp_workspace_bytes;
+
+  // make a copy of the data, to be sorted
+  TBlob sorted_data_blob(sorted_data, Shape1(data_size), gpu::kDevMask);
+  auto sorted_data_tensor = sorted_data_blob.FlatTo1D<gpu, dim_t>(s);
+  mxnet_op::copy(s, sorted_data_blob, data);
+
+  // generate original idx
+  Tensor<gpu, 1, dim_t> original_idx_tensor(original_idx, Shape1(data_size), s);
+  Kernel<range_fwd, gpu>::Launch(s, data_size, 1, static_cast<dim_t>(0),
+                                 static_cast<dim_t>(1), kWriteTo, original_idx);
+  // sort data with its original idx
+  int num_bits = ilog2(num_rows - 1);
+  char* temp_storage_ptr = reinterpret_cast<char*>(temp_storage);
+  Tensor<gpu, 1, char> temp_storage_tensor(temp_storage_ptr,
+                                           Shape1(sort_workspace_size), s);
+  SortByKey(sorted_data_tensor, original_idx_tensor, true,
+            &temp_storage_tensor, 0, num_bits);
+
+  // compute unique row ids based on sorted values.
+  output.CheckAndAllocAuxData(kIdx, Shape1(data_size + 1));
+
+  // fill row_idx array of output matrix, using the row_flg values
+  RType* grad_row_idx = output.aux_data(kIdx).dptr<RType>();
+  cub::DeviceSelect::Unique(temp_storage_ptr, unique_workspace_bytes, sorted_data,
+      grad_row_idx, grad_row_idx + data_size, data_size, Stream<gpu>::GetStream(s));
+
+  dim_t nnr = 0;
+  CUDA_CALL(cudaMemcpy(&nnr, grad_row_idx + data_size, sizeof(RType),
+      cudaMemcpyDeviceToHost));
+  CHECK_EQ(output.shape().ndim(), 2) << "Unexcepted ndim";
+  output.CheckAndAllocData(Shape2(nnr, output.shape()[1]));
+  output.set_aux_shape(kIdx, Shape1(nnr));
+
+  // generate lookup table
+  Kernel<mark_lookup_table, gpu>::Launch(s, nnr, lookup_table, grad_row_idx);
+
+  // accumulate gradients
+  DType* grad_data = output.data().dptr<DType>();
+  Fill<false>(s, TBlob(grad_data, Shape1(nnr * row_length), gpu::kDevMask),
+              kWriteTo, 0);
+  const int SZ = 4;
+  const nnvm::dim_t num_threads_per_row = (row_length + SZ - 1) / SZ;
+  Kernel<AddTakeGradRspDeterministicKernel<SZ>, gpu>::Launch(s, data_size * num_threads_per_row,
+                     grad_data, lookup_table, sorted_data, data_size, original_idx,
+                     ograd.dptr<DType>(), row_length, num_threads_per_row);
+}
+
+inline void SparseEmbeddingOpBackwardDeterministicRspImpl(const OpContext& ctx,
+                                                          const TBlob& ograd,
+                                                          const TBlob& data,
+                                                          const OpReqType req,
+                                                          const NDArray& output) {
+  using nnvm::dim_t;
+  if (req == kNullOp) return;
+  CHECK_EQ(req, kWriteTo) << "SparseEmbedding layer doesn't support "
+                          << "weight gradient calculation with req != write";
+
+  mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
+  const dim_t data_size = static_cast<dim_t>(data.shape_.Size());
+  if (data_size == 0) {
+    FillZerosRspImpl(s, output);
+    return;
+  }
+
+  MSHADOW_TYPE_SWITCH(data.type_flag_, IType, {
+    MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
+      MSHADOW_IDX_TYPE_SWITCH(output.aux_type(rowsparse::kIdx), RType, {
+        SparseEmbeddingDeterministicKernelLaunch<IType, DType, RType>(ctx, ograd, data,
+                                                                      req, output);
+      });
+    });
+  });
+}
+
 
 template<>
-inline void SparseEmbeddingOpBackwardRspImpl<gpu>(const OpContext& ctx,
+inline void SparseEmbeddingOpBackwardRspImpl<gpu>(const SparseEmbeddingParam& param,
+                                                  const OpContext& ctx,
                                                   const TBlob& ograd,
                                                   const TBlob& data,
                                                   const OpReqType req,
                                                   const NDArray& output) {
+  if (param.deterministic) {
+    SparseEmbeddingOpBackwardDeterministicRspImpl(ctx, ograd, data, req, output);
+    return;
+  }
   using namespace mshadow;
   using namespace mxnet_op;
   using namespace mshadow::expr;
@@ -156,7 +350,6 @@ inline void SparseEmbeddingOpBackwardRspImpl<gpu>(const OpContext& ctx,
         dim_t nnr = 0;
         CUDA_CALL(cudaMemcpy(&nnr, &prefix_sum[num_rows-1], sizeof(dim_t),
             cudaMemcpyDeviceToHost));
-
         if (nnr == 0) {
           FillZerosRspImpl(s, output);
           return;
diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h
index 1888a417972..45bf45f14fc 100644
--- a/src/operator/tensor/indexing_op.h
+++ b/src/operator/tensor/indexing_op.h
@@ -57,6 +57,29 @@ enum EmbeddingOpResource {kTempSpace};
 }  // namespace embedding
 
 
+struct SparseEmbeddingParam: public dmlc::Parameter<SparseEmbeddingParam> {
+  int input_dim;
+  int output_dim;
+  int dtype;
+  bool deterministic;
+  DMLC_DECLARE_PARAMETER(SparseEmbeddingParam) {
+    DMLC_DECLARE_FIELD(input_dim).set_lower_bound(1)
+    .describe("Vocabulary size of the input indices.");
+    DMLC_DECLARE_FIELD(output_dim).set_lower_bound(1)
+    .describe("Dimension of the embedding vectors.");
+    DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
+    .add_enum("float32", mshadow::kFloat32)
+    .add_enum("float64", mshadow::kFloat64)
+    .add_enum("float16", mshadow::kFloat16)
+    .add_enum("uint8", mshadow::kUint8)
+    .add_enum("int32", mshadow::kInt32)
+    .describe("Data type of weight.");
+    DMLC_DECLARE_FIELD(deterministic).set_default(false)
+    .describe("Force the backward gradient calculation to be executed based on a deterministic"
+               " order at the cost of slower speed.");
+  }
+};
+
 struct EmbeddingParam: public dmlc::Parameter<EmbeddingParam> {
   int input_dim;
   int output_dim;
@@ -130,14 +153,14 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor<gpu, 2, DType> dst,
                                   const mshadow::Tensor<gpu, 1, IndexType>& index,
                                   const mshadow::Tensor<gpu, 2, DType> &src,
                                   mshadow::Tensor<gpu, 1, char>* workspace = NULL);
-
+template<typename ParamType>
 inline bool EmbeddingOpShape(const nnvm::NodeAttrs& attrs,
                              std::vector<TShape> *in_attrs,
                              std::vector<TShape> *out_attrs) {
   using namespace mshadow;
   const TShape &dshape = (*in_attrs)[embedding::kData];
   if (dshape.ndim() ==  0) return false;
-  const EmbeddingParam& param = nnvm::get<EmbeddingParam>(attrs.parsed);
+  const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
   SHAPE_ASSIGN_CHECK(*in_attrs, embedding::kWeight, Shape2(param.input_dim,
                                                            param.output_dim));
   out_attrs->clear();
@@ -152,10 +175,11 @@ inline bool EmbeddingOpShape(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
+template<typename ParamType>
 inline bool EmbeddingOpType(const nnvm::NodeAttrs& attrs,
                             std::vector<int> *in_type,
                             std::vector<int> *out_type) {
-  const EmbeddingParam& param = nnvm::get<EmbeddingParam>(attrs.parsed);
+  const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
   CHECK_EQ(in_type->size(), 2U);
   CHECK_GE(out_type->size(), 1U);
   int itype = (*in_type)[0];
@@ -219,6 +243,11 @@ inline bool SparseEmbeddingOpBackwardStorageType(const nnvm::NodeAttrs& attrs,
       dispatched = true;
     }
   }
+  const SparseEmbeddingParam& param = nnvm::get<SparseEmbeddingParam>(attrs.parsed);
+  if (param.deterministic) {
+    common::LogOnce("_SparseEmbedding_backward with deterministic=True may reduce "
+                    "speed significantly");
+  }
   return dispatched;
 }
 /*! \brief name the struct Take instead of take
@@ -560,7 +589,8 @@ struct AddTakeGradRspKernel {
 };
 
 template<typename xpu>
-inline void SparseEmbeddingOpBackwardRspImpl(const OpContext& ctx,
+inline void SparseEmbeddingOpBackwardRspImpl(const SparseEmbeddingParam& param,
+                                             const OpContext& ctx,
                                              const TBlob& ograd,
                                              const TBlob& data,
                                              const OpReqType req,
@@ -582,9 +612,10 @@ void SparseEmbeddingOpBackwardEx(const nnvm::NodeAttrs& attrs,
   // check req
   CHECK_EQ(req[embedding::kData], kNullOp)
           << "SparseEmbedding layer doesn't support calculate data gradient";
+  const SparseEmbeddingParam& param = nnvm::get<SparseEmbeddingParam>(attrs.parsed);
   if (data.storage_type() == kDefaultStorage && ograd.storage_type() == kDefaultStorage &&
       weight_grad.storage_type() == kRowSparseStorage) {
-    SparseEmbeddingOpBackwardRspImpl<xpu>(ctx, ograd.data(), data.data(),
+    SparseEmbeddingOpBackwardRspImpl<xpu>(param, ctx, ograd.data(), data.data(),
                                           req[embedding::kWeight], weight_grad);
   } else {
     LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py
index 84dfc5878c2..e0d25da0344 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -1615,43 +1615,45 @@ def check_sparse_elementwise_sum_with_shape(stype, shape, n):
 
 
 def test_sparse_embedding():
-    ''' test sparse embedding op on cpu '''
-    def check_sparse_embedding(executor, weight_ref, data_onehot, grad, density):
-        # update weight based on density
-        weight[:] = rand_ndarray(weight.shape, 'row_sparse', density=density)
-        # check forward
-        executor.forward(is_train=True)
-        assert_almost_equal(executor.outputs[0].asnumpy(), np.dot(data_onehot, weight.asnumpy()))
-        # check backward
-        executor.backward([grad])
-        assert_almost_equal(grad_map["embed_weight"].asnumpy(), np.dot(data_onehot.T, grad.asnumpy()))
+    ''' test sparse embedding operator '''
+    def check_sparse_embedding(in_dim, out_dim, batch, densities, deterministic):
+        # init executor
+        data = mx.sym.Variable("data")
+        weight = mx.sym.Variable("embed_weight", stype='row_sparse')
+        embed = mx.sym.contrib.SparseEmbedding(data=data, weight=weight, input_dim=in_dim,
+                                               output_dim=out_dim, deterministic=deterministic,
+                                               name="embed")
+        grad_req = {'data': 'null', 'embed_weight': 'write'}
+        exe_test = embed.simple_bind(default_context(), grad_req=grad_req, data=(batch,))
+        arg_map = dict(zip(embed.list_arguments(), exe_test.arg_arrays))
+        grad_map = dict(zip(embed.list_arguments(), exe_test.grad_arrays))
+        # init data
+        np_data = np.random.randint(low=0, high=in_dim, size=batch)
+        np_onehot = np.zeros((batch, in_dim)).astype(np.float32)
+        np_onehot[np.arange(batch), np_data] = 1.0
+        arg_map["data"][:] = np_data
+        # init grad
+        np_grad = np.random.uniform(-1, 1, exe_test.outputs[0].shape)
+        grad = mx.nd.zeros(np_grad.shape)
+        grad[:] = np_grad
+        # weight
+        weight = arg_map["embed_weight"]
+        for density in densities:
+            # update weight based on density
+            weight[:] = rand_ndarray(weight.shape, 'row_sparse', density=density)
+            # check forward
+            exe_test.forward(is_train=True)
+            assert_almost_equal(exe_test.outputs[0].asnumpy(), np.dot(np_onehot, weight.asnumpy()))
+            # check backward
+            exe_test.backward([grad])
+            assert_almost_equal(grad_map["embed_weight"].asnumpy(), np.dot(np_onehot.T, grad.asnumpy()))
 
     densities = [0, 0.5, 1]
     in_dim = 50
     out_dim = 3
     batch = 8
-    # init executor
-    data = mx.sym.Variable("data")
-    weight = mx.sym.Variable("embed_weight", stype='row_sparse')
-    embed = mx.sym.contrib.SparseEmbedding(data=data, weight=weight, input_dim=in_dim,
-                                           output_dim=out_dim, name="embed")
-    grad_req = {'data': 'null', 'embed_weight': 'write'}
-    exe_test = embed.simple_bind(default_context(), grad_req=grad_req, data=(batch,))
-    arg_map = dict(zip(embed.list_arguments(), exe_test.arg_arrays))
-    grad_map = dict(zip(embed.list_arguments(), exe_test.grad_arrays))
-    # init data
-    np_data = np.random.randint(low=0, high=in_dim, size=batch)
-    np_onehot = np.zeros((batch, in_dim))
-    np_onehot[np.arange(batch), np_data] = 1.0
-    arg_map["data"][:] = np_data
-    # init grad
-    np_grad = np.random.uniform(-1, 1, exe_test.outputs[0].shape)
-    grad = mx.nd.sparse.zeros('row_sparse', np_grad.shape)
-    grad[:] = np_grad
-    # weight
-    weight = arg_map["embed_weight"]
-    for density in densities:
-        check_sparse_embedding(exe_test, weight, np_onehot, grad, density)
+    check_sparse_embedding(in_dim, out_dim, batch, densities, True)
+    check_sparse_embedding(in_dim, out_dim, batch, densities, False)
 
 
 def test_scatter_ops():


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services