You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/16 23:13:04 UTC

[incubator-mxnet] branch master updated: sparse embedding operator, gpu implementation (#8647)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new c8f7dce  sparse embedding operator, gpu implementation (#8647)
c8f7dce is described below

commit c8f7dce0eb49ab1a62ddc2c7e37b93e9b92c2ae4
Author: Ziyue Huang <zy...@gmail.com>
AuthorDate: Fri Nov 17 07:13:00 2017 +0800

    sparse embedding operator, gpu implementation (#8647)
    
    * sparse embedding, gpu implementation
    
    * minor fix for batch_size
    
    * test
    
    * fix indentation
    
    * update
    
    * fix lint
    
    * fix
    
    * address comments
    
    * update
    
    * update
    
    * Update test_sparse_operator.py
---
 example/sparse/matrix_factorization.py        |   7 +-
 src/operator/tensor/indexing_op.cc            | 112 +++++++++++++++++-
 src/operator/tensor/indexing_op.cu            | 156 ++++++++++++++++++++++++++
 src/operator/tensor/indexing_op.h             | 113 ++-----------------
 tests/python/unittest/test_sparse_operator.py |  58 +++++-----
 5 files changed, 310 insertions(+), 136 deletions(-)

diff --git a/example/sparse/matrix_factorization.py b/example/sparse/matrix_factorization.py
index cdb6164..3387706 100644
--- a/example/sparse/matrix_factorization.py
+++ b/example/sparse/matrix_factorization.py
@@ -22,6 +22,8 @@ import mxnet as mx
 import numpy as np
 from get_data import get_movielens_iter, get_movielens_data
 from matrix_fact_model import matrix_fact_net
+
+
 logging.basicConfig(level=logging.DEBUG)
 
 parser = argparse.ArgumentParser(description="Run matrix factorization with sparse embedding",
@@ -36,6 +38,8 @@ parser.add_argument('--factor-size', type=int, default=128,
                     help="the factor size of the embedding operation")
 parser.add_argument('--use-dense', action='store_true',
                     help="use the dense embedding operator")
+parser.add_argument('--use-gpu', action='store_true',
+                    help="use gpu")
 parser.add_argument('--dummy-iter', action='store_true',
                     help="use the dummy data iterator for speed test")
 
@@ -63,7 +67,7 @@ if __name__ == '__main__':
     print_every = args.print_every
 
     momentum = 0.9
-    ctx = mx.cpu(0)
+    ctx = mx.gpu(0) if args.use_gpu else mx.cpu(0)
     learning_rate = 0.1
 
     # prepare dataset and iterators
@@ -75,7 +79,6 @@ if __name__ == '__main__':
 
     # construct the model
     net = matrix_fact_net(factor_size, factor_size, max_user, max_movies, sparse_embed=use_sparse)
-    a = time.time()
 
     # initialize the module
     mod = mx.module.Module(symbol=net, context=ctx, data_names=['user', 'item'],
diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc
index 273ebec..6fc54fc 100644
--- a/src/operator/tensor/indexing_op.cc
+++ b/src/operator/tensor/indexing_op.cc
@@ -26,6 +26,115 @@
 #include "./indexing_op.h"
 namespace mxnet {
 namespace op {
+
+template<>
+void SparseEmbeddingOpForwardRspImpl<cpu>(mshadow::Stream<cpu>* s,
+                                          const TBlob& data,
+                                          const NDArray& weight,
+                                          const OpReqType req,
+                                          const TBlob& output) {
+  if (req == kNullOp) return;
+  using namespace rowsparse;
+  using namespace mxnet_op;
+  // zeros weight
+  if (req == kWriteTo && !weight.storage_initialized()) {
+    size_t out_size = output.shape_.Size();
+    MSHADOW_TYPE_SWITCH(output.type_flag_, DType, {
+      Fill<false>(s, TBlob(output.dptr<DType>(), mshadow::Shape1(out_size),
+          cpu::kDevMask), kWriteTo, 0);
+    })
+    return;
+  }
+  // check out-of-bound indices
+  bool is_valid = true;
+  MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
+    DType min = 0;
+    DType max = static_cast<DType>(weight.shape()[0] - 1);
+    // check with single thread is faster since data is small
+    DType* data_ptr = data.dptr<DType>();
+    size_t data_size = data.shape_.Size();
+    for (size_t i = 0; i < data_size; i++) {
+      if (data_ptr[i] > max || data_ptr[i] < min) is_valid = false;
+    }
+  })
+  CHECK(is_valid) << "SparseEmbedding input contains data out of bound";
+  // the weight is actually dense
+  if (weight.aux_shape(kIdx)[0] == weight.shape()[0]) {
+    EmbeddingOpForwardDnsImpl<cpu>(s, data, weight.data(), req, output);
+  } else {
+    EmbeddingOpForwardRspImpl<cpu>(s, data, weight, req, output);
+  }
+}
+
+
+template<>
+inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const OpContext& ctx,
+                                                  const TBlob& ograd,
+                                                  const TBlob& data,
+                                                  const OpReqType req,
+                                                  const NDArray& output) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using namespace mshadow::expr;
+  using namespace rowsparse;
+  using nnvm::dim_t;
+  if (req == kNullOp) return;
+  CHECK_EQ(req, kWriteTo) << "SparseEmbedding layer doesn't support "
+                          << "weight gradient calculation with req != write";
+
+  // Request temporary storage for marking non-zero rows and prefix sum
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  dim_t num_rows = output.shape()[0];
+  dim_t row_length = output.shape()[1];
+  // TODO(haibin) request less storage to save space in the future
+  size_t workspace_size = 2 * (num_rows * sizeof(dim_t));
+  Tensor<cpu, 1, char> workspace =
+    ctx.requested[embedding::kTempSpace].get_space_typed<cpu, 1, char>(
+      Shape1(workspace_size), s);
+  dim_t* row_flg = reinterpret_cast<dim_t*>(workspace.dptr_);
+  dim_t* prefix_sum = row_flg + num_rows;
+  dim_t data_size = static_cast<dim_t>(data.shape_.Size());
+
+  MSHADOW_TYPE_SWITCH(data.type_flag_, IType, {
+    MSHADOW_SGL_DBL_TYPE_SWITCH(ograd.type_flag_, DType, {
+      MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), RType, {
+        // mark row flags
+        Fill<false>(s, TBlob(row_flg, Shape1(num_rows), cpu::kDevMask), kWriteTo, 0);
+        Kernel<MarkRowFlgKernel, cpu>::Launch(s, data_size, row_flg, data.dptr<IType>());
+        // calculate inclusive prefix sum
+        // TODO(haibin) ideally this is should be done in parallel
+        prefix_sum[0] = row_flg[0];
+        for (dim_t i = 1; i < num_rows; i++) {
+          prefix_sum[i] = prefix_sum[i - 1] + row_flg[i];
+        }
+        // total number of non-zero rows
+        dim_t nnr = prefix_sum[num_rows - 1];
+        if (nnr == 0) {
+          FillZerosRspImpl(s, output);
+          return;
+        }
+        output.CheckAndAlloc({Shape1(nnr)});
+        RType* grad_row_idx = output.aux_data(kIdx).dptr<RType>();
+        // fill row_idx array of output matrix, using the row_flg values
+        Kernel<FillRspRowIdxKernel, cpu>::Launch(s, num_rows,
+            grad_row_idx, prefix_sum, num_rows);
+        // prefill with zeros
+        DType* grad_data = output.data().dptr<DType>();
+        Fill<false>(s, TBlob(grad_data, Shape1(nnr * row_length),
+            cpu::kDevMask), kWriteTo, 0);
+        // add the final gradients
+        const int num_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+        dim_t segment_len = (nnr + num_threads - 1) / num_threads;
+        Kernel<AddTakeGradRspKernel, cpu>::Launch(s, num_threads, grad_data, prefix_sum,
+                                                  ograd.dptr<DType>(), row_length,
+                                                  data.dptr<IType>(), data_size, segment_len,
+                                                  num_rows);
+      });
+    });
+  });
+}
+
+
 DMLC_REGISTER_PARAMETER(EmbeddingParam);
 DMLC_REGISTER_PARAMETER(TakeParam);
 DMLC_REGISTER_PARAMETER(OneHotParam);
