You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/03 12:13:49 UTC

[GitHub] eric-haibin-lin closed pull request #9625: sparse regression operators

eric-haibin-lin closed pull request #9625: sparse regression operators
URL: https://github.com/apache/incubator-mxnet/pull/9625
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/api/python/ndarray/sparse.md b/docs/api/python/ndarray/sparse.md
index a7aaa1fd41d..dc44111cdad 100644
--- a/docs/api/python/ndarray/sparse.md
+++ b/docs/api/python/ndarray/sparse.md
@@ -495,6 +495,8 @@ We summarize the interface for each class in the following sections.
     make_loss
     stop_gradient
     mxnet.ndarray.contrib.SparseEmbedding
+    LinearRegressionOutput
+    LogisticRegressionOutput
 ```
 
 ## API Reference
diff --git a/docs/api/python/symbol/sparse.md b/docs/api/python/symbol/sparse.md
index b40276b9f1a..a44ff150356 100644
--- a/docs/api/python/symbol/sparse.md
+++ b/docs/api/python/symbol/sparse.md
@@ -194,6 +194,8 @@ In the rest of this document, we list sparse related routines provided by the
     make_loss
     stop_gradient
     mxnet.symbol.contrib.SparseEmbedding
+    LinearRegressionOutput
+    LogisticRegressionOutput
 ```
 
 ## API Reference
diff --git a/src/operator/regression_output-inl.h b/src/operator/regression_output-inl.h
index 4642f8dc467..59cbde3de20 100644
--- a/src/operator/regression_output-inl.h
+++ b/src/operator/regression_output-inl.h
@@ -31,6 +31,7 @@
 #include "./mxnet_op.h"
 #include "./operator_common.h"
 
+
 namespace mxnet {
 namespace op {
 
@@ -77,22 +78,103 @@ inline bool RegressionOpShape(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
+template<bool is_forward>
+inline bool RegressionInferStorageType(const nnvm::NodeAttrs& attrs,
+                                       const int dev_mask,
+                                       DispatchMode* dispatch_mode,
+                                       std::vector<int>* in_attrs,
+                                       std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 2U);
+  CHECK_EQ(out_attrs->size(), is_forward ? 1U : 2U);
+  const size_t label_pos = is_forward ? 1U : 0U;
+  const auto label_stype = in_attrs->at(label_pos);
+  const auto data_stype = in_attrs->at(1 - label_pos);
+  auto& out_stype = out_attrs->at(0);
+  bool dispatched = false;
+  if (!dispatched && data_stype == kDefaultStorage && label_stype == kDefaultStorage) {
+    dispatched = storage_type_assign(&out_stype, kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFCompute);
+  }
+
+  if (!dispatched && data_stype == kDefaultStorage && label_stype == kCSRStorage) {
+    dispatched = storage_type_assign(&out_stype, kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFComputeEx);
+  }
+
+  if (!dispatched) {
+    dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+  }
+  // In backward pass, although we don't care about gradients of label,
+  // a storage type should be assigned to it.
+  if (!is_forward) type_assign(&out_attrs->at(1), kDefaultStorage);
+
+  return dispatched;
+}
+
+/*!
+ * \brief Kernel for binary operator of dense -OP- csr ndarray.
+ * Right hand side of OP has no effect.
+ * Parallelize by each row.
+ */
+template<typename OP, int req>
+struct DnsCsrSparseKernel {
+  template<typename DType, typename IType, typename RType>
+  MSHADOW_XINLINE static void Map(int i, DType* out_data,
+                                  const DType* dns_data,
+                                  const DType* csr_data,
+                                  const IType* csr_idx,
+                                  const RType* csr_indptr,
+                                  const nnvm::dim_t row_length) {
+    nnvm::dim_t row_i = i * row_length;
+    for (nnvm::dim_t j=csr_indptr[i]; j < csr_indptr[i+1]; j++) {
+      KERNEL_ASSIGN(out_data[row_i + csr_idx[j]], req,
+        OP::Map(dns_data[row_i + csr_idx[j]], csr_data[j]));
+    }
+  }
+};
+
+
+template<typename xpu, typename ForwardOp>
+inline void RegressionForwardImpl(mshadow::Stream<xpu> *s, const OpReqType req,
+                                  const TBlob &data, const TBlob &out) {
+  if (req == kNullOp) return;
+  MSHADOW_REAL_TYPE_SWITCH(data.type_flag_, DType, {
+    MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+      const DType* in_data = data.dptr<DType>();
+      DType* out_data = out.dptr<DType>();
+      using namespace mxnet_op;
+      Kernel<op_with_req<ForwardOp, Req>, xpu>::Launch(
+        s, out.Size(), out_data, in_data);
+    });
+  });
+}
+
 template<typename xpu, typename ForwardOp>
 void RegressionForward(const nnvm::NodeAttrs& attrs,
                        const OpContext& ctx,
                        const std::vector<TBlob>& inputs,
                        const std::vector<OpReqType>& req,
                        const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-  MSHADOW_REAL_TYPE_SWITCH(inputs[reg_enum::kData].type_flag_, DType, {
-    MXNET_ASSIGN_REQ_SWITCH(req[reg_enum::kOut], Req, {
-      const DType* in_data = inputs[reg_enum::kData].dptr<DType>();
-      DType* out_data = outputs[reg_enum::kOut].dptr<DType>();
-      using namespace mxnet_op;
-      Kernel<op_with_req<ForwardOp, Req>, xpu>::Launch(
-        s, outputs[reg_enum::kOut].Size(), out_data, in_data);
-    });
-  });
+  RegressionForwardImpl<xpu, ForwardOp>(s, req[reg_enum::kOut],
+    inputs[reg_enum::kData], outputs[reg_enum::kOut]);
+}
+
+template<typename xpu, typename ForwardOp>
+void RegressionForwardEx(const nnvm::NodeAttrs& attrs,
+                         const OpContext& ctx,
+                         const std::vector<NDArray>& inputs,
+                         const std::vector<OpReqType>& req,
+                         const std::vector<NDArray>& outputs) {
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(inputs[reg_enum::kData].storage_type(), kDefaultStorage);
+  CHECK_EQ(inputs[reg_enum::kOut].storage_type(), kDefaultStorage);
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  RegressionForwardImpl<xpu, ForwardOp>(s, req[reg_enum::kOut],
+    inputs[reg_enum::kData].data(), outputs[reg_enum::kOut].data());
 }
 
 template<typename xpu, typename BackwardOp>
