You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/16 23:13:05 UTC

[GitHub] piiswrong closed pull request #8647: sparse embedding operator, gpu implementation

piiswrong closed pull request #8647: sparse embedding operator, gpu implementation
URL: https://github.com/apache/incubator-mxnet/pull/8647
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/sparse/matrix_factorization.py b/example/sparse/matrix_factorization.py
index cdb61643d3..3387706665 100644
--- a/example/sparse/matrix_factorization.py
+++ b/example/sparse/matrix_factorization.py
@@ -22,6 +22,8 @@
 import numpy as np
 from get_data import get_movielens_iter, get_movielens_data
 from matrix_fact_model import matrix_fact_net
+
+
 logging.basicConfig(level=logging.DEBUG)
 
 parser = argparse.ArgumentParser(description="Run matrix factorization with sparse embedding",
@@ -36,6 +38,8 @@
                     help="the factor size of the embedding operation")
 parser.add_argument('--use-dense', action='store_true',
                     help="use the dense embedding operator")
+parser.add_argument('--use-gpu', action='store_true',
+                    help="use gpu")
 parser.add_argument('--dummy-iter', action='store_true',
                     help="use the dummy data iterator for speed test")
 
@@ -63,7 +67,7 @@
     print_every = args.print_every
 
     momentum = 0.9
-    ctx = mx.cpu(0)
+    ctx = mx.gpu(0) if args.use_gpu else mx.cpu(0)
     learning_rate = 0.1
 
     # prepare dataset and iterators
@@ -75,7 +79,6 @@
 
     # construct the model
     net = matrix_fact_net(factor_size, factor_size, max_user, max_movies, sparse_embed=use_sparse)
