You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/08 22:10:38 UTC

[GitHub] piiswrong closed pull request #10847: [MXNET-408] [WIP] inplace ReLU activation

piiswrong closed pull request #10847: [MXNET-408] [WIP] inplace ReLU activation
URL: https://github.com/apache/incubator-mxnet/pull/10847
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/nn/activation-inl.h b/src/operator/nn/activation-inl.h
index 32a7a5ad617..a9f6dbeda89 100644
--- a/src/operator/nn/activation-inl.h
+++ b/src/operator/nn/activation-inl.h
@@ -83,7 +83,7 @@ struct hash<mxnet::op::ActivationParam> {
 namespace mxnet {
 namespace op {
 
-template<typename xpu, typename ForwardOp, typename BackwardOp, typename DType>
+template<typename xpu, typename ForwardOp, typename BackwardOp>
 void ActivationForward(const OpContext &ctx, const TBlob &in_data,
                        const OpReqType &req, const TBlob &out_data) {
   using namespace mshadow;
@@ -91,16 +91,16 @@ void ActivationForward(const OpContext &ctx, const TBlob &in_data,
   Stream<xpu> *s = ctx.get_stream<xpu>();
   const size_t sz = in_data.shape_.Size();
   if (sz) {
-    MXNET_ASSIGN_REQ_SWITCH(req, Req, {
-      mxnet_op::Kernel<mxnet_op::op_with_req<ForwardOp, Req>, xpu>::Launch(
-        s, sz,
-        out_data.dptr<DType>(),
-        in_data.dptr<DType>());
+    MSHADOW_REAL_TYPE_SWITCH(in_data.type_flag_, DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+        mxnet_op::Kernel<mxnet_op::op_with_req<ForwardOp, Req>, xpu>::Launch(
+          s, sz, out_data.dptr<DType>(), in_data.dptr<DType>());
+      });
     });
   }
 }
 
-template<typename xpu, typename ForwardOp, typename BackwardOp, typename DType>
+template<typename xpu, typename ForwardOp, typename BackwardOp>
 void ActivationBackward(const OpContext &ctx, const TBlob &out_grad,
                         const TBlob &out_data, const OpReqType &req,
                         const TBlob &in_grad) {
@@ -109,13 +109,12 @@ void ActivationBackward(const OpContext &ctx, const TBlob &out_grad,
   Stream<xpu> *s = ctx.get_stream<xpu>();
   const size_t sz = out_data.shape_.Size();
   if (sz) {
-    MXNET_ASSIGN_REQ_SWITCH(req, Req, {
-      mxnet_op::Kernel<mxnet_op::op_with_req<
-        mxnet::op::mxnet_op::backward_grad_tuned<BackwardOp>, Req>, xpu>::Launch(
-        s, sz,
-        in_grad.dptr<DType>(),
-        out_grad.dptr<DType>(),
-        out_data.dptr<DType>());
+    MSHADOW_REAL_TYPE_SWITCH(out_grad.type_flag_, DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+        mxnet_op::Kernel<mxnet_op::op_with_req<
+          mxnet_op::backward_grad_tuned<BackwardOp>, Req>, xpu>::Launch(
+            s, sz, in_grad.dptr<DType>(), out_grad.dptr<DType>(), out_data.dptr<DType>());
+      });
     });
   }
 }
@@ -123,72 +122,68 @@ void ActivationBackward(const OpContext &ctx, const TBlob &out_grad,
 template<typename xpu>
 void ActivationComputeImpl(const ActivationParam &param, const OpContext &ctx,
                            const TBlob &input, OpReqType req, const TBlob &output) {
-  MSHADOW_REAL_TYPE_SWITCH(input.type_flag_, DType, {
-    switch (param.act_type) {
-      case activation::kReLU:
-        ActivationForward<xpu, mshadow_op::relu, mshadow_op::relu_grad, DType>(
-            ctx, input, req, output);
-        break;
-      case activation::kSigmoid:
-        ActivationForward<xpu, mshadow_op::sigmoid, mshadow_op::sigmoid_grad, DType>(
-            ctx, input, req, output);
-        break;
-      case activation::kTanh:
-        ActivationForward<xpu, mshadow_op::tanh, mshadow_op::tanh_grad, DType>(
-            ctx, input, req, output);
-        break;
-      case activation::kSoftReLU:
-        ActivationForward<xpu, mshadow_op::softrelu, mshadow_op::softrelu_grad, DType>(
-            ctx, input, req, output);
-        break;
-      case activation::kSoftSign:
-        ActivationForward<xpu, mshadow_op::softsign, mshadow_op::softsign_grad, DType>(
-                ctx, input, req, output);
-            break;
-      default:
-        LOG(FATAL) << "unknown activation type";
-    }
-  });
+  switch (param.act_type) {
+    case activation::kReLU:
+      ActivationForward<xpu, mshadow_op::relu, mshadow_op::relu_grad>(
+          ctx, input, req, output);
+      break;
+    case activation::kSigmoid:
+      ActivationForward<xpu, mshadow_op::sigmoid, mshadow_op::sigmoid_grad>(
+          ctx, input, req, output);
+      break;
+    case activation::kTanh:
+      ActivationForward<xpu, mshadow_op::tanh, mshadow_op::tanh_grad>(
+          ctx, input, req, output);
+      break;
+    case activation::kSoftReLU:
+      ActivationForward<xpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(
+          ctx, input, req, output);
+      break;
+    case activation::kSoftSign:
+      ActivationForward<xpu, mshadow_op::softsign, mshadow_op::softsign_grad>(
+              ctx, input, req, output);
+          break;
+    default:
+      LOG(FATAL) << "unknown activation type";
+  }
 }
 
 template<typename xpu>
 void ActivationGradComputeImpl(const ActivationParam &param, const OpContext &ctx,
                                const TBlob &out_grad, const TBlob &out_data,
                                OpReqType req, const TBlob &output) {
-  MSHADOW_REAL_TYPE_SWITCH(out_grad.type_flag_, DType, {
-    switch (param.act_type) {
-      case activation::kReLU:
-        ActivationBackward<xpu, mshadow_op::relu, mshadow_op::relu_grad, DType>(
-            ctx, out_grad, out_data, req, output);
-        break;
-      case activation::kSigmoid:
-        ActivationBackward<xpu, mshadow_op::sigmoid, mshadow_op::sigmoid_grad, DType>(
-            ctx, out_grad, out_data, req, output);
-        break;
-      case activation::kTanh:
-        ActivationBackward<xpu, mshadow_op::tanh, mshadow_op::tanh_grad, DType>(
-            ctx, out_grad, out_data, req, output);
-        break;
-      case activation::kSoftReLU:
-        ActivationBackward<xpu, mshadow_op::softrelu, mshadow_op::softrelu_grad, DType>(
-            ctx, out_grad, out_data, req, output);
-        break;
-      case activation::kSoftSign:
-        ActivationBackward<xpu, mshadow_op::softsign, mshadow_op::softsign_grad, DType>(
-                ctx, out_grad, out_data, req, output);
-            break;
-      default:
-        LOG(FATAL) << "unknown activation type";
-    }
-  });
+  switch (param.act_type) {
+    case activation::kReLU:
+      ActivationBackward<xpu, mshadow_op::relu, mshadow_op::relu_grad>(
+          ctx, out_grad, out_data, req, output);
+      break;
+    case activation::kSigmoid:
+      ActivationBackward<xpu, mshadow_op::sigmoid, mshadow_op::sigmoid_grad>(
+          ctx, out_grad, out_data, req, output);
+      break;
+    case activation::kTanh:
+      ActivationBackward<xpu, mshadow_op::tanh, mshadow_op::tanh_grad>(
+          ctx, out_grad, out_data, req, output);
+      break;
+    case activation::kSoftReLU:
+      ActivationBackward<xpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(
+          ctx, out_grad, out_data, req, output);
+      break;
+    case activation::kSoftSign:
+      ActivationBackward<xpu, mshadow_op::softsign, mshadow_op::softsign_grad>(
+              ctx, out_grad, out_data, req, output);
+          break;
+    default:
+      LOG(FATAL) << "unknown activation type";
+  }
 }
 
 template<typename xpu>
 void ActivationCompute(const nnvm::NodeAttrs& attrs,
-    const OpContext& ctx,
-    const std::vector<TBlob>& inputs,
-    const std::vector<OpReqType>& req,
-    const std::vector<TBlob>& outputs) {
+                       const OpContext& ctx,
+                       const std::vector<TBlob>& inputs,
+                       const std::vector<OpReqType>& req,
+                       const std::vector<TBlob>& outputs) {
   CHECK_EQ(inputs.size(), 1U);
   CHECK_EQ(outputs.size(), 1U);
   const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
@@ -197,18 +192,19 @@ void ActivationCompute(const nnvm::NodeAttrs& attrs,
 
 template<typename xpu>
 void ActivationGradCompute(const nnvm::NodeAttrs& attrs,
-    const OpContext& ctx,
-    const std::vector<TBlob>& inputs,
-    const std::vector<OpReqType>& req,
-    const std::vector<TBlob>& outputs) {
+                           const OpContext& ctx,
+                           const std::vector<TBlob>& inputs,
+                           const std::vector<OpReqType>& req,
+                           const std::vector<TBlob>& outputs) {
+  const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
 #if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1)
-  CHECK_EQ(inputs.size(), 3U);
+  bool relu = param.act_type == activation::kReLU;
+  CHECK_EQ(inputs.size(), relu ? 2U : 3U);
 #else
   CHECK_EQ(inputs.size(), 2U);
 #endif
   CHECK_EQ(outputs.size(), 1U);
   CHECK_EQ(req.size(), 1U);
-  const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
   ActivationGradComputeImpl<xpu>(param, ctx, inputs[0], inputs[1], req[0], outputs[0]);
 }
 
diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc
index 382efeb1447..595b8912ccc 100644
--- a/src/operator/nn/activation.cc
+++ b/src/operator/nn/activation.cc
@@ -45,7 +45,12 @@ struct ActivationGrad {
     std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
     heads.emplace_back(nnvm::NodeEntry{n, activation::kOut, 0});
 #if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1)
-    heads.push_back(n->inputs[activation::kData]);
+    const NodeAttrs& attrs = n->attrs;
+    // for ReLU, no need to pass input data. This enables inplace optimization during the
+    // forward pass.
+    if (dmlc::get<ActivationParam>(attrs.parsed).act_type != activation::kReLU) {
+      heads.push_back(n->inputs[activation::kData]);
+    }
 #endif
     return MakeGradNode(op_name, n, heads, n->attrs.dict);
   }
@@ -74,13 +79,15 @@ void ActivationGradComputeExCPU(const nnvm::NodeAttrs& attrs,
                                 const std::vector<NDArray>& inputs,
                                 const std::vector<OpReqType>& req,
                                 const std::vector<NDArray>& outputs) {
-  CHECK_EQ(inputs.size(), 3U);
   const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
+  bool relu = param.act_type == activation::kReLU;
+  CHECK_EQ(inputs.size(), relu ? 2U : 3U);
   if (SupportMKLDNN(inputs[0])) {
     MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
-    MKLDNNActivationBackward(attrs, ctx, inputs[0], inputs[2], req[0],
+    // XXX: for y = relu(x), y is passed as "in_data" to Backward()
+    MKLDNNActivationBackward(attrs, ctx, inputs[0], relu ? inputs[1] : inputs[2], req[0],
                              outputs[0]);
-      MKLDNN_OPCHECK_RUN(ActivationGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
+     MKLDNN_OPCHECK_RUN(ActivationGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
     return;
   }
   ActivationGradComputeImpl<cpu>(param, ctx, inputs[0].data(), inputs[1].data(),
@@ -112,23 +119,29 @@ inline static bool BackwardActStorageType(const nnvm::NodeAttrs& attrs,
                                           DispatchMode* dispatch_mode,
                                           std::vector<int> *in_attrs,
                                           std::vector<int> *out_attrs) {
+  bool ret = false;
 #if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1)
-  CHECK_EQ(in_attrs->size(), 3U);
+  const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
+  if (param.act_type != activation::kReLU) {
+    CHECK_EQ(in_attrs->size(), 3U);
+    ret = ElemwiseStorageType<3, 1, false, false, false>(attrs, dev_mask,
+                                                         dispatch_mode,
+                                                         in_attrs, out_attrs);
+  } else {
+    // for ReLU activation, the backward pass only needs ograd and output
+    CHECK_EQ(in_attrs->size(), 2U);
+    ret = ElemwiseStorageType<2, 1, false, false, false>(attrs, dev_mask,
+                                                         dispatch_mode,
+                                                         in_attrs, out_attrs);
+  }
 #else
   CHECK_EQ(in_attrs->size(), 2U);
+  ret = ElemwiseStorageType<2, 1, false, false, false>(attrs, dev_mask,
+                                                       dispatch_mode,
+                                                       in_attrs, out_attrs);
 #endif
   CHECK_EQ(out_attrs->size(), 1U);
-#if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1)
-  bool ret = ElemwiseStorageType<3, 1, false, false, false>(attrs, dev_mask,
-                                                            dispatch_mode,
-                                                            in_attrs, out_attrs);
-#else
-  bool ret = ElemwiseStorageType<2, 1, false, false, false>(attrs, dev_mask,
-                                                            dispatch_mode,
-                                                            in_attrs, out_attrs);
-#endif
 #if MXNET_USE_MKLDNN == 1
-  const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
   if (dev_mask == mshadow::cpu::kDevMask && SupportMKLDNNAct(param)) {
     *dispatch_mode = DispatchMode::kFComputeEx;
   }
@@ -162,7 +175,12 @@ The following activation functions are supported:
 .add_arguments(ActivationParam::__FIELDS__());
 
 NNVM_REGISTER_OP(_backward_Activation)
-.set_num_inputs(3)
+.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
+    int act_type = dmlc::get<ActivationParam>(attrs.parsed).act_type;
+    // for ReLU activation, the backward pass only needs ograd and output
+    if (act_type == activation::kReLU) return 2;
+    return 3;
+  })
 .set_num_outputs(1)
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
 .set_attr<FInferStorageType>("FInferStorageType", BackwardActStorageType)
diff --git a/src/operator/nn/activation.cu b/src/operator/nn/activation.cu
index dc435b2acc1..68b4053efdd 100644
--- a/src/operator/nn/activation.cu
+++ b/src/operator/nn/activation.cu
@@ -55,12 +55,13 @@ void ActivationCompute<gpu>(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(outputs.size(), 1U);
   const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
 
-  // SoftReLU not supported by CUDNN yet
+  // SoftReLU and kSoftSign are both not supported by CUDNN yet
   if (param.act_type == activation::kSoftReLU) {
-    MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
-      ActivationForward<gpu, mshadow_op::softrelu, mshadow_op::softrelu_grad, DType>(ctx,
-          inputs[0], req[0], outputs[0]);
-    });
+    ActivationForward<gpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(ctx,
+      inputs[0], req[0], outputs[0]);
+  } else if (param.act_type == activation::kSoftSign) {
+    ActivationForward<gpu, mshadow_op::softsign, mshadow_op::softsign_grad>(ctx,
+      inputs[0], req[0], outputs[0]);
   } else {
     MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
       get_cudnn_op<DType>(param).Forward(ctx, inputs[0], req[0], outputs[0]);
@@ -70,24 +71,28 @@ void ActivationCompute<gpu>(const nnvm::NodeAttrs& attrs,
 
 template<>
 void ActivationGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
-    const OpContext& ctx,
-    const std::vector<TBlob>& inputs,
-    const std::vector<OpReqType>& req,
-    const std::vector<TBlob>& outputs) {
-  CHECK_EQ(inputs.size(), 3U);
+                                const OpContext& ctx,
+                                const std::vector<TBlob>& inputs,
+                                const std::vector<OpReqType>& req,
+                                const std::vector<TBlob>& outputs) {
+  const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
+  bool relu = param.act_type == activation::kReLU;
+  CHECK_EQ(inputs.size(), relu ? 2U : 3U);
   CHECK_EQ(outputs.size(), 1U);
   CHECK_EQ(req.size(), 1U);
-  const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
 
-  // SoftReLU not supported by CUDNN yet
+  // both SoftReLU and SoftSign not supported by CUDNN yet
   if (param.act_type == activation::kSoftReLU) {
-    MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
-      ActivationBackward<gpu, mshadow_op::softrelu, mshadow_op::softrelu_grad, DType>(
-          ctx, inputs[0], inputs[1], req[0], outputs[0]);
-    });
+    ActivationBackward<gpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(
+      ctx, inputs[0], inputs[1], req[0], outputs[0]);
+  } else if (param.act_type == activation::kSoftSign) {
+    ActivationBackward<gpu, mshadow_op::softsign, mshadow_op::softsign_grad>(
+      ctx, inputs[0], inputs[1], req[0], outputs[0]);
   } else {
     MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
-      get_cudnn_op<DType>(param).Backward(ctx, inputs[0], inputs[2], inputs[1], req[0], outputs[0]);
+      // XXX: for y = relu(x), y is passed as "in_data" to Backward()
+      get_cudnn_op<DType>(param).Backward(ctx, inputs[0], relu ? inputs[1] : inputs[2],
+                                          inputs[1], req[0], outputs[0]);
     });
   }
 }
diff --git a/src/operator/nn/cudnn/cudnn_activation-inl.h b/src/operator/nn/cudnn/cudnn_activation-inl.h
index a89e7bfaf08..2c1f442808c 100644
--- a/src/operator/nn/cudnn/cudnn_activation-inl.h
+++ b/src/operator/nn/cudnn/cudnn_activation-inl.h
@@ -130,6 +130,9 @@ class CuDNNActivationOp {
     #endif
   }
 
+  // backward computation for cudnn activation operator. Note that for relu
+  // it's okay to pass "out_data" as "in_data", since it doesn't make any
+  // difference in terms of computing the gradient of relu.
   void Backward(const OpContext &ctx, const TBlob &out_grad,
       const TBlob &in_data, const TBlob &out_data,
       const OpReqType &req, const TBlob &in_grad) {
diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc
index 9be5bfbc150..a057527d473 100644
--- a/src/operator/nn/mkldnn/mkldnn_act.cc
+++ b/src/operator/nn/mkldnn/mkldnn_act.cc
@@ -165,6 +165,8 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
   stream->Submit();
 }
 
+// For backward relu activation, it's okay to pass "out_data" as "in_data" to this
+// function, since the computation only involes non-zeros.
 void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
                               const NDArray &out_grad, const NDArray &in_data,
                               const OpReqType &req, const NDArray &in_grad) {
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc
index fda47fc5f95..93519d3ee8f 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cc
+++ b/src/operator/tensor/elemwise_unary_op_basic.cc
@@ -86,7 +86,7 @@ The storage type of ``relu`` output depends upon the input storage type:
 .set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<1, 1, false, true, false>)
 .set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::relu>)
 .set_attr<FComputeEx>("FComputeEx<cpu>", UnaryOp::ComputeEx<cpu, mshadow_op::relu>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_relu"});
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_relu"});
 
 MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_relu,
                                                unary_bwd<mshadow_op::relu_grad>);
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index 08c749e597e..83dfc428110 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -1133,14 +1133,17 @@ def test_fullyconnected_with_type():
 
 @with_seed()
 def test_activation_with_type():
-    sym = mx.sym.Activation(name='act', act_type='sigmoid')
-    ctx_list = [{'ctx': mx.gpu(0), 'act_data': (2, 2, 10, 10), 'type_dict': {'act_data': np.float64}},
-                {'ctx': mx.gpu(0), 'act_data': (2, 2, 10, 10), 'type_dict': {'act_data': np.float32}},
-                {'ctx': mx.gpu(0), 'act_data': (2, 2, 10, 10), 'type_dict': {'act_data': np.float16}},
-                {'ctx': mx.cpu(0), 'act_data': (2, 2, 10, 10), 'type_dict': {'act_data': np.float64}},
-                {'ctx': mx.cpu(0), 'act_data': (2, 2, 10, 10), 'type_dict': {'act_data': np.float32}},
-                {'ctx': mx.cpu(0), 'act_data': (2, 2, 10, 10), 'type_dict': {'act_data': np.float16}}]
-    check_consistency(sym, ctx_list)
+    act_types = ['relu', 'sigmoid', 'tanh', 'softrelu', 'softsign']
+    shape = (2, 2, 10, 10)
+    for act_type in act_types:
+        sym = mx.sym.Activation(name='act', act_type=act_type)
+        ctx_list = [{'ctx': mx.gpu(0), 'act_data': shape, 'type_dict': {'act_data': np.float64}},
+                    {'ctx': mx.gpu(0), 'act_data': shape, 'type_dict': {'act_data': np.float32}},
+                    {'ctx': mx.gpu(0), 'act_data': shape, 'type_dict': {'act_data': np.float16}},
+                    {'ctx': mx.cpu(0), 'act_data': shape, 'type_dict': {'act_data': np.float64}},
+                    {'ctx': mx.cpu(0), 'act_data': shape, 'type_dict': {'act_data': np.float32}},
+                    {'ctx': mx.cpu(0), 'act_data': shape, 'type_dict': {'act_data': np.float16}}]
+        check_consistency(sym, ctx_list)
 
 
 @with_seed()
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 7ee67dd2066..d6486e3781c 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -5816,6 +5816,48 @@ def get_output_names_callback(name, arr):
                             name='pooling')
     check_name(us_sym, ['pooling_output'])
 
+@with_seed()
+def test_activation():
+    shape=(9, 10)
+    dtype_l = [np.float64, np.float32, np.float16]
+    rtol_l = [1e-7, 1e-6, 1e-2]
+    atol_l = [1e-7, 1e-6, 1e-2]
+    rtol_fd = 1e-5
+    atol_fd = 1e-6
+    num_eps = 1e-6
+    unary_ops = {
+        'relu': [lambda x: mx.sym.Activation(x, act_type='relu'),
+                 lambda x: np.maximum(x, 0.),
+                 lambda x: 1. * (x > 0.),
+                 -5.0, 5.0],
+        'sigmoid': [lambda x: mx.sym.Activation(x, act_type='sigmoid'),
+                    lambda x: 1. / (np.exp(-x) + 1.),
+                    lambda x: 1. / (np.exp(-x) + 1.) / (np.exp(x) + 1.),
+                    -3.0, 3.0],
+        'tanh': [lambda x: mx.sym.Activation(x, act_type='tanh'),
+                 lambda x: np.tanh(x),
+                 lambda x: 1. - np.tanh(x) ** 2,
+                 -4.0, 4.0],
+        'softrelu': [lambda x: mx.sym.Activation(x, act_type='softrelu'),
+                    lambda x: np.log(1. + np.exp(x)),
+                    lambda x: 1. - 1 / (1 + np.exp(x)),
+                    -3.0, 3.0],
+    }
+    # Loop over operators
+    for name, op in unary_ops.items():
+        # Loop over dtype's
+        for ind in range(len(dtype_l)):
+            dtype = dtype_l[ind]
+            rtol = rtol_l[ind]
+            atol = atol_l[ind]
+            compare_forw_backw_unary_op(
+                name, op[0], op[1], op[2], shape, op[3], op[4], rtol, atol,
+                dtype)
+        # Finite difference testing
+        finite_diff_unary_op(
+            name, op[0], shape, op[3], op[4], rtol_fd, atol_fd, num_eps)
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services