@@ -101,26 +183,89 @@ void RegressionBackward(const nnvm::NodeAttrs& attrs,
                         const std::vector<TBlob>& inputs,
                         const std::vector<OpReqType>& req,
                         const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 2);
+  CHECK_EQ(outputs.size(), 2);
+  if (req[reg_enum::kData] == kNullOp) return;
   const RegressionOutputParam& param = nnvm::get<RegressionOutputParam>(attrs.parsed);
   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
   // inputs are in_label, out_data
   // outputs are data_grad, label_grad
-  MSHADOW_REAL_TYPE_SWITCH(inputs[1].type_flag_, DType, {
-    MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-      const DType* in_label = inputs[0].dptr<DType>();
-      const DType* out_data = inputs[1].dptr<DType>();
-      DType* data_grad = outputs[0].dptr<DType>();
-      const real_t num_output = inputs[0].Size()/inputs[0].shape_[0];
+  const TBlob& in_label = inputs[0], out_data = inputs[1];
+  const TBlob& data_grad = outputs[0];
+  MSHADOW_REAL_TYPE_SWITCH(out_data.type_flag_, DType, {
+    MXNET_ASSIGN_REQ_SWITCH(req[reg_enum::kData], Req, {
+      const DType* in_label_ptr = in_label.dptr<DType>();
+      const DType* out_data_ptr = out_data.dptr<DType>();
+      DType* data_grad_ptr = data_grad.dptr<DType>();
+      const real_t num_output = in_label.Size()/in_label.shape_[0];
       using namespace mxnet_op;
       Kernel<op_with_req<BackwardOp, Req>, xpu>::Launch(
-        s, outputs[0].Size(), data_grad, out_data, in_label);
+        s, data_grad.Size(), data_grad_ptr, out_data_ptr, in_label_ptr);
       Kernel<op_with_req<mshadow_op::mul, Req>, xpu>::Launch(
-        s, outputs[0].Size(), data_grad, data_grad,
+        s, data_grad.Size(), data_grad_ptr, data_grad_ptr,
         static_cast<DType>(param.grad_scale/num_output));
     });
   });
 }
 
