You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/08/12 19:24:21 UTC

[incubator-mxnet] branch master updated: support selu activation function (#12059)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0241e2c  support selu activation function (#12059)
0241e2c is described below

commit 0241e2c7c36ab73fd20fe9e43ec078051c9d172d
Author: Hao Jin <ha...@users.noreply.github.com>
AuthorDate: Sun Aug 12 15:24:10 2018 -0400

    support selu activation function (#12059)
---
 python/mxnet/gluon/nn/activations.py   |  4 +---
 src/operator/leaky_relu-inl.h          | 19 ++++++++++++++++++-
 src/operator/leaky_relu.cc             |  2 ++
 src/operator/mshadow_op.h              | 10 ++++++++++
 src/operator/operator_tune.cc          |  2 ++
 tests/python/unittest/test_operator.py | 31 +++++++++++++++++++++++++++++++
 6 files changed, 64 insertions(+), 4 deletions(-)

diff --git a/python/mxnet/gluon/nn/activations.py b/python/mxnet/gluon/nn/activations.py
index 422301a..fa8eee9 100644
--- a/python/mxnet/gluon/nn/activations.py
+++ b/python/mxnet/gluon/nn/activations.py
@@ -176,11 +176,9 @@ class SELU(HybridBlock):
     """
     def __init__(self, **kwargs):
         super(SELU, self).__init__(**kwargs)
-        self._scale = 1.0507009873554804934193349852946
-        self._alpha = 1.6732632423543772848170429916717
 
     def hybrid_forward(self, F, x):
-        return self._scale * F.where(x > 0, x, self._alpha * (F.exp(x) - 1.0))
+        return F.LeakyReLU(x, act_type='selu', name='fwd')
 
 
 class Swish(HybridBlock):
diff --git a/src/operator/leaky_relu-inl.h b/src/operator/leaky_relu-inl.h
index 20aabc8..1c4f48b 100644
--- a/src/operator/leaky_relu-inl.h
+++ b/src/operator/leaky_relu-inl.h
@@ -47,7 +47,7 @@ namespace op {
 namespace leakyrelu {
 enum LeakyReLUOpInputs {kData, kGamma};
 enum LeakyReLUOpOutputs {kOut, kMask};
-enum LeakyReLUOpType {kLeakyReLU, kPReLU, kRReLU, kELU};
+enum LeakyReLUOpType {kLeakyReLU, kPReLU, kRReLU, kELU, kSELU};
 enum LeakyReLUOpResource {kRandom};
 }  // namespace leakyrelu
 
@@ -63,6 +63,7 @@ struct LeakyReLUParam : public dmlc::Parameter<LeakyReLUParam> {
     .add_enum("leaky", leakyrelu::kLeakyReLU)
     .add_enum("prelu", leakyrelu::kPReLU)
     .add_enum("elu", leakyrelu::kELU)
+    .add_enum("selu", leakyrelu::kSELU)
     .describe("Activation function to be applied.");
     DMLC_DECLARE_FIELD(slope).set_default(0.25f)
     .describe("Init slope for the activation. (For leaky and elu only)");
@@ -182,6 +183,13 @@ class LeakyReLUOp : public Operator {
         });
         break;
       }
+      case leakyrelu::kSELU: {
+        MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, {
+          mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::selu, Req>, xpu>::Launch(
+            s, out.size(0) * out.size(1) * out.size(2), out.dptr_, data.dptr_);
+        });
+        break;
+      }
       default:
         LOG(FATAL) << "Not implmented";
     }
@@ -270,6 +278,15 @@ class LeakyReLUOp : public Operator {
         });
         break;
       }
+      case leakyrelu::kSELU: {
+        MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kData], Req, {
+          mxnet_op::Kernel<mxnet_op::op_with_req<
+            mxnet_op::backward_grad_tuned<mshadow_op::selu_grad>, Req>, xpu>::Launch(
+              s, gdata.size(0) * gdata.size(1) * gdata.size(2), gdata.dptr_, grad.dptr_,
+              output.dptr_);
+        });
+        break;
+      }
       default:
         LOG(FATAL) << "Not implmented";
     }
diff --git a/src/operator/leaky_relu.cc b/src/operator/leaky_relu.cc
index 99b6ba3..4bb2423 100644
--- a/src/operator/leaky_relu.cc
+++ b/src/operator/leaky_relu.cc
@@ -54,6 +54,8 @@ when the input is negative and has a slope of one when input is positive.
 The following modified ReLU Activation functions are supported:
 
 - *elu*: Exponential Linear Unit. `y = x > 0 ? x : slope * (exp(x)-1)`
+- *selu*: Scaled Exponential Linear Unit. `y = lambda * (x > 0 ? x : alpha * (exp(x) - 1))` where
+  *lambda = 1.0507009873554804934193349852946* and *alpha = 1.6732632423543772848170429916717*.
 - *leaky*: Leaky ReLU. `y = x > 0 ? x : slope * x`
 - *prelu*: Parametric ReLU. This is same as *leaky* except that `slope` is learnt during training.
 - *rrelu*: Randomized ReLU. same as *leaky* but the `slope` is uniformly and randomly chosen from
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index 7a2032d..3397193 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -42,8 +42,12 @@ namespace mshadow_op {
 
 #ifdef __CUDA_ARCH__
 __constant__ const float PI = 3.14159265358979323846;
+__constant__ const float SELU_ALPHA = 1.6732632423543772848170429916717;
+__constant__ const float SELU_LAMBDA = 1.0507009873554804934193349852946;
 #else
 const float PI = 3.14159265358979323846;
+const float SELU_ALPHA = 1.6732632423543772848170429916717;
+const float SELU_LAMBDA = 1.0507009873554804934193349852946;
 using std::isnan;
 #endif
 using std::enable_if;
@@ -126,6 +130,12 @@ MXNET_UNARY_MATH_OP_NC(relu, a > DType(0) ? a : DType(0));
 
 MXNET_UNARY_MATH_OP_NC(relu_grad, a > DType(0) ? DType(1) : DType(0));
 
+MXNET_UNARY_MATH_OP_NC(selu, DType(SELU_LAMBDA) *
+                         (a > DType(0) ? a : DType(math::id(SELU_ALPHA) * math::expm1(a))));
+
+MXNET_UNARY_MATH_OP_NC(selu_grad,
+                       DType(SELU_LAMBDA) * (a > DType(0) ? DType(1) : DType(SELU_ALPHA + a)));
+
 MXNET_BINARY_MATH_OP_NC(prelu_grad, a > DType(0) ? DType(0) : a);
 
 MXNET_BINARY_MATH_OP_NC(xelu, a > DType(0) ? a :
diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc
index 0953cba..cf5412f 100644
--- a/src/operator/operator_tune.cc
+++ b/src/operator/operator_tune.cc
@@ -217,6 +217,8 @@ IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softsign);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::softsign_grad);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::relu);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::relu_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::selu);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::selu_grad);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tanh);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tanh_grad);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softrelu);  // NOLINT()
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 22caaca..54eb0fd 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -820,6 +820,37 @@ def test_prelu():
                                     [g_xa_full, g_gam_full], rtol=rtol, atol=atol, dtype=dtype)
 
 @with_seed()
+def test_selu():
+    alpha = 1.6732632423543772848170429916717
+    lamb = 1.0507009873554804934193349852946
+    def fselu(x):
+        neg_indices = x < 0
+        out = x.copy()
+        out[neg_indices] = alpha * np.expm1(out[neg_indices])
+        return out * lamb
+    def fselu_grad(grad, x, y):
+        neg_indices = x < 0
+        out = np.ones(x.shape).astype(x.dtype)
+        out[neg_indices] = y[neg_indices] + alpha
+        return out * lamb
+
+    shape = (3, 4)
+    x = mx.sym.Variable("x")
+    y = mx.sym.LeakyReLU(data=x, act_type="selu")
+    for dtype in [np.float16, np.float32, np.float64]:
+        xa = np.random.uniform(low=-0.1,high=0.1,size=shape).astype(dtype)
+        eps, rtol, atol = (7.5e-4, 1e-1, 1e-2) if dtype is np.float16 else (1e-4, 1e-2, 1e-4)
+        if dtype is np.float16:
+            xa /= 10.0
+        xa[abs(xa) < eps] = 0.01
+        ya = fselu(xa)
+        ga = fselu_grad(np.ones(shape).astype(dtype), xa, ya)
+        check_numeric_gradient(y, [xa], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype)
+        check_symbolic_forward(y, [xa], [ya], rtol=rtol, atol=atol, dtype=dtype)
+        check_symbolic_backward(y, [xa], [np.ones(shape)], [ga], rtol=rtol, atol=atol, dtype=dtype)
+
+
+@with_seed()
 def test_sigmoid():
     def fsigmoid(a):
         return np.divide(1.0, (1.0 + np.exp(-a)))