You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/01/28 02:38:38 UTC

[incubator-mxnet] 03/03: refactor regression ops to nnvm interface (#9540)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch v1.1.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 32caa1033631668016cf60755ae71b4b3186996c
Author: Ziyue Huang <zy...@gmail.com>
AuthorDate: Sun Jan 28 06:24:33 2018 +0800

    refactor regression ops to nnvm interface (#9540)
    
    * refactor regression ops
    
    * fix err for instantiation of minus_sign
    
    * remove useless header file init_op.h
    
    * replace with macro and address other comments
    
    * update
    
    * minor revise docs
    
    * add mae test
---
 src/operator/operator_tune.cc          |   2 +
 src/operator/regression_output-inl.h   | 228 +++++++++++++--------------------
 src/operator/regression_output.cc      | 107 +++++++++-------
 src/operator/regression_output.cu      |  41 +++---
 tests/python/unittest/test_operator.py |   4 +-
 5 files changed, 170 insertions(+), 212 deletions(-)

diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc
index 7cdf7a2..e0f8306 100644
--- a/src/operator/operator_tune.cc
+++ b/src/operator/operator_tune.cc
@@ -286,12 +286,14 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::plus);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minus);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mul);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minus_sign);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rminus);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rdiv);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::plus);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::minus);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mul);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::minus_sign);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rminus);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rdiv);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div_grad);  // NOLINT()
diff --git a/src/operator/regression_output-inl.h b/src/operator/regression_output-inl.h
index 08b2f0a..4642f8d 100644
--- a/src/operator/regression_output-inl.h
+++ b/src/operator/regression_output-inl.h
@@ -18,28 +18,28 @@
  */
 
 /*!
- * Copyright (c) 2015 by Contributors
  * \file regression_ouput-inl.h
  * \brief Regression output operator.
- */
+*/
 #ifndef MXNET_OPERATOR_REGRESSION_OUTPUT_INL_H_
 #define MXNET_OPERATOR_REGRESSION_OUTPUT_INL_H_
 
-#include <dmlc/logging.h>
-#include <mxnet/operator.h>
-#include <map>
-#include <string>
+#include <mxnet/operator_util.h>
 #include <vector>
 #include <utility>
