You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/03/03 12:13:57 UTC

[incubator-mxnet] branch master updated: sparse regression operators (#9625)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new dedfd2d  sparse regression operators (#9625)
dedfd2d is described below

commit dedfd2d60713319855c0b9df0aac57eee2d68f2d
Author: Ziyue Huang <zy...@gmail.com>
AuthorDate: Sat Mar 3 20:13:48 2018 +0800

    sparse regression operators (#9625)
    
    * sparse regression ops
    
    * add elemadd(dns, csr)
    
    * address comments and fix
    
    * replace copy with mshadow_op::identity
    
    * add kWriteInplace check
    
    * elemwise broadcast add
    
    * less template instantiation
    
    * not instantiate broadcast_add
    
    * remove DnsCsrOP instantiation in elemwise_binary
    
    * lint
    
    * remove two regression ops
    
    * enable binary op
    
    * disable binary broadcast
    
    * fix
    
    * duplicate some codes in binary_broadcst
    
    * try to make names short
    
    * try to make names short for infer stype
    
    * disbale sparse broadcst_add
    
    * revert binary broadcast
    
    * update
    
    * disable MAE
    
    * disable DnsCsrOp
    
    * remove IType
    
    * remove binary
    
    * update
    
    * address comments
    
    * update
    
    * try to fix R-test MF
    
    * Revert "try to fix R-test MF"
    
    This reverts commit f6d3e17ea7f5a71d23d81375bf345147a4373a93.
    
    * remove grad_req check for label
    
    * address comments
    
    * trigger CI
---
 docs/api/python/ndarray/sparse.md      |   2 +
 docs/api/python/symbol/sparse.md       |   2 +
 src/operator/regression_output-inl.h   | 179 +++++++++++++++++++++++++++++----
 src/operator/regression_output.cc      |  86 ++++++++++------
 src/operator/regression_output.cu      |  12 ++-
 tests/python/unittest/test_operator.py |  70 ++++++++-----
 6 files changed, 272 insertions(+), 79 deletions(-)

diff --git a/docs/api/python/ndarray/sparse.md b/docs/api/python/ndarray/sparse.md
index df33570..b0cdd88 100644
--- a/docs/api/python/ndarray/sparse.md
+++ b/docs/api/python/ndarray/sparse.md
@@ -496,6 +496,8 @@ We summarize the interface for each class in the following sections.
     make_loss
     stop_gradient
     mxnet.ndarray.contrib.SparseEmbedding
+    LinearRegressionOutput
+    LogisticRegressionOutput
 ```
 
 ## API Reference
diff --git a/docs/api/python/symbol/sparse.md b/docs/api/python/symbol/sparse.md
index b40276b..a44ff15 100644
--- a/docs/api/python/symbol/sparse.md
+++ b/docs/api/python/symbol/sparse.md
@@ -194,6 +194,8 @@ In the rest of this document, we list sparse related routines provided by the
     make_loss
     stop_gradient
     mxnet.symbol.contrib.SparseEmbedding
+    LinearRegressionOutput
+    LogisticRegressionOutput
 ```
 
 ## API Reference
diff --git a/src/operator/regression_output-inl.h b/src/operator/regression_output-inl.h
index 4642f8d..59cbde3 100644
--- a/src/operator/regression_output-inl.h
+++ b/src/operator/regression_output-inl.h
@@ -31,6 +31,7 @@
 #include "./mxnet_op.h"
 #include "./operator_common.h"
 
+
 namespace mxnet {
 namespace op {
 
@@ -77,22 +78,103 @@ inline bool RegressionOpShape(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
+template<bool is_forward>
+inline bool RegressionInferStorageType(const nnvm::NodeAttrs& attrs,
+                                       const int dev_mask,
+                                       DispatchMode* dispatch_mode,
+                                       std::vector<int>* in_attrs,
+                                       std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 2U);
+  CHECK_EQ(out_attrs->size(), is_forward ? 1U : 2U);
+  const size_t label_pos = is_forward ? 1U : 0U;
+  const auto label_stype = in_attrs->at(label_pos);
+  const auto data_stype = in_attrs->at(1 - label_pos);
+  auto& out_stype = out_attrs->at(0);
+  bool dispatched = false;
+  if (!dispatched && data_stype == kDefaultStorage && label_stype == kDefaultStorage) {
+    dispatched = storage_type_assign(&out_stype, kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFCompute);
+  }
+
+  if (!dispatched && data_stype == kDefaultStorage && label_stype == kCSRStorage) {
+    dispatched = storage_type_assign(&out_stype, kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFComputeEx);
+  }
+
+  if (!dispatched) {
+    dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+  }
+  // In backward pass, although we don't care about gradients of label,
+  // a storage type should be assigned to it.
+  if (!is_forward) type_assign(&out_attrs->at(1), kDefaultStorage);
+
+  return dispatched;
+}
+
+/*!
+ * \brief Kernel for binary operator of dense -OP- csr ndarray.
+ * Right hand side of OP has no effect.
+ * Parallelize by each row.
+ */
+template<typename OP, int req>
+struct DnsCsrSparseKernel {
+  template<typename DType, typename IType, typename RType>
+  MSHADOW_XINLINE static void Map(int i, DType* out_data,
+                                  const DType* dns_data,
+                                  const DType* csr_data,
+                                  const IType* csr_idx,
+                                  const RType* csr_indptr,
+                                  const nnvm::dim_t row_length) {
+    nnvm::dim_t row_i = i * row_length;
+    for (nnvm::dim_t j=csr_indptr[i]; j < csr_indptr[i+1]; j++) {
+      KERNEL_ASSIGN(out_data[row_i + csr_idx[j]], req,
+        OP::Map(dns_data[row_i + csr_idx[j]], csr_data[j]));
+    }
+  }
+};
+
+
+template<typename xpu, typename ForwardOp>
+inline void RegressionForwardImpl(mshadow::Stream<xpu> *s, const OpReqType req,
+                                  const TBlob &data, const TBlob &out) {
+  if (req == kNullOp) return;
+  MSHADOW_REAL_TYPE_SWITCH(data.type_flag_, DType, {
+    MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+      const DType* in_data = data.dptr<DType>();
+      DType* out_data = out.dptr<DType>();
+      using namespace mxnet_op;
+      Kernel<op_with_req<ForwardOp, Req>, xpu>::Launch(
+        s, out.Size(), out_data, in_data);
+    });
+  });
+}
+
 template<typename xpu, typename ForwardOp>
 void RegressionForward(const nnvm::NodeAttrs& attrs,
                        const OpContext& ctx,
                        const std::vector<TBlob>& inputs,
                        const std::vector<OpReqType>& req,
                        const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-  MSHADOW_REAL_TYPE_SWITCH(inputs[reg_enum::kData].type_flag_, DType, {
-    MXNET_ASSIGN_REQ_SWITCH(req[reg_enum::kOut], Req, {
-      const DType* in_data = inputs[reg_enum::kData].dptr<DType>();
-      DType* out_data = outputs[reg_enum::kOut].dptr<DType>();
-      using namespace mxnet_op;
-      Kernel<op_with_req<ForwardOp, Req>, xpu>::Launch(
-        s, outputs[reg_enum::kOut].Size(), out_data, in_data);
-    });
-  });
+  RegressionForwardImpl<xpu, ForwardOp>(s, req[reg_enum::kOut],
+    inputs[reg_enum::kData], outputs[reg_enum::kOut]);
+}
+
+template<typename xpu, typename ForwardOp>
+void RegressionForwardEx(const nnvm::NodeAttrs& attrs,
+                         const OpContext& ctx,
+                         const std::vector<NDArray>& inputs,
+                         const std::vector<OpReqType>& req,
+                         const std::vector<NDArray>& outputs) {
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(inputs[reg_enum::kData].storage_type(), kDefaultStorage);
+  CHECK_EQ(inputs[reg_enum::kOut].storage_type(), kDefaultStorage);
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  RegressionForwardImpl<xpu, ForwardOp>(s, req[reg_enum::kOut],
+    inputs[reg_enum::kData].data(), outputs[reg_enum::kOut].data());
 }
 
 template<typename xpu, typename BackwardOp>
@@ -101,26 +183,89 @@ void RegressionBackward(const nnvm::NodeAttrs& attrs,
                         const std::vector<TBlob>& inputs,
                         const std::vector<OpReqType>& req,
                         const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 2);
+  CHECK_EQ(outputs.size(), 2);
+  if (req[reg_enum::kData] == kNullOp) return;
   const RegressionOutputParam& param = nnvm::get<RegressionOutputParam>(attrs.parsed);
   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
   // inputs are in_label, out_data
   // outputs are data_grad, label_grad
-  MSHADOW_REAL_TYPE_SWITCH(inputs[1].type_flag_, DType, {
-    MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-      const DType* in_label = inputs[0].dptr<DType>();
-      const DType* out_data = inputs[1].dptr<DType>();
-      DType* data_grad = outputs[0].dptr<DType>();
-      const real_t num_output = inputs[0].Size()/inputs[0].shape_[0];
+  const TBlob& in_label = inputs[0], out_data = inputs[1];
+  const TBlob& data_grad = outputs[0];
+  MSHADOW_REAL_TYPE_SWITCH(out_data.type_flag_, DType, {
+    MXNET_ASSIGN_REQ_SWITCH(req[reg_enum::kData], Req, {
+      const DType* in_label_ptr = in_label.dptr<DType>();
+      const DType* out_data_ptr = out_data.dptr<DType>();
+      DType* data_grad_ptr = data_grad.dptr<DType>();
+      const real_t num_output = in_label.Size()/in_label.shape_[0];
       using namespace mxnet_op;
       Kernel<op_with_req<BackwardOp, Req>, xpu>::Launch(
-        s, outputs[0].Size(), data_grad, out_data, in_label);
+        s, data_grad.Size(), data_grad_ptr, out_data_ptr, in_label_ptr);
       Kernel<op_with_req<mshadow_op::mul, Req>, xpu>::Launch(
-        s, outputs[0].Size(), data_grad, data_grad,
+        s, data_grad.Size(), data_grad_ptr, data_grad_ptr,
         static_cast<DType>(param.grad_scale/num_output));
     });
   });
 }
 
+
+template<typename xpu, typename BackwardOp>
+inline void RegressionBackwardCSRImpl(mshadow::Stream<xpu> *s,
+                                      const RegressionOutputParam& param,
+                                      const OpReqType req,
+                                      const NDArray &data, const NDArray &label,
+                                      const NDArray &data_grad) {
+  if (req == kNullOp) return;
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using namespace csr;
+  const TShape dshape = data.shape();
+  const nnvm::dim_t num_rows = dshape[0];
+  const nnvm::dim_t row_length = dshape[1];
+  CHECK_EQ(label.aux_type(kIndPtr), label.aux_type(kIdx))
+    << "Type of indices array and index pointer array of the label should be the same";
+  MSHADOW_IDX_TYPE_SWITCH(label.aux_type(kIdx), IType, {
+    MSHADOW_REAL_TYPE_SWITCH(label.dtype(), DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+        const IType* label_indptr = label.aux_data(kIndPtr).dptr<IType>();
+        const IType* label_idx = label.aux_data(kIdx).dptr<IType>();
+        const DType* label_data = label.data().dptr<DType>();
+        const DType* data_ptr = data.data().dptr<DType>();
+        DType* grad_ptr = data_grad.data().dptr<DType>();
+        if (req != kWriteInplace) {
+          Kernel<op_with_req<mshadow_op::identity, Req>, xpu>::Launch(s,
+            dshape.Size(), grad_ptr, data_ptr);
+        }
+        Kernel<DnsCsrSparseKernel<BackwardOp, Req>, xpu>::Launch(s, num_rows,
+          grad_ptr, data_ptr, label_data, label_idx, label_indptr, row_length);
+        Kernel<op_with_req<mshadow_op::mul, Req>, xpu>::Launch(s, dshape.Size(),
+          grad_ptr, grad_ptr, static_cast<DType>(param.grad_scale/row_length));
+      });
+    });
+  });
+}
+
+
+template<typename xpu, typename BackwardOP>
+void RegressionBackwardEx(const nnvm::NodeAttrs& attrs,
+                          const OpContext& ctx,
+                          const std::vector<NDArray>& inputs,
+                          const std::vector<OpReqType>& req,
+                          const std::vector<NDArray>& outputs) {
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 2U);
+  const RegressionOutputParam& param = nnvm::get<RegressionOutputParam>(attrs.parsed);
+  const auto label_stype = inputs[0].storage_type();
+  const auto data_stype = inputs[1].storage_type();
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  if (data_stype == kDefaultStorage && label_stype == kCSRStorage) {
+    RegressionBackwardCSRImpl<xpu, BackwardOP>(s, param, req[0], inputs[1],
+      inputs[0], outputs[0]);
+  } else {
+    LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
+  }
+}
+
 struct RegressionOpGrad {
   const char *op_name;
   std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
diff --git a/src/operator/regression_output.cc b/src/operator/regression_output.cc
index 0b8ce69..9539a15 100644
--- a/src/operator/regression_output.cc
+++ b/src/operator/regression_output.cc
@@ -26,37 +26,38 @@
 #include "./elemwise_op_common.h"
 
 
-#define MXNET_OPERATOR_REGISTER_REGRESSION_FWD(__name$, __kernel$, __bwdop$)   \
-  NNVM_REGISTER_OP(__name$)                                                    \
-  .set_num_inputs(2)                                                           \
-  .set_num_outputs(1)                                                          \
-  .set_attr<nnvm::FListInputNames>("FListInputNames",                          \
-    [](const NodeAttrs& attrs) {                                               \
-      return std::vector<std::string>{"data", "label"};                        \
-    })                                                                         \
-  .set_attr<nnvm::FInferShape>("FInferShape", RegressionOpShape)               \
-  .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)                \
-  .set_attr<nnvm::FGradient>("FGradient", RegressionOpGrad{__bwdop$})          \
-  .set_attr<nnvm::FInplaceOption>("FInplaceOption",                            \
-  [](const NodeAttrs& attrs){                                                  \
-    return std::vector<std::pair<int, int> >{{0, 0}};                          \
-  })                                                                           \
-  .set_attr<FCompute>("FCompute<cpu>", RegressionForward<cpu, __kernel$>)      \
-  .add_argument("data", "NDArray-or-Symbol", "Input data to the function.")    \
-  .add_argument("label", "NDArray-or-Symbol", "Input label to the function.")  \
+#define MXNET_OPERATOR_REGISTER_REGRESSION_FWD(__name$, __kernel$, __bwdop$)           \
+  NNVM_REGISTER_OP(__name$)                                                            \
+  MXNET_ADD_SPARSE_OP_ALIAS(__name$)                                                   \
+  .set_num_inputs(2)                                                                   \
+  .set_num_outputs(1)                                                                  \
+  .set_attr<nnvm::FListInputNames>("FListInputNames",                                  \
+    [](const NodeAttrs& attrs) {                                                       \
+      return std::vector<std::string>{"data", "label"};                                \
+    })                                                                                 \
+  .set_attr<nnvm::FInferShape>("FInferShape", RegressionOpShape)                       \
+  .set_attr<nnvm::FGradient>("FGradient", RegressionOpGrad{__bwdop$})                  \
+  .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)                        \
+  .set_attr<nnvm::FInplaceOption>("FInplaceOption",                                    \
+  [](const NodeAttrs& attrs){                                                          \
+    return std::vector<std::pair<int, int> >{{0, 0}};                                  \
+  })                                                                                   \
+  .set_attr<FCompute>("FCompute<cpu>", RegressionForward<cpu, __kernel$>)              \
+  .add_argument("data", "NDArray-or-Symbol", "Input data to the function.")            \
+  .add_argument("label", "NDArray-or-Symbol", "Input label to the function.")          \
   .add_arguments(RegressionOutputParam::__FIELDS__())
 