+
+template<typename xpu, typename BackwardOp>
+inline void RegressionBackwardCSRImpl(mshadow::Stream<xpu> *s,
+                                      const RegressionOutputParam& param,
+                                      const OpReqType req,
+                                      const NDArray &data, const NDArray &label,
+                                      const NDArray &data_grad) {
+  if (req == kNullOp) return;
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using namespace csr;
+  const TShape dshape = data.shape();
+  const nnvm::dim_t num_rows = dshape[0];
+  const nnvm::dim_t row_length = dshape[1];
+  CHECK_EQ(label.aux_type(kIndPtr), label.aux_type(kIdx))
+    << "Type of indices array and index pointer array of the label should be the same";
+  MSHADOW_IDX_TYPE_SWITCH(label.aux_type(kIdx), IType, {
+    MSHADOW_REAL_TYPE_SWITCH(label.dtype(), DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+        const IType* label_indptr = label.aux_data(kIndPtr).dptr<IType>();
+        const IType* label_idx = label.aux_data(kIdx).dptr<IType>();
+        const DType* label_data = label.data().dptr<DType>();
+        const DType* data_ptr = data.data().dptr<DType>();
+        DType* grad_ptr = data_grad.data().dptr<DType>();
+        if (req != kWriteInplace) {
+          Kernel<op_with_req<mshadow_op::identity, Req>, xpu>::Launch(s,
+            dshape.Size(), grad_ptr, data_ptr);
+        }
+        Kernel<DnsCsrSparseKernel<BackwardOp, Req>, xpu>::Launch(s, num_rows,
+          grad_ptr, data_ptr, label_data, label_idx, label_indptr, row_length);
+        Kernel<op_with_req<mshadow_op::mul, Req>, xpu>::Launch(s, dshape.Size(),
+          grad_ptr, grad_ptr, static_cast<DType>(param.grad_scale/row_length));
+      });
+    });
+  });
+}
+
+
+template<typename xpu, typename BackwardOP>
+void RegressionBackwardEx(const nnvm::NodeAttrs& attrs,
+                          const OpContext& ctx,
+                          const std::vector<NDArray>& inputs,
+                          const std::vector<OpReqType>& req,
+                          const std::vector<NDArray>& outputs) {
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 2U);
+  const RegressionOutputParam& param = nnvm::get<RegressionOutputParam>(attrs.parsed);
+  const auto label_stype = inputs[0].storage_type();
+  const auto data_stype = inputs[1].storage_type();
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  if (data_stype == kDefaultStorage && label_stype == kCSRStorage) {
+    RegressionBackwardCSRImpl<xpu, BackwardOP>(s, param, req[0], inputs[1],
+      inputs[0], outputs[0]);
+  } else {
+    LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
+  }
+}
+
 struct RegressionOpGrad {
   const char *op_name;
   std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
diff --git a/src/operator/regression_output.cc b/src/operator/regression_output.cc
index 0b8ce69062b..9539a15fc93 100644
--- a/src/operator/regression_output.cc
+++ b/src/operator/regression_output.cc
@@ -26,37 +26,38 @@
 #include "./elemwise_op_common.h"
 
 
-#define MXNET_OPERATOR_REGISTER_REGRESSION_FWD(__name$, __kernel$, __bwdop$)   \
-  NNVM_REGISTER_OP(__name$)                                                    \
-  .set_num_inputs(2)                                                           \
-  .set_num_outputs(1)                                                          \
-  .set_attr<nnvm::FListInputNames>("FListInputNames",                          \
-    [](const NodeAttrs& attrs) {                                               \
-      return std::vector<std::string>{"data", "label"};                        \
-    })                                                                         \
-  .set_attr<nnvm::FInferShape>("FInferShape", RegressionOpShape)               \
-  .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)                \
-  .set_attr<nnvm::FGradient>("FGradient", RegressionOpGrad{__bwdop$})          \
-  .set_attr<nnvm::FInplaceOption>("FInplaceOption",                            \
-  [](const NodeAttrs& attrs){                                                  \
-    return std::vector<std::pair<int, int> >{{0, 0}};                          \
-  })                                                                           \
-  .set_attr<FCompute>("FCompute<cpu>", RegressionForward<cpu, __kernel$>)      \
-  .add_argument("data", "NDArray-or-Symbol", "Input data to the function.")    \
-  .add_argument("label", "NDArray-or-Symbol", "Input label to the function.")  \
+#define MXNET_OPERATOR_REGISTER_REGRESSION_FWD(__name$, __kernel$, __bwdop$)           \
+  NNVM_REGISTER_OP(__name$)                                                            \
+  MXNET_ADD_SPARSE_OP_ALIAS(__name$)                                                   \
+  .set_num_inputs(2)                                                                   \
+  .set_num_outputs(1)                                                                  \
+  .set_attr<nnvm::FListInputNames>("FListInputNames",                                  \
+    [](const NodeAttrs& attrs) {                                                       \
+      return std::vector<std::string>{"data", "label"};                                \
+    })                                                                                 \
+  .set_attr<nnvm::FInferShape>("FInferShape", RegressionOpShape)                       \
+  .set_attr<nnvm::FGradient>("FGradient", RegressionOpGrad{__bwdop$})                  \
+  .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)                        \
+  .set_attr<nnvm::FInplaceOption>("FInplaceOption",                                    \
+  [](const NodeAttrs& attrs){                                                          \
+    return std::vector<std::pair<int, int> >{{0, 0}};                                  \
+  })                                                                                   \
+  .set_attr<FCompute>("FCompute<cpu>", RegressionForward<cpu, __kernel$>)              \
+  .add_argument("data", "NDArray-or-Symbol", "Input data to the function.")            \
+  .add_argument("label", "NDArray-or-Symbol", "Input label to the function.")          \
   .add_arguments(RegressionOutputParam::__FIELDS__())
 