+#include "./mshadow_op.h"
+#include "./mxnet_op.h"
 #include "./operator_common.h"
 
 namespace mxnet {
 namespace op {
 
+/*!
+ * \brief regression namespace
+ */
 namespace reg_enum {
 enum RegressionOutputOpInputs {kData, kLabel};
 enum RegressionOutputOutputs {kOut};
-enum RegressionOutputType {kLinear, kLogistic, kMAE};
 }  // reg_enum
 
 struct RegressionOutputParam : public dmlc::Parameter<RegressionOutputParam> {
@@ -50,146 +50,90 @@ struct RegressionOutputParam : public dmlc::Parameter<RegressionOutputParam> {
   };
 };
 
-// Special Operator to output regression value in forward
-// And get gradient in calculation.
-template<typename xpu, typename ForwardOp, typename BackwardOp>
-class RegressionOutputOp : public Operator {
- public:
-  explicit RegressionOutputOp(RegressionOutputParam param) : param_(param) {}
-
-  virtual void Forward(const OpContext &ctx,
-                       const std::vector<TBlob> &in_data,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &out_data,
-                       const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    CHECK_EQ(in_data.size(), 2U) << "RegressionOutputOp Input: [data, label]";
-    CHECK_EQ(out_data.size(), 1U) << "RegressionOutputOp Output: [output]";
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    Tensor<xpu, 2> data = in_data[reg_enum::kData].FlatTo2D<xpu, real_t>(s);
-    Tensor<xpu, 2> out = out_data[reg_enum::kOut].FlatTo2D<xpu, real_t>(s);
-    Assign(out, req[reg_enum::kOut], F<ForwardOp>(data));
-  }
-
-  virtual void Backward(const OpContext &ctx,
-                        const std::vector<TBlob> &out_grad,
-                        const std::vector<TBlob> &in_data,
-                        const std::vector<TBlob> &out_data,
-                        const std::vector<OpReqType> &req,
-                        const std::vector<TBlob> &in_grad,
-                        const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    CHECK_EQ(in_data.size(), 2U);
-    CHECK_EQ(out_grad.size(), 1U);
-    CHECK_GE(in_grad.size(), 1U);
-    CHECK_GE(req.size(), 1U);
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    real_t num_output =
-      in_data[reg_enum::kLabel].Size()/in_data[reg_enum::kLabel].shape_[0];
-    Tensor<xpu, 2> out = out_data[reg_enum::kOut].FlatTo2D<xpu, real_t>(s);
-    Tensor<xpu, 2> grad = in_grad[reg_enum::kData].FlatTo2D<xpu, real_t>(s);
-    Tensor<xpu, 2> label = in_data[reg_enum::kLabel]
-      .get_with_shape<xpu, 2, real_t>(out.shape_, s);
-    Assign(grad, req[reg_enum::kData], param_.grad_scale/num_output*
-      F<BackwardOp>(out, reshape(label, grad.shape_)));
-  }
-
- private:
-  RegressionOutputParam param_;
-};
-
-// Decalre Factory function, used for dispatch specialization
-template<typename xpu>
-Operator* CreateRegressionOutputOp(reg_enum::RegressionOutputType type,
-                                   RegressionOutputParam param);
-
-#if DMLC_USE_CXX11
-template<reg_enum::RegressionOutputType type>
-class RegressionOutputProp : public OperatorProperty {
- public:
-  std::vector<std::string> ListArguments() const override {
-    return {"data", "label"};
-  }
-
-  void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
-    param_.Init(kwargs);
-  }
-
-  std::map<std::string, std::string> GetParams() const override {
-    return param_.__DICT__();
-  }
-
-  bool InferShape(std::vector<TShape> *in_shape,
-                  std::vector<TShape> *out_shape,
-                  std::vector<TShape> *aux_shape) const override {
-    using namespace mshadow;
-    CHECK_EQ(in_shape->size(), 2) << "Input:[data, label]";
-    const TShape &dshape = in_shape->at(0);
-    if (dshape.ndim() == 0) return false;
-    auto &lshape = (*in_shape)[1];
-    if (lshape.ndim() == 0) {
-      // special treatment for 1D output, to allow 1D label by default.
-      // Think about change convention later
-      if (dshape.ndim() == 2 && dshape[1] == 1) {
-        lshape = Shape1(dshape[0]);
-      } else {
-        lshape = dshape;
-      }
-    } else if (lshape[0] != dshape[0] || lshape.Size() != dshape.Size()) {
-      std::ostringstream os;
-      os << "Shape inconsistent, Provided=" << lshape << ','
-         << " inferred shape=" << dshape;
-      throw ::mxnet::op::InferShapeError(os.str(), 1);
-    }
-    out_shape->clear();
-    out_shape->push_back(dshape);
-    return true;
-  }
-
-  OperatorProperty* Copy() const override {
-    auto ptr = new RegressionOutputProp<type>();
-    ptr->param_ = param_;
-    return ptr;
-  }
-
-  std::string TypeString() const override {
-    switch (type) {
-      case reg_enum::kLinear: return "LinearRegressionOutput";
-      case reg_enum::kLogistic: return "LogisticRegressionOutput";
-      case reg_enum::kMAE: return "MAERegressionOutput";
-      default: LOG(FATAL) << "unknown type"; return "";
+inline bool RegressionOpShape(const nnvm::NodeAttrs& attrs,
+                              std::vector<TShape> *in_attrs,
+                              std::vector<TShape> *out_attrs) {
+  using namespace mshadow;
+  CHECK_EQ(in_attrs->size(), 2U) << "Input:[data, label]";
+  const TShape &dshape = in_attrs->at(0);
+  if (dshape.ndim() == 0) return false;
+  auto &lshape = (*in_attrs)[1];
+  if (lshape.ndim() == 0) {
+    // special treatment for 1D output, to allow 1D label by default.
+    // Think about change convention later
+    if (dshape.ndim() == 2 && dshape[1] == 1) {
+      lshape = Shape1(dshape[0]);
+    } else {
+      lshape = dshape;
     }
+  } else if (lshape[0] != dshape[0] || lshape.Size() != dshape.Size()) {
+    std::ostringstream os;
+    os << "Shape inconsistent, Provided=" << lshape << ','
+       << " inferred shape=" << dshape;
+    throw ::mxnet::op::InferShapeError(os.str(), 1);
   }
-
-  std::vector<int> DeclareBackwardDependency(
-    const std::vector<int> &out_grad,
-    const std::vector<int> &in_data,
-    const std::vector<int> &out_data) const override {
-    return {in_data[reg_enum::kLabel], out_data[reg_enum::kOut]};
-  }
-
-  std::vector<std::pair<int, void*> > BackwardInplaceOption(
-    const std::vector<int> &out_grad,
-    const std::vector<int> &in_data,
-    const std::vector<int> &out_data,
-    const std::vector<void*> &in_grad) const override {
-    return {{out_data[reg_enum::kOut], in_grad[reg_enum::kData]}};
-  }
-
-  std::vector<std::pair<int, void*> > ForwardInplaceOption(
-    const std::vector<int> &in_data,
-    const std::vector<void*> &out_data) const override {
-    return {{in_data[reg_enum::kData], out_data[reg_enum::kOut]}};
+  out_attrs->clear();
+  out_attrs->push_back(dshape);
+  return true;
+}
+
+template<typename xpu, typename ForwardOp>
+void RegressionForward(const nnvm::NodeAttrs& attrs,
+                       const OpContext& ctx,
+                       const std::vector<TBlob>& inputs,
+                       const std::vector<OpReqType>& req,
+                       const std::vector<TBlob>& outputs) {
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  MSHADOW_REAL_TYPE_SWITCH(inputs[reg_enum::kData].type_flag_, DType, {
+    MXNET_ASSIGN_REQ_SWITCH(req[reg_enum::kOut], Req, {
+      const DType* in_data = inputs[reg_enum::kData].dptr<DType>();
+      DType* out_data = outputs[reg_enum::kOut].dptr<DType>();
+      using namespace mxnet_op;
+      Kernel<op_with_req<ForwardOp, Req>, xpu>::Launch(
+        s, outputs[reg_enum::kOut].Size(), out_data, in_data);
+    });
+  });
+}
+
+template<typename xpu, typename BackwardOp>
+void RegressionBackward(const nnvm::NodeAttrs& attrs,
+                        const OpContext& ctx,
+                        const std::vector<TBlob>& inputs,
+                        const std::vector<OpReqType>& req,
+                        const std::vector<TBlob>& outputs) {
+  const RegressionOutputParam& param = nnvm::get<RegressionOutputParam>(attrs.parsed);
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  // inputs are in_label, out_data
+  // outputs are data_grad, label_grad
+  MSHADOW_REAL_TYPE_SWITCH(inputs[1].type_flag_, DType, {
+    MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+      const DType* in_label = inputs[0].dptr<DType>();
+      const DType* out_data = inputs[1].dptr<DType>();
+      DType* data_grad = outputs[0].dptr<DType>();
+      const real_t num_output = inputs[0].Size()/inputs[0].shape_[0];
+      using namespace mxnet_op;
+      Kernel<op_with_req<BackwardOp, Req>, xpu>::Launch(
+        s, outputs[0].Size(), data_grad, out_data, in_label);
+      Kernel<op_with_req<mshadow_op::mul, Req>, xpu>::Launch(
+        s, outputs[0].Size(), data_grad, data_grad,
+        static_cast<DType>(param.grad_scale/num_output));
+    });
+  });
+}
+
+struct RegressionOpGrad {
+  const char *op_name;
+  std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
+                                          const std::vector<nnvm::NodeEntry>& ograds) const {
+    std::vector<nnvm::NodeEntry> heads;
+    heads.push_back(n->inputs[reg_enum::kLabel]);
+    heads.emplace_back(nnvm::NodeEntry{n, reg_enum::kOut, 0});
+    return MakeGradNode(op_name, n, heads, n->attrs.dict);
   }
+};
 
-  Operator* CreateOperator(Context ctx) const override;
 
- protected:
-  RegressionOutputParam param_;
-};
-#endif  // DMLC_USE_CXX11
 }  // namespace op
 }  // namespace mxnet