-    a = time.time()
 
     # initialize the module
     mod = mx.module.Module(symbol=net, context=ctx, data_names=['user', 'item'],
diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc
index 273ebec488..6fc54fc208 100644
--- a/src/operator/tensor/indexing_op.cc
+++ b/src/operator/tensor/indexing_op.cc
@@ -26,6 +26,115 @@
 #include "./indexing_op.h"
 namespace mxnet {
 namespace op {
+
+template<>
+void SparseEmbeddingOpForwardRspImpl<cpu>(mshadow::Stream<cpu>* s,
+                                          const TBlob& data,
+                                          const NDArray& weight,
+                                          const OpReqType req,
+                                          const TBlob& output) {
+  if (req == kNullOp) return;
+  using namespace rowsparse;
+  using namespace mxnet_op;
+  // zeros weight
+  if (req == kWriteTo && !weight.storage_initialized()) {
+    size_t out_size = output.shape_.Size();
+    MSHADOW_TYPE_SWITCH(output.type_flag_, DType, {
+      Fill<false>(s, TBlob(output.dptr<DType>(), mshadow::Shape1(out_size),
+          cpu::kDevMask), kWriteTo, 0);
+    })
+    return;
+  }
+  // check out-of-bound indices
+  bool is_valid = true;
+  MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
+    DType min = 0;
+    DType max = static_cast<DType>(weight.shape()[0] - 1);
+    // check with single thread is faster since data is small
+    DType* data_ptr = data.dptr<DType>();
+    size_t data_size = data.shape_.Size();
+    for (size_t i = 0; i < data_size; i++) {
+      if (data_ptr[i] > max || data_ptr[i] < min) is_valid = false;
+    }
+  })
+  CHECK(is_valid) << "SparseEmbedding input contains data out of bound";
+  // the weight is actually dense
+  if (weight.aux_shape(kIdx)[0] == weight.shape()[0]) {
+    EmbeddingOpForwardDnsImpl<cpu>(s, data, weight.data(), req, output);
+  } else {
+    EmbeddingOpForwardRspImpl<cpu>(s, data, weight, req, output);
+  }
+}
+
+
+template<>
+inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const OpContext& ctx,
+                                                  const TBlob& ograd,
+                                                  const TBlob& data,
+                                                  const OpReqType req,
+                                                  const NDArray& output) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using namespace mshadow::expr;
+  using namespace rowsparse;
+  using nnvm::dim_t;
+  if (req == kNullOp) return;
+  CHECK_EQ(req, kWriteTo) << "SparseEmbedding layer doesn't support "
+                          << "weight gradient calculation with req != write";
+
+  // Request temporary storage for marking non-zero rows and prefix sum
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  dim_t num_rows = output.shape()[0];
+  dim_t row_length = output.shape()[1];
+  // TODO(haibin) request less storage to save space in the future
+  size_t workspace_size = 2 * (num_rows * sizeof(dim_t));
+  Tensor<cpu, 1, char> workspace =
+    ctx.requested[embedding::kTempSpace].get_space_typed<cpu, 1, char>(
+      Shape1(workspace_size), s);
+  dim_t* row_flg = reinterpret_cast<dim_t*>(workspace.dptr_);
+  dim_t* prefix_sum = row_flg + num_rows;
+  dim_t data_size = static_cast<dim_t>(data.shape_.Size());
+
+  MSHADOW_TYPE_SWITCH(data.type_flag_, IType, {
+    MSHADOW_SGL_DBL_TYPE_SWITCH(ograd.type_flag_, DType, {
+      MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), RType, {
+        // mark row flags
+        Fill<false>(s, TBlob(row_flg, Shape1(num_rows), cpu::kDevMask), kWriteTo, 0);
+        Kernel<MarkRowFlgKernel, cpu>::Launch(s, data_size, row_flg, data.dptr<IType>());
+        // calculate inclusive prefix sum
+        // TODO(haibin) ideally this is should be done in parallel
+        prefix_sum[0] = row_flg[0];
+        for (dim_t i = 1; i < num_rows; i++) {
+          prefix_sum[i] = prefix_sum[i - 1] + row_flg[i];
+        }
+        // total number of non-zero rows
+        dim_t nnr = prefix_sum[num_rows - 1];
+        if (nnr == 0) {
+          FillZerosRspImpl(s, output);
+          return;
+        }
+        output.CheckAndAlloc({Shape1(nnr)});
+        RType* grad_row_idx = output.aux_data(kIdx).dptr<RType>();
+        // fill row_idx array of output matrix, using the row_flg values
+        Kernel<FillRspRowIdxKernel, cpu>::Launch(s, num_rows,
+            grad_row_idx, prefix_sum, num_rows);
+        // prefill with zeros
+        DType* grad_data = output.data().dptr<DType>();
+        Fill<false>(s, TBlob(grad_data, Shape1(nnr * row_length),
+            cpu::kDevMask), kWriteTo, 0);
+        // add the final gradients
+        const int num_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+        dim_t segment_len = (nnr + num_threads - 1) / num_threads;
+        Kernel<AddTakeGradRspKernel, cpu>::Launch(s, num_threads, grad_data, prefix_sum,
+                                                  ograd.dptr<DType>(), row_length,
+                                                  data.dptr<IType>(), data_size, segment_len,
+                                                  num_rows);
+      });
+    });
+  });
+}
+
+
 DMLC_REGISTER_PARAMETER(EmbeddingParam);
 DMLC_REGISTER_PARAMETER(TakeParam);
 DMLC_REGISTER_PARAMETER(OneHotParam);
@@ -116,8 +225,7 @@ The storage type of weight must be `row_sparse`, and the gradient of the weight
 .. Note::
 
     `SparseEmbedding` is designed for the use case where `input_dim` is very large (e.g. 100k).
-    The `row_sparse` weight cannot be used in a `BucketingModule`.
-    The operator is only available on CPU.
+    The operator is available on both CPU and GPU.
 
 Examples::
 
diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu
index 2cddd006a6..2aba122f65 100644
--- a/src/operator/tensor/indexing_op.cu
+++ b/src/operator/tensor/indexing_op.cu
@@ -24,14 +24,170 @@
 */
 
 #include "./indexing_op.h"