@@ -116,8 +225,7 @@ The storage type of weight must be `row_sparse`, and the gradient of the weight
 .. Note::
 
     `SparseEmbedding` is designed for the use case where `input_dim` is very large (e.g. 100k).
-    The `row_sparse` weight cannot be used in a `BucketingModule`.
-    The operator is only available on CPU.
+    The operator is available on both CPU and GPU.
 
 Examples::
 
diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu
index 2cddd00..2aba122 100644
--- a/src/operator/tensor/indexing_op.cu
+++ b/src/operator/tensor/indexing_op.cu
@@ -24,14 +24,170 @@
 */
 
 #include "./indexing_op.h"
+#include "./util/tensor_util-inl.cuh"
+
 namespace mxnet {
 namespace op {
+
+/*! \brief If there are out-of-bound indices, out will be assigned to 1.
+ */
+
+struct is_valid_check {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, int32_t* out, const DType* data,
+                                  const DType min, const DType max) {
+    if (data[i] < min || data[i] > max) *out = 1;
+  }
+};
+
+
+struct AddTakeGradRspGPUKernel {
+  template<typename DType, typename IType>
+  __device__ __forceinline__ static void Map(int tid,
+                                             DType* out,
+                                             const nnvm::dim_t* prefix_sum,
+                                             const IType* data,
+                                             const DType* ograd,
+                                             const nnvm::dim_t row_length) {
+    using nnvm::dim_t;
+    const dim_t data_i = tid / row_length;
+    const dim_t grad_i = tid % row_length;
+    const dim_t irow = static_cast<dim_t>(data[data_i]);
+    const dim_t rsp_row = prefix_sum[irow] - 1;
+    const DType val = ograd[data_i * row_length + grad_i];
+    atomicAdd(static_cast<DType *>(&(out[rsp_row*row_length+grad_i])), val);
+  }
+};
+
+template<>
+void SparseEmbeddingOpForwardRspImpl<gpu>(mshadow::Stream<gpu>* s,
+                                          const TBlob& data,
+                                          const NDArray& weight,
+                                          const OpReqType req,
+                                          const TBlob& output) {
+  if (req == kNullOp) return;
+  using namespace rowsparse;
+  using namespace mxnet_op;
+  // zeros weight
+  if (req == kWriteTo && !weight.storage_initialized()) {
+    size_t out_size = output.shape_.Size();
+    MSHADOW_TYPE_SWITCH(output.type_flag_, DType, {
+      Fill<false>(s, TBlob(output.dptr<DType>(), mshadow::Shape1(out_size),
+          gpu::kDevMask), kWriteTo, 0);
+    })
+    return;
+  }
+  // check out-of-bound indices
+  int32_t is_valid = 0;
+  MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
+    DType min = 0;
+    DType max = static_cast<DType>(weight.shape()[0] - 1);
+    DType* data_ptr = data.dptr<DType>();
+    size_t data_size = data.shape_.Size();
+    int32_t* is_valid_ptr = NULL;
+    CUDA_CALL(cudaMalloc(&is_valid_ptr, sizeof(int32_t)));
+    Kernel<set_zero, gpu>::Launch(s, 1, is_valid_ptr);
+    Kernel<is_valid_check, gpu>::Launch(s, data_size, is_valid_ptr, data_ptr, min, max);
+    CUDA_CALL(cudaMemcpy(&is_valid, is_valid_ptr, sizeof(int32_t),
+              cudaMemcpyDeviceToHost));
+  })
+  CHECK_EQ(is_valid, 0) << "SparseEmbedding input contains data out of bound";
+  // the weight is actually dense
+  if (weight.aux_shape(kIdx)[0] == weight.shape()[0]) {
+    EmbeddingOpForwardDnsImpl<gpu>(s, data, weight.data(), req, output);
+  } else {
+    EmbeddingOpForwardRspImpl<gpu>(s, data, weight, req, output);
+  }
+}
+
+
+template<>
+inline void SparseEmbeddingOpBackwardRspImpl<gpu>(const OpContext& ctx,
+                                                  const TBlob& ograd,
+                                                  const TBlob& data,
+                                                  const OpReqType req,
+                                                  const NDArray& output) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using namespace mshadow::expr;
+  using namespace rowsparse;
+  using nnvm::dim_t;
+  if (req == kNullOp) return;
+  CHECK_EQ(req, kWriteTo) << "SparseEmbedding layer doesn't support "
+                          << "weight gradient calculation with req != write";
+
+  // Request temporary storage for marking non-zero rows and prefix sum
+  Stream<gpu> *s = ctx.get_stream<gpu>();
+  dim_t num_rows = output.shape()[0];
+  dim_t row_length = output.shape()[1];
+  dim_t data_size = static_cast<dim_t>(data.shape_.Size());
+  dim_t num_threads;
+
+  MSHADOW_TYPE_SWITCH(data.type_flag_, IType, {
+    MSHADOW_SGL_DBL_TYPE_SWITCH(ograd.type_flag_, DType, {
+      MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), RType, {
+        dim_t* prefix_sum = NULL;
+        void* d_temp_storage = NULL;
+        size_t temp_storage_bytes = 0;
+        cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                      temp_storage_bytes,
+                                      prefix_sum,
+                                      prefix_sum,
+                                      num_rows,
+                                      Stream<gpu>::GetStream(s));
+        Tensor<gpu, 1, char> workspace = ctx.requested[0]
+            .get_space_typed<gpu, 1, char>(Shape1(num_rows * sizeof(dim_t) +
+                                           temp_storage_bytes), s);
+        prefix_sum = reinterpret_cast<dim_t*>(workspace.dptr_);
+        d_temp_storage = workspace.dptr_ + num_rows*sizeof(dim_t);
+        num_threads = num_rows;
+        Fill<false>(s, TBlob(prefix_sum, Shape1(num_threads), gpu::kDevMask), kWriteTo, 0);
+        Kernel<MarkRowFlgKernel, gpu>::Launch(s, data_size, prefix_sum, data.dptr<IType>());
+
+        cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                      temp_storage_bytes,
+                                      prefix_sum,
+                                      prefix_sum,
+                                      num_rows,
+                                      mshadow::Stream<gpu>::GetStream(s));
+        dim_t nnr = 0;
+        CUDA_CALL(cudaMemcpy(&nnr, &prefix_sum[num_rows-1], sizeof(dim_t),
+            cudaMemcpyDeviceToHost));
+
+        if (nnr == 0) {
+          FillZerosRspImpl(s, output);
+          return;
+        }
+        output.CheckAndAlloc({Shape1(nnr)});
+        RType* grad_row_idx = output.aux_data(kIdx).dptr<RType>();
+        // fill row_idx array of output matrix, using the row_flg values
+        Kernel<FillRspRowIdxKernel, gpu>::Launch(s, num_rows,
+            grad_row_idx, prefix_sum, num_rows);
+        // prefill with zeros
+        DType* grad_data = output.data().dptr<DType>();
+        Fill<false>(s, TBlob(grad_data, Shape1(nnr * row_length), gpu::kDevMask),
+            kWriteTo, 0);
+        // add the final gradients
+        num_threads = row_length * data_size;
+        Kernel<AddTakeGradRspGPUKernel, gpu>::Launch(s, num_threads, grad_data, prefix_sum,
+            data.dptr<IType>(), ograd.dptr<DType>(), row_length);
+      });
+    });
+  });
+}
+
 NNVM_REGISTER_OP(Embedding)
 .set_attr<FCompute>("FCompute<gpu>", EmbeddingOpForward<gpu>);
 