+
 #endif  // MXNET_OPERATOR_REGRESSION_OUTPUT_INL_H_
diff --git a/src/operator/regression_output.cc b/src/operator/regression_output.cc
index 2f8042e..7b0fbae 100644
--- a/src/operator/regression_output.cc
+++ b/src/operator/regression_output.cc
@@ -18,61 +18,71 @@
  */
 
 /*!
- * Copyright (c) 2015 by Contributors
- * \file regression_output.cc
- * \brief regression output operator
+ * \file regression_ouput.cc
+ * \brief Regression output operator.
 */
+
 #include "./regression_output-inl.h"
-#include "./mshadow_op.h"
+
+#define MXNET_OPERATOR_REGISTER_REGRESSION_FWD(__name$, __kernel$, __bwdop$)   \
+  NNVM_REGISTER_OP(__name$)                                                    \
+  .set_num_inputs(2)                                                           \
+  .set_num_outputs(1)                                                          \
+  .set_attr<nnvm::FListInputNames>("FListInputNames",                          \
+    [](const NodeAttrs& attrs) {                                               \
+      return std::vector<std::string>{"data", "label"};                        \
+    })                                                                         \
+  .set_attr<nnvm::FInferShape>("FInferShape", RegressionOpShape)               \
+  .set_attr<nnvm::FGradient>("FGradient", RegressionOpGrad{__bwdop$})          \
+  .set_attr<nnvm::FInplaceOption>("FInplaceOption",                            \
+  [](const NodeAttrs& attrs){                                                  \
+    return std::vector<std::pair<int, int> >{{0, 0}};                          \
+  })                                                                           \
+  .set_attr<FCompute>("FCompute<cpu>", RegressionForward<cpu, __kernel$>)      \
+  .add_argument("data", "NDArray-or-Symbol", "Input data to the function.")    \
+  .add_argument("label", "NDArray-or-Symbol", "Input label to the function.")  \
+  .add_arguments(RegressionOutputParam::__FIELDS__())
+
+#define MXNET_OPERATOR_REGISTER_REGRESSION_BWD(__name$, __kernel$)         \
+  NNVM_REGISTER_OP(__name$)                                                \
+  .set_num_inputs(2)                                                       \
+  .set_num_outputs(2)                                                      \
+  .set_attr_parser(ParamParser<RegressionOutputParam>)                     \
+  .set_attr<nnvm::TIsBackward>("TIsBackward", true)                        \
+  .set_attr<nnvm::FInplaceOption>("FInplaceOption",                        \
+  [](const NodeAttrs& attrs){                                              \
+    return std::vector<std::pair<int, int> >{{1, 0}};                      \
+  })                                                                       \
+  .set_attr<FCompute>("FCompute<cpu>", RegressionBackward<cpu, __kernel$>)
 
 namespace mxnet {
 namespace op {
 
-template<>
-Operator *CreateRegressionOutputOp<cpu>(reg_enum::RegressionOutputType type,
-                                        RegressionOutputParam param) {
-  switch (type) {
-    case reg_enum::kLinear:
-      return new RegressionOutputOp<cpu, op::mshadow_op::identity, op::mshadow_op::minus>(param);
-    case reg_enum::kLogistic:
-      return new RegressionOutputOp<cpu, mshadow_op::sigmoid, op::mshadow_op::minus>(param);
-    case reg_enum::kMAE:
-      return new RegressionOutputOp<cpu, op::mshadow_op::identity, mshadow_op::minus_sign>(param);
-    default:
-      LOG(FATAL) << "unknown activation type " << type;
-  }
-  return nullptr;
-}
-
-// DO_BIND_DISPATCH comes from operator_common.h
-template<reg_enum::RegressionOutputType type>
-Operator *RegressionOutputProp<type>::CreateOperator(Context ctx) const {
-  DO_BIND_DISPATCH(CreateRegressionOutputOp, type, param_);
-}
 
 DMLC_REGISTER_PARAMETER(RegressionOutputParam);
 
-MXNET_REGISTER_OP_PROPERTY(LinearRegressionOutput, RegressionOutputProp<reg_enum::kLinear>)
+MXNET_OPERATOR_REGISTER_REGRESSION_FWD(LinearRegressionOutput,
+  mshadow_op::identity, "_backward_linear_reg_out")
 .describe(R"code(Computes and optimizes for squared loss during backward propagation.
 Just outputs ``data`` during forward propagation.
 
 If :math:`\hat{y}_i` is the predicted value of the i-th sample, and :math:`y_i` is the corresponding target value,
 then the squared loss estimated over :math:`n` samples is defined as
 
-:math:`\text{SquaredLoss}(y, \hat{y} ) = \frac{1}{n} \sum_{i=0}^{n-1} \left( y_i - \hat{y}_i \right)^2`
+:math:`\text{SquaredLoss}(\textbf{Y}, \hat{\textbf{Y}} ) = \frac{1}{n} \sum_{i=0}^{n-1} \lVert  \textbf{y}_i - \hat{\textbf{y}}_i  \rVert_2`
 
 .. note::
    Use the LinearRegressionOutput as the final output layer of a net.
 
-By default, gradients of this loss function are scaled by factor `1/n`, where n is the number of training examples.
-The parameter `grad_scale` can be used to change this scale to `grad_scale/n`.
+By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of regression outputs of a training example.
+The parameter `grad_scale` can be used to change this scale to `grad_scale/m`.
+
+)code" ADD_FILELINE);
 
-)code" ADD_FILELINE)
-.add_argument("data", "NDArray-or-Symbol", "Input data to the function.")
-.add_argument("label", "NDArray-or-Symbol", "Input label to the function.")
-.add_arguments(RegressionOutputParam::__FIELDS__());
+MXNET_OPERATOR_REGISTER_REGRESSION_BWD(_backward_linear_reg_out, mshadow_op::minus);
 
-MXNET_REGISTER_OP_PROPERTY(MAERegressionOutput, RegressionOutputProp<reg_enum::kMAE>)
+MXNET_OPERATOR_REGISTER_REGRESSION_FWD(MAERegressionOutput,
+  mshadow_op::identity, "_backward_mae_reg_out")
 .describe(R"code(Computes mean absolute error of the input.
 
 MAE is a risk metric corresponding to the expected value of the absolute error.
@@ -80,24 +90,24 @@ MAE is a risk metric corresponding to the expected value of the absolute error.
 If :math:`\hat{y}_i` is the predicted value of the i-th sample, and :math:`y_i` is the corresponding target value,
 then the mean absolute error (MAE) estimated over :math:`n` samples is defined as
 
-:math:`\text{MAE}(y, \hat{y} ) = \frac{1}{n} \sum_{i=0}^{n-1} \left| y_i - \hat{y}_i \right|`
+:math:`\text{MAE}(\textbf{Y}, \hat{\textbf{Y}} ) = \frac{1}{n} \sum_{i=0}^{n-1} \lVert \textbf{y}_i - \hat{\textbf{y}}_i \rVert_1`
 
 .. note::
    Use the MAERegressionOutput as the final output layer of a net.
 
-By default, gradients of this loss function are scaled by factor `1/n`, where n is the number of training examples.
-The parameter `grad_scale` can be used to change this scale to `grad_scale/n`.
+By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of regression outputs of a training example.
+The parameter `grad_scale` can be used to change this scale to `grad_scale/m`.
 
-)code" ADD_FILELINE)
-.add_argument("data", "NDArray-or-Symbol", "Input data to the function.")
-.add_argument("label", "NDArray-or-Symbol", "Input label to the function.")
-.add_arguments(RegressionOutputParam::__FIELDS__());
+)code" ADD_FILELINE);
 