+#include "./util/tensor_util-inl.cuh"
+
 namespace mxnet {
 namespace op {
+
+/*! \brief If there are out-of-bound indices, out will be assigned to 1.
+ */
+
+struct is_valid_check {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, int32_t* out, const DType* data,
+                                  const DType min, const DType max) {
+    if (data[i] < min || data[i] > max) *out = 1;
+  }
+};
+
+
+struct AddTakeGradRspGPUKernel {
+  template<typename DType, typename IType>
+  __device__ __forceinline__ static void Map(int tid,
+                                             DType* out,
+                                             const nnvm::dim_t* prefix_sum,
+                                             const IType* data,
+                                             const DType* ograd,
+                                             const nnvm::dim_t row_length) {
+    using nnvm::dim_t;
+    const dim_t data_i = tid / row_length;
+    const dim_t grad_i = tid % row_length;
+    const dim_t irow = static_cast<dim_t>(data[data_i]);
+    const dim_t rsp_row = prefix_sum[irow] - 1;
+    const DType val = ograd[data_i * row_length + grad_i];
+    atomicAdd(static_cast<DType *>(&(out[rsp_row*row_length+grad_i])), val);
+  }
+};
+
+template<>
+void SparseEmbeddingOpForwardRspImpl<gpu>(mshadow::Stream<gpu>* s,
+                                          const TBlob& data,
+                                          const NDArray& weight,
+                                          const OpReqType req,
+                                          const TBlob& output) {
+  if (req == kNullOp) return;
+  using namespace rowsparse;
+  using namespace mxnet_op;
+  // zeros weight
+  if (req == kWriteTo && !weight.storage_initialized()) {
+    size_t out_size = output.shape_.Size();
+    MSHADOW_TYPE_SWITCH(output.type_flag_, DType, {
+      Fill<false>(s, TBlob(output.dptr<DType>(), mshadow::Shape1(out_size),
+          gpu::kDevMask), kWriteTo, 0);
+    })
+    return;
+  }
+  // check out-of-bound indices
+  int32_t is_valid = 0;
+  MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
+    DType min = 0;
+    DType max = static_cast<DType>(weight.shape()[0] - 1);
+    DType* data_ptr = data.dptr<DType>();
+    size_t data_size = data.shape_.Size();
+    int32_t* is_valid_ptr = NULL;
+    CUDA_CALL(cudaMalloc(&is_valid_ptr, sizeof(int32_t)));
+    Kernel<set_zero, gpu>::Launch(s, 1, is_valid_ptr);
+    Kernel<is_valid_check, gpu>::Launch(s, data_size, is_valid_ptr, data_ptr, min, max);
+    CUDA_CALL(cudaMemcpy(&is_valid, is_valid_ptr, sizeof(int32_t),
+              cudaMemcpyDeviceToHost));
+  })
+  CHECK_EQ(is_valid, 0) << "SparseEmbedding input contains data out of bound";
+  // the weight is actually dense
+  if (weight.aux_shape(kIdx)[0] == weight.shape()[0]) {
+    EmbeddingOpForwardDnsImpl<gpu>(s, data, weight.data(), req, output);
+  } else {
+    EmbeddingOpForwardRspImpl<gpu>(s, data, weight, req, output);
+  }
+}
+
+
+template<>
+inline void SparseEmbeddingOpBackwardRspImpl<gpu>(const OpContext& ctx,
+                                                  const TBlob& ograd,
+                                                  const TBlob& data,
+                                                  const OpReqType req,
+                                                  const NDArray& output) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using namespace mshadow::expr;
+  using namespace rowsparse;
+  using nnvm::dim_t;
+  if (req == kNullOp) return;
+  CHECK_EQ(req, kWriteTo) << "SparseEmbedding layer doesn't support "
+                          << "weight gradient calculation with req != write";
+
+  // Request temporary storage for marking non-zero rows and prefix sum
+  Stream<gpu> *s = ctx.get_stream<gpu>();
+  dim_t num_rows = output.shape()[0];
+  dim_t row_length = output.shape()[1];
+  dim_t data_size = static_cast<dim_t>(data.shape_.Size());
+  dim_t num_threads;
+
+  MSHADOW_TYPE_SWITCH(data.type_flag_, IType, {
+    MSHADOW_SGL_DBL_TYPE_SWITCH(ograd.type_flag_, DType, {
+      MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), RType, {
+        dim_t* prefix_sum = NULL;
+        void* d_temp_storage = NULL;
+        size_t temp_storage_bytes = 0;
+        cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                      temp_storage_bytes,
+                                      prefix_sum,
+                                      prefix_sum,
+                                      num_rows,
+                                      Stream<gpu>::GetStream(s));
+        Tensor<gpu, 1, char> workspace = ctx.requested[0]
+            .get_space_typed<gpu, 1, char>(Shape1(num_rows * sizeof(dim_t) +
+                                           temp_storage_bytes), s);
+        prefix_sum = reinterpret_cast<dim_t*>(workspace.dptr_);
+        d_temp_storage = workspace.dptr_ + num_rows*sizeof(dim_t);
+        num_threads = num_rows;
+        Fill<false>(s, TBlob(prefix_sum, Shape1(num_threads), gpu::kDevMask), kWriteTo, 0);
+        Kernel<MarkRowFlgKernel, gpu>::Launch(s, data_size, prefix_sum, data.dptr<IType>());
+
+        cub::DeviceScan::InclusiveSum(d_temp_storage,
+                                      temp_storage_bytes,
+                                      prefix_sum,
+                                      prefix_sum,
+                                      num_rows,
+                                      mshadow::Stream<gpu>::GetStream(s));
+        dim_t nnr = 0;
+        CUDA_CALL(cudaMemcpy(&nnr, &prefix_sum[num_rows-1], sizeof(dim_t),
+            cudaMemcpyDeviceToHost));
+
+        if (nnr == 0) {
+          FillZerosRspImpl(s, output);
+          return;
+        }
+        output.CheckAndAlloc({Shape1(nnr)});
+        RType* grad_row_idx = output.aux_data(kIdx).dptr<RType>();
+        // fill row_idx array of output matrix, using the row_flg values
+        Kernel<FillRspRowIdxKernel, gpu>::Launch(s, num_rows,
+            grad_row_idx, prefix_sum, num_rows);
+        // prefill with zeros
+        DType* grad_data = output.data().dptr<DType>();
+        Fill<false>(s, TBlob(grad_data, Shape1(nnr * row_length), gpu::kDevMask),
+            kWriteTo, 0);
+        // add the final gradients
+        num_threads = row_length * data_size;
+        Kernel<AddTakeGradRspGPUKernel, gpu>::Launch(s, num_threads, grad_data, prefix_sum,
+            data.dptr<IType>(), ograd.dptr<DType>(), row_length);
+      });
+    });
+  });
+}
+
 NNVM_REGISTER_OP(Embedding)
 .set_attr<FCompute>("FCompute<gpu>", EmbeddingOpForward<gpu>);
 