+NNVM_REGISTER_OP(_contrib_SparseEmbedding)
+.set_attr<FComputeEx>("FComputeEx<gpu>", SparseEmbeddingOpForwardEx<gpu>);
+
 NNVM_REGISTER_OP(_backward_Embedding)
 .set_attr<FCompute>("FCompute<gpu>", EmbeddingOpBackward<gpu>);
 
+NNVM_REGISTER_OP(_backward_SparseEmbedding)
+.set_attr<FComputeEx>("FComputeEx<gpu>", SparseEmbeddingOpBackwardEx<gpu>);
+
 NNVM_REGISTER_OP(take)
 .set_attr<FCompute>("FCompute<gpu>", TakeOpForward<gpu>);
 
diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h
index 684794b..7af5edd 100644
--- a/src/operator/tensor/indexing_op.h
+++ b/src/operator/tensor/indexing_op.h
@@ -187,9 +187,8 @@ inline bool SparseEmbeddingOpForwardStorageType(const nnvm::NodeAttrs& attrs,
   const int& weight_stype = in_attrs->at(embedding::kWeight);
   int& out_stype = out_attrs->at(embedding::kOut);
   bool dispatched = false;
-  const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
   if (!dispatched && data_stype == kDefaultStorage &&
-      weight_stype == kRowSparseStorage && !invalid_ctx) {
+      weight_stype == kRowSparseStorage) {
     // dns, rsp -> dns
     dispatched = storage_type_assign(&out_stype, kDefaultStorage,
                                      dispatch_mode, DispatchMode::kFComputeEx);
@@ -215,9 +214,8 @@ inline bool SparseEmbeddingOpBackwardStorageType(const nnvm::NodeAttrs& attrs,
   int& data_grad_stype = out_attrs->at(0);
   int& weight_grad_stype = out_attrs->at(1);
   bool dispatched = false;
-  const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
   if (!dispatched && ograd_stype == kDefaultStorage &&
-      data_stype == kDefaultStorage && !invalid_ctx) {
+      data_stype == kDefaultStorage) {
     // dns, dns -> dns, rsp
     if (type_assign(&data_grad_stype, kDefaultStorage) &&
         type_assign(&weight_grad_stype, kRowSparseStorage) &&
@@ -336,8 +334,8 @@ struct TakeRspKernel {
   }
 };
 
-inline void EmbeddingOpForwardRspImpl(mshadow::Stream<mshadow::cpu>* s,
-                                      const cpu& cpu_dev,
+template<typename xpu>
+inline void EmbeddingOpForwardRspImpl(mshadow::Stream<xpu>* s,
                                       const TBlob& data,
                                       const NDArray& weight,
                                       const OpReqType req,
@@ -351,7 +349,7 @@ inline void EmbeddingOpForwardRspImpl(mshadow::Stream<mshadow::cpu>* s,
           size_t data_size = data.shape_.Size();
           // only using the second dim since weight.ndim() == 2
           const nnvm::dim_t row_length = weight.shape()[1];
-          Kernel<TakeRspKernel<req_t>, cpu>::Launch(s, data_size, data.dptr<IType>(),
+          Kernel<TakeRspKernel<req_t>, xpu>::Launch(s, data_size, data.dptr<IType>(),
                                                     output.dptr<DType>(),
                                                     weight.aux_data(kIdx).dptr<RType>(),
                                                     weight.data().dptr<DType>(),
@@ -369,39 +367,7 @@ void SparseEmbeddingOpForwardRspImpl(mshadow::Stream<xpu>* s,
                                      const TBlob& data,
                                      const NDArray& weight,
                                      const OpReqType req,
-                                     const TBlob& output) {
-  if (req == kNullOp) return;
-  CHECK((std::is_same<xpu, mshadow::cpu>::value)) << "SparseEmbedding is only implemented for CPU";
-  using namespace rowsparse;
-  using namespace mxnet_op;
-  // zeros weight
-  if (req == kWriteTo && !weight.storage_initialized()) {
-    size_t out_size = output.shape_.Size();
-    MSHADOW_TYPE_SWITCH(output.type_flag_, DType, {
-      Kernel<set_zero, xpu>::Launch(s, out_size, output.dptr<DType>());
-    })
-    return;
-  }
-  // check out-of-bound indices
-  bool is_valid = true;
-  MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
-    DType min = 0;
-    DType max = static_cast<DType>(weight.shape()[0] - 1);
-    // check with single thread is faster since data is small
-    DType* data_ptr = data.dptr<DType>();
-    size_t data_size = data.shape_.Size();
-    for (size_t i = 0; i < data_size; i++) {
-      if (data_ptr[i] > max || data_ptr[i] < min) is_valid = false;
-    }
-  })
-  CHECK(is_valid) << "SparseEmbedding input contains data out of bound";
-  // the weight is actually dense
-  if (weight.aux_shape(kIdx)[0] == weight.shape()[0]) {
-    EmbeddingOpForwardDnsImpl(s, data, weight.data(), req, output);
-  } else {
-    EmbeddingOpForwardRspImpl(s, xpu(), data, weight, req, output);
-  }
-}
+                                     const TBlob& output);
 
 template<typename xpu>
 void EmbeddingOpForward(const nnvm::NodeAttrs& attrs,
@@ -603,71 +569,12 @@ struct AddTakeGradRspKernel {
   }
 };
 
+template<typename xpu>
 inline void SparseEmbeddingOpBackwardRspImpl(const OpContext& ctx,
-                                             const cpu& cpu_dev,
                                              const TBlob& ograd,
                                              const TBlob& data,
                                              const OpReqType req,
-                                             const NDArray& output) {
-  using namespace mshadow;
-  using namespace mxnet_op;
-  using namespace mshadow::expr;
-  using namespace rowsparse;
-  using nnvm::dim_t;
-  if (req == kNullOp) return;
-  CHECK_EQ(req, kWriteTo) << "SparseEmbedding layer doesn't support "
-                          << "weight gradient calculation with req != write";
-
-  // Request temporary storage for marking non-zero rows and prefix sum
-  Stream<cpu> *s = ctx.get_stream<cpu>();
-  dim_t num_rows = output.shape()[0];
-  dim_t row_length = output.shape()[1];
-  // TODO(haibin) request less storage to save space in the future
-  size_t workspace_size = 2 * (num_rows * sizeof(dim_t));
-  Tensor<cpu, 1, char> workspace =
-    ctx.requested[embedding::kTempSpace].get_space_typed<cpu, 1, char>(
-      Shape1(workspace_size), s);
-  dim_t* row_flg = reinterpret_cast<dim_t*>(workspace.dptr_);
-  dim_t* prefix_sum = row_flg + num_rows;
-  dim_t data_size = static_cast<dim_t>(data.shape_.Size());
-
-  MSHADOW_TYPE_SWITCH(data.type_flag_, IType, {
-    MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
-      MSHADOW_TYPE_SWITCH(output.aux_type(kIdx), RType, {
-        // mark row flags
-        Fill<false>(s, TBlob(row_flg, mshadow::Shape1(num_rows), cpu::kDevMask), kWriteTo, 0);
-        Kernel<MarkRowFlgKernel, cpu>::Launch(s, data_size, row_flg, data.dptr<IType>());
-        // calculate inclusive prefix sum
-        // TODO(haibin) ideally this is should be done in parallel
-        prefix_sum[0] = row_flg[0];
-        for (dim_t i = 1; i < num_rows; i++) {
-          prefix_sum[i] = prefix_sum[i - 1] + row_flg[i];
-        }
-        // total number of non-zero rows
-        dim_t nnr = prefix_sum[num_rows - 1];
-        if (nnr == 0) {
-          FillZerosRspImpl(s, output);
-          return;
-        }
-        output.CheckAndAlloc({Shape1(nnr)});
-        RType* grad_row_idx = output.aux_data(kIdx).dptr<RType>();
-        // fill row_idx array of output matrix, using the row_flg values
-        Kernel<FillRspRowIdxKernel, cpu>::Launch(s, num_rows,
-               grad_row_idx, prefix_sum, num_rows);
-        // prefill with zeros
-        DType* grad_data = output.data().dptr<DType>();
-        Kernel<set_zero, cpu>::Launch(s, nnr * row_length, grad_data);
-        // add the final gradients
-        const int num_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
-        dim_t segment_len = (nnr + num_threads - 1) / num_threads;
-        Kernel<AddTakeGradRspKernel, cpu>::Launch(s, num_threads, grad_data, prefix_sum,
-                                                  ograd.dptr<DType>(), row_length,
-                                                  data.dptr<IType>(), data_size, segment_len,
-                                                  num_rows);
-      });
-    });
-  });
-}
+                                             const NDArray& output);
 
 template<typename xpu>
 void SparseEmbeddingOpBackwardEx(const nnvm::NodeAttrs& attrs,
@@ -687,8 +594,8 @@ void SparseEmbeddingOpBackwardEx(const nnvm::NodeAttrs& attrs,
           << "SparseEmbedding layer doesn't support calculate data gradient";
   if (data.storage_type() == kDefaultStorage && ograd.storage_type() == kDefaultStorage &&
       weight_grad.storage_type() == kRowSparseStorage) {
-    SparseEmbeddingOpBackwardRspImpl(ctx, xpu(), ograd.data(), data.data(),
-                                     req[embedding::kWeight], weight_grad);
+    SparseEmbeddingOpBackwardRspImpl<xpu>(ctx, ograd.data(), data.data(),
+                                          req[embedding::kWeight], weight_grad);
   } else {
     LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
   }
diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py
index 0db9f45..31c3c46 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -1447,6 +1447,7 @@ def test_sparse_square_sum():
                     check_numeric_gradient(test, [rsp], grad_stype_dict={'data': 'row_sparse'},
                                            atol=1e-2, rtol=0.1)
 
+                    
 def test_sparse_storage_fallback():
     """ test operators which don't implement FComputeEx or FStatefulComputeEx """
     def check_broadcast_add(shape, lhs_stype, rhs_stype):
@@ -1551,39 +1552,38 @@ def test_sparse_embedding():
         # update weight based on density
         weight[:] = rand_ndarray(weight.shape, 'row_sparse', density=density)
         # check forward
-        exe_test.forward(is_train=True)
-        assert_almost_equal(exe_test.outputs[0].asnumpy(), np.dot(data_onehot, weight.asnumpy()))
+        executor.forward(is_train=True)
+        assert_almost_equal(executor.outputs[0].asnumpy(), np.dot(data_onehot, weight.asnumpy()))
         # check backward
         executor.backward([grad])
         assert_almost_equal(grad_map["embed_weight"].asnumpy(), np.dot(data_onehot.T, grad.asnumpy()))
 
-    if default_context().device_type == 'cpu':
-        densities = [0, 0.5, 1]
-        in_dim = 50
-        out_dim = 3
-        batch = 8
-        # init executor
-        data = mx.sym.Variable("data")
-        weight = mx.sym.Variable("embed_weight", stype='row_sparse')
-        embed = mx.sym.contrib.SparseEmbedding(data=data, weight=weight, input_dim=in_dim,
-                                               output_dim=out_dim, name="embed")
-        grad_req = {'data': 'null', 'embed_weight': 'write'}
-        exe_test = embed.simple_bind(default_context(), grad_req=grad_req, data=(batch,))
-        arg_map = dict(zip(embed.list_arguments(), exe_test.arg_arrays))
-        grad_map = dict(zip(embed.list_arguments(), exe_test.grad_arrays))
-        # init data
-        np_data = np.random.randint(low=0, high=in_dim, size=batch)
-        np_onehot = np.zeros((batch, in_dim))
-        np_onehot[np.arange(batch), np_data] = 1.0
-        arg_map["data"][:] = np_data
-        # init grad
-        np_grad = np.random.uniform(-1, 1, exe_test.outputs[0].shape)
-        grad = mx.nd.sparse.zeros('row_sparse', np_grad.shape)
-        grad[:] = np_grad
-        # weight
-        weight = arg_map["embed_weight"]
-        for density in densities:
-            check_sparse_embedding(exe_test, weight, np_onehot, grad, density)
+    densities = [0, 0.5, 1]
+    in_dim = 50
+    out_dim = 3
+    batch = 8
+    # init executor
+    data = mx.sym.Variable("data")
+    weight = mx.sym.Variable("embed_weight", stype='row_sparse')
+    embed = mx.sym.contrib.SparseEmbedding(data=data, weight=weight, input_dim=in_dim,
+                                           output_dim=out_dim, name="embed")
+    grad_req = {'data': 'null', 'embed_weight': 'write'}
+    exe_test = embed.simple_bind(default_context(), grad_req=grad_req, data=(batch,))
+    arg_map = dict(zip(embed.list_arguments(), exe_test.arg_arrays))
+    grad_map = dict(zip(embed.list_arguments(), exe_test.grad_arrays))
+    # init data
+    np_data = np.random.randint(low=0, high=in_dim, size=batch)
+    np_onehot = np.zeros((batch, in_dim))
+    np_onehot[np.arange(batch), np_data] = 1.0
+    arg_map["data"][:] = np_data
+    # init grad
+    np_grad = np.random.uniform(-1, 1, exe_test.outputs[0].shape)
+    grad = mx.nd.sparse.zeros('row_sparse', np_grad.shape)
+    grad[:] = np_grad
+    # weight
+    weight = arg_map["embed_weight"]
+    for density in densities:
+        check_sparse_embedding(exe_test, weight, np_onehot, grad, density)
 
 def test_scatter_ops():
     def csr_get_seen_points(name, csr_array, verbose=False):

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].