-MXNET_REGISTER_OP_PROPERTY(LogisticRegressionOutput, RegressionOutputProp<reg_enum::kLogistic>)
+MXNET_OPERATOR_REGISTER_REGRESSION_BWD(_backward_mae_reg_out, mshadow_op::minus_sign);
+
+MXNET_OPERATOR_REGISTER_REGRESSION_FWD(LogisticRegressionOutput,
+  mshadow_op::sigmoid, "_backward_logistic_reg_out")
 .describe(R"code(Applies a logistic function to the input.
 
 The logistic function, also known as the sigmoid function, is computed as
-:math:`\frac{1}{1+exp(-x)}`.
+:math:`\frac{1}{1+exp(-\textbf{x})}`.
 
 Commonly, the sigmoid is used to squash the real-valued output of a linear model
 :math:wTx+b into the [0,1] range so that it can be interpreted as a probability.
@@ -106,13 +116,12 @@ It is suitable for binary classification or probability prediction tasks.
 .. note::
    Use the LogisticRegressionOutput as the final output layer of a net.
 
-By default, gradients of this loss function are scaled by factor `1/n`, where n is the number of training examples.
-The parameter `grad_scale` can be used to change this scale to `grad_scale/n`.
+By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of regression outputs of a training example.
+The parameter `grad_scale` can be used to change this scale to `grad_scale/m`.
+
+)code" ADD_FILELINE);
 
