You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/12 19:24:12 UTC

[GitHub] eric-haibin-lin closed pull request #12059: Support selu activation function

eric-haibin-lin closed pull request #12059: Support selu activation function
URL: https://github.com/apache/incubator-mxnet/pull/12059
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/nn/activations.py b/python/mxnet/gluon/nn/activations.py
index 422301a6a48..fa8eee9d298 100644
--- a/python/mxnet/gluon/nn/activations.py
+++ b/python/mxnet/gluon/nn/activations.py
@@ -176,11 +176,9 @@ class SELU(HybridBlock):
     """
     def __init__(self, **kwargs):
         super(SELU, self).__init__(**kwargs)
-        self._scale = 1.0507009873554804934193349852946
-        self._alpha = 1.6732632423543772848170429916717
 
     def hybrid_forward(self, F, x):
-        return self._scale * F.where(x > 0, x, self._alpha * (F.exp(x) - 1.0))
+        return F.LeakyReLU(x, act_type='selu', name='fwd')
 
 
 class Swish(HybridBlock):
diff --git a/src/operator/leaky_relu-inl.h b/src/operator/leaky_relu-inl.h
index 20aabc8ae32..1c4f48b32ed 100644
--- a/src/operator/leaky_relu-inl.h
+++ b/src/operator/leaky_relu-inl.h
@@ -47,7 +47,7 @@ namespace op {
 namespace leakyrelu {
 enum LeakyReLUOpInputs {kData, kGamma};
 enum LeakyReLUOpOutputs {kOut, kMask};
-enum LeakyReLUOpType {kLeakyReLU, kPReLU, kRReLU, kELU};
+enum LeakyReLUOpType {kLeakyReLU, kPReLU, kRReLU, kELU, kSELU};
 enum LeakyReLUOpResource {kRandom};
 }  // namespace leakyrelu
 
@@ -63,6 +63,7 @@ struct LeakyReLUParam : public dmlc::Parameter<LeakyReLUParam> {
     .add_enum("leaky", leakyrelu::kLeakyReLU)
     .add_enum("prelu", leakyrelu::kPReLU)
     .add_enum("elu", leakyrelu::kELU)
+    .add_enum("selu", leakyrelu::kSELU)
     .describe("Activation function to be applied.");
     DMLC_DECLARE_FIELD(slope).set_default(0.25f)
     .describe("Init slope for the activation. (For leaky and elu only)");
@@ -182,6 +183,13 @@ class LeakyReLUOp : public Operator {
         });
         break;
       }
+      case leakyrelu::kSELU: {
+        MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kOut], Req, {
+          mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::selu, Req>, xpu>::Launch(
+            s, out.size(0) * out.size(1) * out.size(2), out.dptr_, data.dptr_);
+        });
+        break;
+      }
       default:
         LOG(FATAL) << "Not implmented";
     }
@@ -270,6 +278,15 @@ class LeakyReLUOp : public Operator {
         });
         break;
       }
+      case leakyrelu::kSELU: {
+        MXNET_ASSIGN_REQ_SWITCH(req[leakyrelu::kData], Req, {
+          mxnet_op::Kernel<mxnet_op::op_with_req<
+            mxnet_op::backward_grad_tuned<mshadow_op::selu_grad>, Req>, xpu>::Launch(
+              s, gdata.size(0) * gdata.size(1) * gdata.size(2), gdata.dptr_, grad.dptr_,
+              output.dptr_);
+        });
+        break;
+      }
       default:
         LOG(FATAL) << "Not implmented";
     }
diff --git a/src/operator/leaky_relu.cc b/src/operator/leaky_relu.cc
index 99b6ba362f7..4bb24237b8e 100644
--- a/src/operator/leaky_relu.cc
+++ b/src/operator/leaky_relu.cc
@@ -54,6 +54,8 @@ when the input is negative and has a slope of one when input is positive.
 The following modified ReLU Activation functions are supported:
 
 - *elu*: Exponential Linear Unit. `y = x > 0 ? x : slope * (exp(x)-1)`
+- *selu*: Scaled Exponential Linear Unit. `y = lambda * (x > 0 ? x : alpha * (exp(x) - 1))` where
+  *lambda = 1.0507009873554804934193349852946* and *alpha = 1.6732632423543772848170429916717*.
 - *leaky*: Leaky ReLU. `y = x > 0 ? x : slope * x`
 - *prelu*: Parametric ReLU. This is same as *leaky* except that `slope` is learnt during training.
 - *rrelu*: Randomized ReLU. same as *leaky* but the `slope` is uniformly and randomly chosen from
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index 7a2032df758..339719375fd 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -42,8 +42,12 @@ namespace mshadow_op {
 
 #ifdef __CUDA_ARCH__
 __constant__ const float PI = 3.14159265358979323846;
+__constant__ const float SELU_ALPHA = 1.6732632423543772848170429916717;
+__constant__ const float SELU_LAMBDA = 1.0507009873554804934193349852946;
 #else
 const float PI = 3.14159265358979323846;
+const float SELU_ALPHA = 1.6732632423543772848170429916717;
+const float SELU_LAMBDA = 1.0507009873554804934193349852946;
 using std::isnan;
 #endif
 using std::enable_if;
@@ -126,6 +130,12 @@ MXNET_UNARY_MATH_OP_NC(relu, a > DType(0) ? a : DType(0));
 
 MXNET_UNARY_MATH_OP_NC(relu_grad, a > DType(0) ? DType(1) : DType(0));
 
+MXNET_UNARY_MATH_OP_NC(selu, DType(SELU_LAMBDA) *
+                         (a > DType(0) ? a : DType(math::id(SELU_ALPHA) * math::expm1(a))));
+
+MXNET_UNARY_MATH_OP_NC(selu_grad,
+                       DType(SELU_LAMBDA) * (a > DType(0) ? DType(1) : DType(SELU_ALPHA + a)));
+
 MXNET_BINARY_MATH_OP_NC(prelu_grad, a > DType(0) ? DType(0) : a);
 
 MXNET_BINARY_MATH_OP_NC(xelu, a > DType(0) ? a :
diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc
index 0953cbaf519..cf5412f9824 100644
--- a/src/operator/operator_tune.cc
+++ b/src/operator/operator_tune.cc
@@ -217,6 +217,8 @@ IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softsign);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::softsign_grad);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::relu);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::relu_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::selu);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::selu_grad);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tanh);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::tanh_grad);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softrelu);  // NOLINT()
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 90e85d123d5..949f2c59c45 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -819,6 +819,37 @@ def fprelu_grad(x, y, gamma):
             check_symbolic_backward(y, [xa, gam_full], [np.ones(shape), np.ones(gam_full.shape)],
                                     [g_xa_full, g_gam_full], rtol=rtol, atol=atol, dtype=dtype)
 
+@with_seed()
+def test_selu():
+    alpha = 1.6732632423543772848170429916717
+    lamb = 1.0507009873554804934193349852946
+    def fselu(x):
+        neg_indices = x < 0
+        out = x.copy()
+        out[neg_indices] = alpha * np.expm1(out[neg_indices])
+        return out * lamb
+    def fselu_grad(grad, x, y):
+        neg_indices = x < 0
+        out = np.ones(x.shape).astype(x.dtype)
+        out[neg_indices] = y[neg_indices] + alpha
+        return out * lamb
+
+    shape = (3, 4)
+    x = mx.sym.Variable("x")
+    y = mx.sym.LeakyReLU(data=x, act_type="selu")
+    for dtype in [np.float16, np.float32, np.float64]:
+        xa = np.random.uniform(low=-0.1,high=0.1,size=shape).astype(dtype)
+        eps, rtol, atol = (7.5e-4, 1e-1, 1e-2) if dtype is np.float16 else (1e-4, 1e-2, 1e-4)
+        if dtype is np.float16:
+            xa /= 10.0
+        xa[abs(xa) < eps] = 0.01
+        ya = fselu(xa)
+        ga = fselu_grad(np.ones(shape).astype(dtype), xa, ya)
+        check_numeric_gradient(y, [xa], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype)
+        check_symbolic_forward(y, [xa], [ya], rtol=rtol, atol=atol, dtype=dtype)
+        check_symbolic_backward(y, [xa], [np.ones(shape)], [ga], rtol=rtol, atol=atol, dtype=dtype)
+
+
 @with_seed()
 def test_sigmoid():
     def fsigmoid(a):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services