-#define MXNET_OPERATOR_REGISTER_REGRESSION_BWD(__name$, __kernel$)         \
-  NNVM_REGISTER_OP(__name$)                                                \
-  .set_num_inputs(2)                                                       \
-  .set_num_outputs(2)                                                      \
-  .set_attr_parser(ParamParser<RegressionOutputParam>)                     \
-  .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 2>)            \
-  .set_attr<nnvm::TIsBackward>("TIsBackward", true)                        \
-  .set_attr<nnvm::FInplaceOption>("FInplaceOption",                        \
-  [](const NodeAttrs& attrs){                                              \
-    return std::vector<std::pair<int, int> >{{1, 0}};                      \
-  })                                                                       \
+#define MXNET_OPERATOR_REGISTER_REGRESSION_BWD(__name$, __kernel$)                      \
+  NNVM_REGISTER_OP(__name$)                                                             \
+  .set_num_inputs(2)                                                                    \
+  .set_num_outputs(2)                                                                   \
+  .set_attr_parser(ParamParser<RegressionOutputParam>)                                  \
+  .set_attr<nnvm::TIsBackward>("TIsBackward", true)                                     \
+  .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 2>)                         \
+  .set_attr<nnvm::FInplaceOption>("FInplaceOption",                                     \
+  [](const NodeAttrs& attrs){                                                           \
+    return std::vector<std::pair<int, int> >{{1, 0}};                                   \
+  })                                                                                    \
   .set_attr<FCompute>("FCompute<cpu>", RegressionBackward<cpu, __kernel$>)
 
 namespace mxnet {
@@ -67,6 +68,8 @@ DMLC_REGISTER_PARAMETER(RegressionOutputParam);
 
 MXNET_OPERATOR_REGISTER_REGRESSION_FWD(LinearRegressionOutput,
   mshadow_op::identity, "_backward_linear_reg_out")
+.set_attr<FInferStorageType>("FInferStorageType", RegressionInferStorageType<true>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", RegressionForwardEx<cpu, mshadow_op::identity>)
 .describe(R"code(Computes and optimizes for squared loss during backward propagation.
 Just outputs ``data`` during forward propagation.
 
@@ -78,12 +81,19 @@ then the squared loss estimated over :math:`n` samples is defined as
 .. note::
    Use the LinearRegressionOutput as the final output layer of a net.
 
+The storage type of ``label`` can be ``default`` or ``csr``
+
+- LinearRegressionOutput(default, default) = default
+- LinearRegressionOutput(default, csr) = default
+
 By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of regression outputs of a training example.
 The parameter `grad_scale` can be used to change this scale to `grad_scale/m`.
 
 )code" ADD_FILELINE);
 
-MXNET_OPERATOR_REGISTER_REGRESSION_BWD(_backward_linear_reg_out, mshadow_op::minus);
+MXNET_OPERATOR_REGISTER_REGRESSION_BWD(_backward_linear_reg_out, mshadow_op::minus)
+.set_attr<FInferStorageType>("FInferStorageType", RegressionInferStorageType<false>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", RegressionBackwardEx<cpu, mshadow_op::minus>);
 
 MXNET_OPERATOR_REGISTER_REGRESSION_FWD(MAERegressionOutput,
   mshadow_op::identity, "_backward_mae_reg_out")
@@ -99,6 +109,11 @@ then the mean absolute error (MAE) estimated over :math:`n` samples is defined a
 .. note::
    Use the MAERegressionOutput as the final output layer of a net.
 
+The storage type of ``label`` can be ``default`` or ``csr``
+
+- MAERegressionOutput(default, default) = default
+- MAERegressionOutput(default, csr) = default
+
 By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of regression outputs of a training example.
 The parameter `grad_scale` can be used to change this scale to `grad_scale/m`.
 
@@ -108,6 +123,8 @@ MXNET_OPERATOR_REGISTER_REGRESSION_BWD(_backward_mae_reg_out, mshadow_op::minus_
 
 MXNET_OPERATOR_REGISTER_REGRESSION_FWD(LogisticRegressionOutput,
   mshadow_op::sigmoid, "_backward_logistic_reg_out")
+.set_attr<FInferStorageType>("FInferStorageType", RegressionInferStorageType<true>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", RegressionForwardEx<cpu, mshadow_op::sigmoid>)
 .describe(R"code(Applies a logistic function to the input.
 
 The logistic function, also known as the sigmoid function, is computed as
@@ -120,12 +137,19 @@ It is suitable for binary classification or probability prediction tasks.
 .. note::
    Use the LogisticRegressionOutput as the final output layer of a net.
 
+The storage type of ``label`` can be ``default`` or ``csr``
+
+- LogisticRegressionOutput(default, default) = default
+- LogisticRegressionOutput(default, csr) = default
+
 By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of regression outputs of a training example.
 The parameter `grad_scale` can be used to change this scale to `grad_scale/m`.
 
 )code" ADD_FILELINE);
 
-MXNET_OPERATOR_REGISTER_REGRESSION_BWD(_backward_logistic_reg_out, mshadow_op::minus);
+MXNET_OPERATOR_REGISTER_REGRESSION_BWD(_backward_logistic_reg_out, mshadow_op::minus)
+.set_attr<FInferStorageType>("FInferStorageType", RegressionInferStorageType<false>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", RegressionBackwardEx<cpu, mshadow_op::minus>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/regression_output.cu b/src/operator/regression_output.cu
index e3a2e7ea2b2..ca11b84a212 100644
--- a/src/operator/regression_output.cu
+++ b/src/operator/regression_output.cu
@@ -28,10 +28,12 @@ namespace mxnet {
 namespace op {
 
 NNVM_REGISTER_OP(LinearRegressionOutput)
-.set_attr<FCompute>("FCompute<gpu>", RegressionForward<gpu, mshadow_op::identity>);
+.set_attr<FCompute>("FCompute<gpu>", RegressionForward<gpu, mshadow_op::identity>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", RegressionForwardEx<gpu, mshadow_op::identity>);
 
 NNVM_REGISTER_OP(_backward_linear_reg_out)
-.set_attr<FCompute>("FCompute<gpu>", RegressionBackward<gpu, mshadow_op::minus>);
+.set_attr<FCompute>("FCompute<gpu>", RegressionBackward<gpu, mshadow_op::minus>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", RegressionBackwardEx<gpu, mshadow_op::minus>);
 
 NNVM_REGISTER_OP(MAERegressionOutput)
 .set_attr<FCompute>("FCompute<gpu>", RegressionForward<gpu, mshadow_op::identity>);
@@ -40,10 +42,12 @@ NNVM_REGISTER_OP(_backward_mae_reg_out)
 .set_attr<FCompute>("FCompute<gpu>", RegressionBackward<gpu, mshadow_op::minus_sign>);
 
 NNVM_REGISTER_OP(LogisticRegressionOutput)
-.set_attr<FCompute>("FCompute<gpu>", RegressionForward<gpu, mshadow_op::sigmoid>);
+.set_attr<FCompute>("FCompute<gpu>", RegressionForward<gpu, mshadow_op::sigmoid>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", RegressionForwardEx<gpu, mshadow_op::sigmoid>);
 
 NNVM_REGISTER_OP(_backward_logistic_reg_out)
-.set_attr<FCompute>("FCompute<gpu>", RegressionBackward<gpu, mshadow_op::minus>);
+.set_attr<FCompute>("FCompute<gpu>", RegressionBackward<gpu, mshadow_op::minus>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", RegressionBackwardEx<gpu, mshadow_op::minus>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 7889e084f74..1a04e8e024d 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -218,41 +218,57 @@ def check_slice_channel(data_ndim, axis, num_outputs, squeeze_axis):
     check_slice_channel(data_ndim=3, axis=-1, num_outputs=2, squeeze_axis=False)
     check_slice_channel(data_ndim=5, axis=-2, num_outputs=3, squeeze_axis=True)
 
-
-def check_regression(symbol, forward, backward):
-    data = mx.symbol.Variable('data')
-    label = mx.symbol.Variable('label')
-    out = symbol(data, label)
-    shape = (3, 1)
-    arr_data = mx.random.uniform(-1, 1, shape, ctx=mx.cpu()).copyto(default_context())
-    arr_label = mx.random.uniform(0, 1, shape[0], ctx=mx.cpu()).copyto(default_context())
-    arr_grad = mx.nd.empty(shape)
-    exec1 = out.bind(default_context(),
-                     args=[arr_data, arr_label],
-                     args_grad={"data" : arr_grad})
-    exec1.forward(is_train=True)
-    out1 = exec1.outputs[0].asnumpy()
-    npout = forward(arr_data.asnumpy())
-    # Non-zero atol required by test_operator_gpu.py:test_regression with seed 651640549
-    atol = 1e-5
-    assert_almost_equal(npout, out1, atol=atol)
-
-    exec1.backward()
-    npout = backward(npout,  arr_label.asnumpy().reshape(npout.shape))
-    assert_almost_equal(npout, arr_grad.asnumpy(), atol=atol)
-
-
 @with_seed()
 def test_regression():
+    ''' test regression operator '''
+    def check_regression(symbol, forward, backward, shape, stype='default', densities=[0, 0.5, 1]):
+        # init executor
+        data = mx.symbol.Variable('data')
+        label = mx.symbol.Variable('label', stype=stype)
+        out = symbol(data, label)
+        grad_req = {'data': 'write', 'label': 'null'}
+        out_exec = out.simple_bind(default_context(), grad_req=grad_req,
+            data=shape, label=shape)
+        arg_map = dict(zip(out.list_arguments(), out_exec.arg_arrays))
+        grad_map = dict(zip(out.list_arguments(), out_exec.grad_arrays))
+        # init data
+        arr_data = mx.random.uniform(-1, 1, shape)
+        arg_map["data"][:] = arr_data
+        # init label based on density
+        arr_label = arg_map["label"]
+        atol = 1e-5
+        for density in densities:
+            arr_label[:] = rand_ndarray(shape, stype, density=density)
+            out_exec.forward(is_train=True)
+            out_exec.backward()
+            np_out = forward(arr_data.asnumpy())
+            out_grad = backward(np_out, arr_label.asnumpy().reshape(np_out.shape)) / shape[1]
+            assert_almost_equal(out_exec.outputs[0].asnumpy(), np_out, atol=atol)
+            assert_almost_equal(grad_map["data"].asnumpy(), out_grad, atol=atol)
+
+    shape = (50, 30)
+
     check_regression(mx.symbol.LogisticRegressionOutput,
                      lambda x: 1.0 / (1.0 + np.exp(-x)),
-                     lambda x, y : x - y)
+                     lambda x, y : x - y,
+                     shape)
     check_regression(mx.symbol.LinearRegressionOutput,
                      lambda x: x,
-                     lambda x, y : x - y)
+                     lambda x, y : x - y,
+                     shape)
     check_regression(mx.symbol.MAERegressionOutput,
                      lambda x: x,
-                     lambda x, y : np.where(x > y, np.ones(x.shape), -np.ones(x.shape)))
+                     lambda x, y : np.where(x > y, np.ones(x.shape), -np.ones(x.shape)),
+                     shape)
+    check_regression(mx.symbol.LogisticRegressionOutput,
+                     lambda x: 1.0 / (1.0 + np.exp(-x)),
+                     lambda x, y : x - y,
+                     shape, stype='csr')
+    check_regression(mx.symbol.LinearRegressionOutput,
+                     lambda x: x,
+                     lambda x, y : x - y,
+                     shape, stype='csr')
+   
 
 def check_softmax_grad(xpu):
     x = mx.sym.Variable('x')


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services