-)code" ADD_FILELINE)
-.add_argument("data", "NDArray-or-Symbol", "Input data to the function.")
-.add_argument("label", "NDArray-or-Symbol", "Input label to the function.")
-.add_arguments(RegressionOutputParam::__FIELDS__());
+MXNET_OPERATOR_REGISTER_REGRESSION_BWD(_backward_logistic_reg_out, mshadow_op::minus);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/regression_output.cu b/src/operator/regression_output.cu
index cb951f1..e3a2e7e 100644
--- a/src/operator/regression_output.cu
+++ b/src/operator/regression_output.cu
@@ -18,31 +18,32 @@
  */
 
 /*!
- * Copyright (c) 2015 by Contributors
- * \file regression_output.cu
- * \brief regression output operator
+ * \file regression_ouput.cu
+ * \brief Regression output operator.
 */
 #include "./regression_output-inl.h"
-#include "./mshadow_op.h"
+
 
 namespace mxnet {
 namespace op {
 
-template<>
-Operator *CreateRegressionOutputOp<gpu>(reg_enum::RegressionOutputType type,
-                                        RegressionOutputParam param) {
-  switch (type) {
-    case reg_enum::kLinear:
-      return new RegressionOutputOp<gpu, op::mshadow_op::identity, op::mshadow_op::minus>(param);
-    case reg_enum::kLogistic:
-      return new RegressionOutputOp<gpu, mshadow_op::sigmoid, op::mshadow_op::minus>(param);
-    case reg_enum::kMAE:
-      return new RegressionOutputOp<gpu, op::mshadow_op::identity, mshadow_op::minus_sign>(param);
-    default:
-      LOG(FATAL) << "unknown activation type " << type;
-  }
-  return NULL;
-}
+NNVM_REGISTER_OP(LinearRegressionOutput)
+.set_attr<FCompute>("FCompute<gpu>", RegressionForward<gpu, mshadow_op::identity>);
+
+NNVM_REGISTER_OP(_backward_linear_reg_out)
+.set_attr<FCompute>("FCompute<gpu>", RegressionBackward<gpu, mshadow_op::minus>);
+
+NNVM_REGISTER_OP(MAERegressionOutput)
+.set_attr<FCompute>("FCompute<gpu>", RegressionForward<gpu, mshadow_op::identity>);
+
+NNVM_REGISTER_OP(_backward_mae_reg_out)
+.set_attr<FCompute>("FCompute<gpu>", RegressionBackward<gpu, mshadow_op::minus_sign>);
+
+NNVM_REGISTER_OP(LogisticRegressionOutput)
+.set_attr<FCompute>("FCompute<gpu>", RegressionForward<gpu, mshadow_op::sigmoid>);
+
+NNVM_REGISTER_OP(_backward_logistic_reg_out)
+.set_attr<FCompute>("FCompute<gpu>", RegressionBackward<gpu, mshadow_op::minus>);
+
 }  // namespace op
 }  // namespace mxnet
-
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 742d055..640cd34 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -244,7 +244,9 @@ def test_regression():
     check_regression(mx.symbol.LinearRegressionOutput,
                      lambda x: x,
                      lambda x, y : x - y)
-
+    check_regression(mx.symbol.MAERegressionOutput,
+                     lambda x: x,
+                     lambda x, y : np.where(x > y, np.ones(x.shape), -np.ones(x.shape)))
 
 def check_softmax_grad(xpu):
     x = mx.sym.Variable('x')

-- 
To stop receiving notification emails like this one, please contact
haibin@apache.org.