+NNVM_REGISTER_OP(_contrib_SparseEmbedding)
+.set_attr<FComputeEx>("FComputeEx<gpu>", SparseEmbeddingOpForwardEx<gpu>);
+
 NNVM_REGISTER_OP(_backward_Embedding)
 .set_attr<FCompute>("FCompute<gpu>", EmbeddingOpBackward<gpu>);
 
+NNVM_REGISTER_OP(_backward_SparseEmbedding)
+.set_attr<FComputeEx>("FComputeEx<gpu>", SparseEmbeddingOpBackwardEx<gpu>);
+
 NNVM_REGISTER_OP(take)
 .set_attr<FCompute>("FCompute<gpu>", TakeOpForward<gpu>);
 
diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h
index 684794bcd9..7af5edd157 100644
--- a/src/operator/tensor/indexing_op.h
+++ b/src/operator/tensor/indexing_op.h
@@ -187,9 +187,8 @@ inline bool SparseEmbeddingOpForwardStorageType(const nnvm::NodeAttrs& attrs,
   const int& weight_stype = in_attrs->at(embedding::kWeight);
   int& out_stype = out_attrs->at(embedding::kOut);
   bool dispatched = false;
-  const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
   if (!dispatched && data_stype == kDefaultStorage &&
-      weight_stype == kRowSparseStorage && !invalid_ctx) {
+      weight_stype == kRowSparseStorage) {
     // dns, rsp -> dns
     dispatched = storage_type_assign(&out_stype, kDefaultStorage,
                                      dispatch_mode, DispatchMode::kFComputeEx);
@@ -215,9 +214,8 @@ inline bool SparseEmbeddingOpBackwardStorageType(const nnvm::NodeAttrs& attrs,
   int& data_grad_stype = out_attrs->at(0);
   int& weight_grad_stype = out_attrs->at(1);
   bool dispatched = false;
-  const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
   if (!dispatched && ograd_stype == kDefaultStorage &&
-      data_stype == kDefaultStorage && !invalid_ctx) {
+      data_stype == kDefaultStorage) {
     // dns, dns -> dns, rsp
     if (type_assign(&data_grad_stype, kDefaultStorage) &&
         type_assign(&weight_grad_stype, kRowSparseStorage) &&
@@ -336,8 +334,8 @@ struct TakeRspKernel {
   }
 };
 
-inline void EmbeddingOpForwardRspImpl(mshadow::Stream<mshadow::cpu>* s,
-                                      const cpu& cpu_dev,
+template<typename xpu>
+inline void EmbeddingOpForwardRspImpl(mshadow::Stream<xpu>* s,
                                       const TBlob& data,
                                       const NDArray& weight,
                                       const OpReqType req,
@@ -351,7 +349,7 @@ inline void EmbeddingOpForwardRspImpl(mshadow::Stream<mshadow::cpu>* s,
           size_t data_size = data.shape_.Size();
           // only using the second dim since weight.ndim() == 2
           const nnvm::dim_t row_length = weight.shape()[1];
-          Kernel<TakeRspKernel<req_t>, cpu>::Launch(s, data_size, data.dptr<IType>(),
+          Kernel<TakeRspKernel<req_t>, xpu>::Launch(s, data_size, data.dptr<IType>(),
                                                     output.dptr<DType>(),
                                                     weight.aux_data(kIdx).dptr<RType>(),
                                                     weight.data().dptr<DType>(),
@@ -369,39 +367,7 @@ void SparseEmbeddingOpForwardRspImpl(mshadow::Stream<xpu>* s,
                                      const TBlob& data,
                                      const NDArray& weight,
                                      const OpReqType req,
-                                     const TBlob& output) {
-  if (req == kNullOp) return;
-  CHECK((std::is_same<xpu, mshadow::cpu>::value)) << "SparseEmbedding is only implemented for CPU";
-  using namespace rowsparse;
-  using namespace mxnet_op;
-  // zeros weight
-  if (req == kWriteTo && !weight.storage_initialized()) {
-    size_t out_size = output.shape_.Size();
-    MSHADOW_TYPE_SWITCH(output.type_flag_, DType, {
-      Kernel<set_zero, xpu>::Launch(s, out_size, output.dptr<DType>());
-    })
-    return;
-  }
-  // check out-of-bound indices
-  bool is_valid = true;
-  MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
-    DType min = 0;
-    DType max = static_cast<DType>(weight.shape()[0] - 1);
-    // check with single thread is faster since data is small
-    DType* data_ptr = data.dptr<DType>();
-    size_t data_size = data.shape_.Size();
-    for (size_t i = 0; i < data_size; i++) {
-      if (data_ptr[i] > max || data_ptr[i] < min) is_valid = false;
-    }
-  })
-  CHECK(is_valid) << "SparseEmbedding input contains data out of bound";
-  // the weight is actually dense
-  if (weight.aux_shape(kIdx)[0] == weight.shape()[0]) {
-    EmbeddingOpForwardDnsImpl(s, data, weight.data(), req, output);
-  } else {
-    EmbeddingOpForwardRspImpl(s, xpu(), data, weight, req, output);
-  }
-}
+                                     const TBlob& output);
 
 template<typename xpu>
 void EmbeddingOpForward(const nnvm::NodeAttrs& attrs,
@@ -603,71 +569,12 @@ struct AddTakeGradRspKernel {
   }
 };
 
+template<typename xpu>
 inline void SparseEmbeddingOpBackwardRspImpl(const OpContext& ctx,
-                                             const cpu& cpu_dev,
                                              const TBlob& ograd,
                                              const TBlob& data,
                                              const OpReqType req,
-                                             const NDArray& output) {
-  using namespace mshadow;
-  using namespace mxnet_op;
-  using namespace mshadow::expr;
-  using namespace rowsparse;
-  using nnvm::dim_t;
-  if (req == kNullOp) return;
-  CHECK_EQ(req, kWriteTo) << "SparseEmbedding layer doesn't support "
-                          << "weight gradient calculation with req != write";
-
-  // Request temporary storage for marking non-zero rows and prefix sum
-  Stream<cpu> *s = ctx.get_stream<cpu>();
-  dim_t num_rows = output.shape()[0];
-  dim_t row_length = output.shape()[1];
-  // TODO(haibin) request less storage to save space in the future
-  size_t workspace_size = 2 * (num_rows * sizeof(dim_t));
-  Tensor<cpu, 1, char> workspace =
-    ctx.requested[embedding::kTempSpace].get_space_typed<cpu, 1, char>(
-      Shape1(workspace_size), s);
-  dim_t* row_flg = reinterpret_cast<dim_t*>(workspace.dptr_);
-  dim_t* prefix_sum = row_flg + num_rows;
-  dim_t data_size = static_cast<dim_t>(data.shape_.Size());
-
-  MSHADOW_TYPE_SWITCH(data.type_flag_, IType, {
-    MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
-      MSHADOW_TYPE_SWITCH(output.aux_type(kIdx), RType, {
-        // mark row flags
-        Fill<false>(s, TBlob(row_flg, mshadow::Shape1(num_rows), cpu::kDevMask), kWriteTo, 0);
-        Kernel<MarkRowFlgKernel, cpu>::Launch(s, data_size, row_flg, data.dptr<IType>());
-        // calculate inclusive prefix sum
-        // TODO(haibin) ideally this is should be done in parallel
-        prefix_sum[0] = row_flg[0];
-        for (dim_t i = 1; i < num_rows; i++) {
-          prefix_sum[i] = prefix_sum[i - 1] + row_flg[i];
-        }
-        // total number of non-zero rows
-        dim_t nnr = prefix_sum[num_rows - 1];
-        if (nnr == 0) {
-          FillZerosRspImpl(s, output);
-          return;
-        }
-        output.CheckAndAlloc({Shape1(nnr)});
-        RType* grad_row_idx = output.aux_data(kIdx).dptr<RType>();
-        // fill row_idx array of output matrix, using the row_flg values
-        Kernel<FillRspRowIdxKernel, cpu>::Launch(s, num_rows,
-               grad_row_idx, prefix_sum, num_rows);
-        // prefill with zeros
-        DType* grad_data = output.data().dptr<DType>();
-        Kernel<set_zero, cpu>::Launch(s, nnr * row_length, grad_data);
-        // add the final gradients
-        const int num_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
-        dim_t segment_len = (nnr + num_threads - 1) / num_threads;
-        Kernel<AddTakeGradRspKernel, cpu>::Launch(s, num_threads, grad_data, prefix_sum,
-                                                  ograd.dptr<DType>(), row_length,
-                                                  data.dptr<IType>(), data_size, segment_len,
-                                                  num_rows);
-      });
-    });
-  });
-}
+                                             const NDArray& output);
 
 template<typename xpu>
 void SparseEmbeddingOpBackwardEx(const nnvm::NodeAttrs& attrs,
@@ -687,8 +594,8 @@ void SparseEmbeddingOpBackwardEx(const nnvm::NodeAttrs& attrs,
           << "SparseEmbedding layer doesn't support calculate data gradient";
   if (data.storage_type() == kDefaultStorage && ograd.storage_type() == kDefaultStorage &&
       weight_grad.storage_type() == kRowSparseStorage) {
-    SparseEmbeddingOpBackwardRspImpl(ctx, xpu(), ograd.data(), data.data(),
-                                     req[embedding::kWeight], weight_grad);
+    SparseEmbeddingOpBackwardRspImpl<xpu>(ctx, ograd.data(), data.data(),
+                                          req[embedding::kWeight], weight_grad);
   } else {
     LOG(FATAL) << "Not implemented: " << operator_string(attrs, ctx, inputs, req, outputs);
   }
diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py
index 0db9f451dd..31c3c46889 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -1447,6 +1447,7 @@ def test_sparse_square_sum():
                     check_numeric_gradient(test, [rsp], grad_stype_dict={'data': 'row_sparse'},
                                            atol=1e-2, rtol=0.1)
 
+                    
 def test_sparse_storage_fallback():
     """ test operators which don't implement FComputeEx or FStatefulComputeEx """
     def check_broadcast_add(shape, lhs_stype, rhs_stype):
@@ -1551,39 +1552,38 @@ def check_sparse_embedding(executor, weight_ref, data_onehot, grad, density):
         # update weight based on density
         weight[:] = rand_ndarray(weight.shape, 'row_sparse', density=density)
         # check forward
-        exe_test.forward(is_train=True)
-        assert_almost_equal(exe_test.outputs[0].asnumpy(), np.dot(data_onehot, weight.asnumpy()))
+        executor.forward(is_train=True)
+        assert_almost_equal(executor.outputs[0].asnumpy(), np.dot(data_onehot, weight.asnumpy()))
         # check backward
         executor.backward([grad])
         assert_almost_equal(grad_map["embed_weight"].asnumpy(), np.dot(data_onehot.T, grad.asnumpy()))
 
-    if default_context().device_type == 'cpu':
-        densities = [0, 0.5, 1]
-        in_dim = 50
-        out_dim = 3
-        batch = 8
-        # init executor
-        data = mx.sym.Variable("data")
-        weight = mx.sym.Variable("embed_weight", stype='row_sparse')
-        embed = mx.sym.contrib.SparseEmbedding(data=data, weight=weight, input_dim=in_dim,
-                                               output_dim=out_dim, name="embed")
-        grad_req = {'data': 'null', 'embed_weight': 'write'}
-        exe_test = embed.simple_bind(default_context(), grad_req=grad_req, data=(batch,))
-        arg_map = dict(zip(embed.list_arguments(), exe_test.arg_arrays))
-        grad_map = dict(zip(embed.list_arguments(), exe_test.grad_arrays))
-        # init data
-        np_data = np.random.randint(low=0, high=in_dim, size=batch)
-        np_onehot = np.zeros((batch, in_dim))
-        np_onehot[np.arange(batch), np_data] = 1.0
-        arg_map["data"][:] = np_data
-        # init grad
-        np_grad = np.random.uniform(-1, 1, exe_test.outputs[0].shape)
-        grad = mx.nd.sparse.zeros('row_sparse', np_grad.shape)
-        grad[:] = np_grad
-        # weight
-        weight = arg_map["embed_weight"]
-        for density in densities:
-            check_sparse_embedding(exe_test, weight, np_onehot, grad, density)
+    densities = [0, 0.5, 1]
+    in_dim = 50
+    out_dim = 3
+    batch = 8
+    # init executor
+    data = mx.sym.Variable("data")
+    weight = mx.sym.Variable("embed_weight", stype='row_sparse')
+    embed = mx.sym.contrib.SparseEmbedding(data=data, weight=weight, input_dim=in_dim,
+                                           output_dim=out_dim, name="embed")
+    grad_req = {'data': 'null', 'embed_weight': 'write'}
+    exe_test = embed.simple_bind(default_context(), grad_req=grad_req, data=(batch,))
+    arg_map = dict(zip(embed.list_arguments(), exe_test.arg_arrays))
+    grad_map = dict(zip(embed.list_arguments(), exe_test.grad_arrays))
+    # init data
+    np_data = np.random.randint(low=0, high=in_dim, size=batch)
+    np_onehot = np.zeros((batch, in_dim))
+    np_onehot[np.arange(batch), np_data] = 1.0
+    arg_map["data"][:] = np_data
+    # init grad
+    np_grad = np.random.uniform(-1, 1, exe_test.outputs[0].shape)
+    grad = mx.nd.sparse.zeros('row_sparse', np_grad.shape)
+    grad[:] = np_grad
+    # weight
+    weight = arg_map["embed_weight"]
+    for density in densities:
+        check_sparse_embedding(exe_test, weight, np_onehot, grad, density)
 
 def test_scatter_ops():
     def csr_get_seen_points(name, csr_array, verbose=False):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services