-#define MXNET_OPERATOR_REGISTER_REGRESSION_BWD(__name$, __kernel$)         \
-  NNVM_REGISTER_OP(__name$)                                                \
-  .set_num_inputs(2)                                                       \
-  .set_num_outputs(2)                                                      \
-  .set_attr_parser(ParamParser<RegressionOutputParam>)                     \
-  .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 2>)            \
-  .set_attr<nnvm::TIsBackward>("TIsBackward", true)                        \
-  .set_attr<nnvm::FInplaceOption>("FInplaceOption",                        \
-  [](const NodeAttrs& attrs){                                              \
-    return std::vector<std::pair<int, int> >{{1, 0}};                      \
-  })                                                                       \
+#define MXNET_OPERATOR_REGISTER_REGRESSION_BWD(__name$, __kernel$)                      \
+  NNVM_REGISTER_OP(__name$)                                                             \
+  .set_num_inputs(2)                                                                    \
+  .set_num_outputs(2)                                                                   \
+  .set_attr_parser(ParamParser<RegressionOutputParam>)                                  \
+  .set_attr<nnvm::TIsBackward>("TIsBackward", true)                                     \
+  .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 2>)                         \
+  .set_attr<nnvm::FInplaceOption>("FInplaceOption",                                     \
+  [](const NodeAttrs& attrs){                                                           \
+    return std::vector<std::pair<int, int> >{{1, 0}};                                   \
+  })                                                                                    \
   .set_attr<FCompute>("FCompute<cpu>", RegressionBackward<cpu, __kernel$>)
 
 namespace mxnet {
@@ -67,6 +68,8 @@ DMLC_REGISTER_PARAMETER(RegressionOutputParam);
 
 MXNET_OPERATOR_REGISTER_REGRESSION_FWD(LinearRegressionOutput,
   mshadow_op::identity, "_backward_linear_reg_out")
+.set_attr<FInferStorageType>("FInferStorageType", RegressionInferStorageType<true>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", RegressionForwardEx<cpu, mshadow_op::identity>)
 .describe(R"code(Computes and optimizes for squared loss during backward propagation.
 Just outputs ``data`` during forward propagation.
 
@@ -78,12 +81,19 @@ then the squared loss estimated over :math:`n` samples is defined as
 .. note::
    Use the LinearRegressionOutput as the final output layer of a net.
 
+The storage type of ``label`` can be ``default`` or ``csr``
+
+- LinearRegressionOutput(default, default) = default
+- LinearRegressionOutput(default, csr) = default
+
 By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of regression outputs of a training example.
 The parameter `grad_scale` can be used to change this scale to `grad_scale/m`.
 
 )code" ADD_FILELINE);
 
-MXNET_OPERATOR_REGISTER_REGRESSION_BWD(_backward_linear_reg_out, mshadow_op::minus);
+MXNET_OPERATOR_REGISTER_REGRESSION_BWD(_backward_linear_reg_out, mshadow_op::minus)
+.set_attr<FInferStorageType>("FInferStorageType", RegressionInferStorageType<false>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", RegressionBackwardEx<cpu, mshadow_op::minus>);
 
 MXNET_OPERATOR_REGISTER_REGRESSION_FWD(MAERegressionOutput,
   mshadow_op::identity, "_backward_mae_reg_out")
@@ -99,6 +109,11 @@ then the mean absolute error (MAE) estimated over :math:`n` samples is defined a
 .. note::
    Use the MAERegressionOutput as the final output layer of a net.
 
+The storage type of ``label`` can be ``default`` or ``csr``
+
+- MAERegressionOutput(default, default) = default
+- MAERegressionOutput(default, csr) = default
+
 By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of regression outputs of a training example.
 The parameter `grad_scale` can be used to change this scale to `grad_scale/m`.
 
@@ -108,6 +123,8 @@ MXNET_OPERATOR_REGISTER_REGRESSION_BWD(_backward_mae_reg_out, mshadow_op::minus_
 
 MXNET_OPERATOR_REGISTER_REGRESSION_FWD(LogisticRegressionOutput,
   mshadow_op::sigmoid, "_backward_logistic_reg_out")
+.set_attr<FInferStorageType>("FInferStorageType", RegressionInferStorageType<true>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", RegressionForwardEx<cpu, mshadow_op::sigmoid>)
 .describe(R"code(Applies a logistic function to the input.
 
 The logistic function, also known as the sigmoid function, is computed as
@@ -120,12 +137,19 @@ It is suitable for binary classification or probability prediction tasks.
 .. note::
    Use the LogisticRegressionOutput as the final output layer of a net.
 
+The storage type of ``label`` can be ``default`` or ``csr``
+
+- LogisticRegressionOutput(default, default) = default
+- LogisticRegressionOutput(default, csr) = default
+
 By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of regression outputs of a training example.
 The parameter `grad_scale` can be used to change this scale to `grad_scale/m`.
 
 )code" ADD_FILELINE);
 
-MXNET_OPERATOR_REGISTER_REGRESSION_BWD(_backward_logistic_reg_out, mshadow_op::minus);
+MXNET_OPERATOR_REGISTER_REGRESSION_BWD(_backward_logistic_reg_out, mshadow_op::minus)
+.set_attr<FInferStorageType>("FInferStorageType", RegressionInferStorageType<false>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", RegressionBackwardEx<cpu, mshadow_op::minus>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/regression_output.cu b/src/operator/regression_output.cu
index e3a2e7e..ca11b84 100644
--- a/src/operator/regression_output.cu
+++ b/src/operator/regression_output.cu
@@ -28,10 +28,12 @@ namespace mxnet {
 namespace op {
 
 NNVM_REGISTER_OP(LinearRegressionOutput)
-.set_attr<FCompute>("FCompute<gpu>", RegressionForward<gpu, mshadow_op::identity>);
+.set_attr<FCompute>("FCompute<gpu>", RegressionForward<gpu, mshadow_op::identity>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", RegressionForwardEx<gpu, mshadow_op::identity>);
 
 NNVM_REGISTER_OP(_backward_linear_reg_out)
-.set_attr<FCompute>("FCompute<gpu>", RegressionBackward<gpu, mshadow_op::minus>);
+.set_attr<FCompute>("FCompute<gpu>", RegressionBackward<gpu, mshadow_op::minus>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", RegressionBackwardEx<gpu, mshadow_op::minus>);
 
 NNVM_REGISTER_OP(MAERegressionOutput)
 .set_attr<FCompute>("FCompute<gpu>", RegressionForward<gpu, mshadow_op::identity>);
@@ -40,10 +42,12 @@ NNVM_REGISTER_OP(_backward_mae_reg_out)
 .set_attr<FCompute>("FCompute<gpu>", RegressionBackward<gpu, mshadow_op::minus_sign>);
 
 NNVM_REGISTER_OP(LogisticRegressionOutput)
-.set_attr<FCompute>("FCompute<gpu>", RegressionForward<gpu, mshadow_op::sigmoid>);
+.set_attr<FCompute>("FCompute<gpu>", RegressionForward<gpu, mshadow_op::sigmoid>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", RegressionForwardEx<gpu, mshadow_op::sigmoid>);
 
 NNVM_REGISTER_OP(_backward_logistic_reg_out)
-.set_attr<FCompute>("FCompute<gpu>", RegressionBackward<gpu, mshadow_op::minus>);
+.set_attr<FCompute>("FCompute<gpu>", RegressionBackward<gpu, mshadow_op::minus>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", RegressionBackwardEx<gpu, mshadow_op::minus>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 7889e08..1a04e8e 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -218,41 +218,57 @@ def test_slice_channel():
     check_slice_channel(data_ndim=3, axis=-1, num_outputs=2, squeeze_axis=False)
     check_slice_channel(data_ndim=5, axis=-2, num_outputs=3, squeeze_axis=True)
 
-
-def check_regression(symbol, forward, backward):
-    data = mx.symbol.Variable('data')
-    label = mx.symbol.Variable('label')
-    out = symbol(data, label)
-    shape = (3, 1)
-    arr_data = mx.random.uniform(-1, 1, shape, ctx=mx.cpu()).copyto(default_context())
-    arr_label = mx.random.uniform(0, 1, shape[0], ctx=mx.cpu()).copyto(default_context())
-    arr_grad = mx.nd.empty(shape)
-    exec1 = out.bind(default_context(),
-                     args=[arr_data, arr_label],
-                     args_grad={"data" : arr_grad})
-    exec1.forward(is_train=True)
-    out1 = exec1.outputs[0].asnumpy()
-    npout = forward(arr_data.asnumpy())
-    # Non-zero atol required by test_operator_gpu.py:test_regression with seed 651640549
-    atol = 1e-5
-    assert_almost_equal(npout, out1, atol=atol)
-
-    exec1.backward()
-    npout = backward(npout,  arr_label.asnumpy().reshape(npout.shape))
-    assert_almost_equal(npout, arr_grad.asnumpy(), atol=atol)
-
-
 @with_seed()
 def test_regression():
+    ''' test regression operator '''
+    def check_regression(symbol, forward, backward, shape, stype='default', densities=[0, 0.5, 1]):
+        # init executor
+        data = mx.symbol.Variable('data')
+        label = mx.symbol.Variable('label', stype=stype)
+        out = symbol(data, label)
+        grad_req = {'data': 'write', 'label': 'null'}
+        out_exec = out.simple_bind(default_context(), grad_req=grad_req,
+            data=shape, label=shape)
+        arg_map = dict(zip(out.list_arguments(), out_exec.arg_arrays))
+        grad_map = dict(zip(out.list_arguments(), out_exec.grad_arrays))
+        # init data
+        arr_data = mx.random.uniform(-1, 1, shape)
+        arg_map["data"][:] = arr_data
+        # init label based on density
+        arr_label = arg_map["label"]
+        atol = 1e-5
+        for density in densities:
+            arr_label[:] = rand_ndarray(shape, stype, density=density)
+            out_exec.forward(is_train=True)
+            out_exec.backward()
+            np_out = forward(arr_data.asnumpy())
+            out_grad = backward(np_out, arr_label.asnumpy().reshape(np_out.shape)) / shape[1]
+            assert_almost_equal(out_exec.outputs[0].asnumpy(), np_out, atol=atol)
+            assert_almost_equal(grad_map["data"].asnumpy(), out_grad, atol=atol)
+
+    shape = (50, 30)
+
     check_regression(mx.symbol.LogisticRegressionOutput,
                      lambda x: 1.0 / (1.0 + np.exp(-x)),
-                     lambda x, y : x - y)
+                     lambda x, y : x - y,
+                     shape)
     check_regression(mx.symbol.LinearRegressionOutput,
                      lambda x: x,
-                     lambda x, y : x - y)
+                     lambda x, y : x - y,
+                     shape)
     check_regression(mx.symbol.MAERegressionOutput,
                      lambda x: x,
-                     lambda x, y : np.where(x > y, np.ones(x.shape), -np.ones(x.shape)))
+                     lambda x, y : np.where(x > y, np.ones(x.shape), -np.ones(x.shape)),
+                     shape)
+    check_regression(mx.symbol.LogisticRegressionOutput,
+                     lambda x: 1.0 / (1.0 + np.exp(-x)),
+                     lambda x, y : x - y,
+                     shape, stype='csr')
+    check_regression(mx.symbol.LinearRegressionOutput,
+                     lambda x: x,
+                     lambda x, y : x - y,
+                     shape, stype='csr')
+   
 
 def check_softmax_grad(xpu):
     x = mx.sym.Variable('x')

-- 
To stop receiving notification emails like this one, please contact
haibin@apache.org.