You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/20 22:36:51 UTC

[GitHub] piiswrong closed pull request #9996: add softsign operator to v1.1.0

piiswrong closed pull request #9996: add softsign operator to v1.1.0
URL: https://github.com/apache/incubator-mxnet/pull/9996
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index af7ef513f14..1d4284e1ac2 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -111,6 +111,10 @@ MXNET_UNARY_MATH_OP(sigmoid, 1.0f / (1.0f + math::exp(-a)));
 
 MXNET_UNARY_MATH_OP(sigmoid_grad, math::id(a) * (1.0f - math::id(a)));
 
+MXNET_UNARY_MATH_OP(softsign, a / (1.0f + math::fabs(a)));
+
+MXNET_UNARY_MATH_OP(softsign_grad, 1.0f /  math::sqr(1.0f + math::fabs(a)));
+
 MXNET_UNARY_MATH_OP_NC(relu, a > DType(0) ? a : DType(0));
 
 MXNET_UNARY_MATH_OP_NC(relu_grad, a > DType(0) ? DType(1) : DType(0));
diff --git a/src/operator/nn/activation-inl.h b/src/operator/nn/activation-inl.h
index ac8b747f0f3..f2bc45e1bec 100644
--- a/src/operator/nn/activation-inl.h
+++ b/src/operator/nn/activation-inl.h
@@ -45,7 +45,7 @@ namespace op {
 namespace activation {
 enum ActivationOpInputs {kData};
 enum ActivationOpOutputs {kOut};
-enum ActivationOpType {kReLU, kSigmoid, kTanh, kSoftReLU};
+enum ActivationOpType {kReLU, kSigmoid, kTanh, kSoftReLU, kSoftSign};
 }  // activation
 
 struct ActivationParam : public dmlc::Parameter<ActivationParam> {
@@ -57,6 +57,7 @@ struct ActivationParam : public dmlc::Parameter<ActivationParam> {
     .add_enum("sigmoid", activation::kSigmoid)
     .add_enum("tanh", activation::kTanh)
     .add_enum("softrelu", activation::kSoftReLU)
+    .add_enum("softsign", activation::kSoftSign)
     .describe("Activation function to be applied.");
   }
 };
diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc
index 401a9e3eaa5..d634033c80f 100644
--- a/src/operator/nn/activation.cc
+++ b/src/operator/nn/activation.cc
@@ -64,6 +64,9 @@ Operator *CreateOp<cpu>(ActivationParam param, int dtype, const TShape& dshape)
       case activation::kSoftReLU:
         op = new ActivationOp<cpu, mshadow_op::softrelu, mshadow_op::softrelu_grad, DType>();
         break;
+      case activation::kSoftSign:
+        op = new ActivationOp<cpu, mshadow_op::softsign, mshadow_op::softsign_grad, DType>();
+        break;
       default:
         LOG(FATAL) << "unknown activation type";
     }
@@ -88,6 +91,7 @@ The following activation functions are supported:
 - `sigmoid`: :math:`y = \frac{1}{1 + exp(-x)}`
 - `tanh`: Hyperbolic tangent, :math:`y = \frac{exp(x) - exp(-x)}{exp(x) + exp(-x)}`
 - `softrelu`: Soft ReLU, or SoftPlus, :math:`y = log(1 + exp(x))`
+- `softsign`: :math:`y = \frac{x}{1 + abs(x)}`
 
 )code" ADD_FILELINE)
 .add_argument("data", "NDArray-or-Symbol", "Input array to activation function.")
diff --git a/src/operator/nn/activation.cu b/src/operator/nn/activation.cu
index c2f6be9f37c..a7fdaa14a0b 100644
--- a/src/operator/nn/activation.cu
+++ b/src/operator/nn/activation.cu
@@ -58,6 +58,9 @@ Operator *CreateOp<gpu>(ActivationParam param, int dtype, const TShape& dshape)
       case activation::kTanh:
         op = new ActivationOp<gpu, mshadow_op::tanh, mshadow_op::tanh_grad, DType>();
         break;
+      case activation::kSoftSign:
+        op = new ActivationOp<gpu, mshadow_op::softsign, mshadow_op::softsign_grad, DType>();
+        break;
       default:
         LOG(FATAL) << "unknown activation";
     }
diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc
index e0f8306565d..c13f1ac2fae 100644
--- a/src/operator/operator_tune.cc
+++ b/src/operator/operator_tune.cc
@@ -213,6 +213,8 @@ IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_grad);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sigmoid);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sigmoid_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softsign);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::softsign_grad);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::relu);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::relu_grad);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::tanh);  // NOLINT()
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc
index 13a58d0165a..c675e51a076 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cc
+++ b/src/operator/tensor/elemwise_unary_op_basic.cc
@@ -105,6 +105,23 @@ The storage type of ``sigmoid`` output is always dense
 
 MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_sigmoid,
                                                unary_bwd<mshadow_op::sigmoid_grad>);
+// softsign
+MXNET_OPERATOR_REGISTER_UNARY(softsign)
+MXNET_ADD_SPARSE_OP_ALIAS(softsign)
+.describe(R"code(Computes softsign of x element-wise.
+
+.. math::
+   y = x / (1 + abs(x))
+
+The storage type of ``softsign`` output is always dense
+
+)code" ADD_FILELINE)
+  .set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::softsign>)
+  .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_softsign"});
+
+MXNET_OPERATOR_REGISTER_BINARY(_backward_softsign)
+.set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu,
+  unary_bwd<mshadow_op::softsign_grad> >);
 
 // copy
 MXNET_OPERATOR_REGISTER_UNARY(_copy)
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cu b/src/operator/tensor/elemwise_unary_op_basic.cu
index 41eef903401..8dfa9af74ce 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cu
+++ b/src/operator/tensor/elemwise_unary_op_basic.cu
@@ -40,6 +40,14 @@ NNVM_REGISTER_OP(_backward_sigmoid)
 .set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<
   gpu, unary_bwd<mshadow_op::sigmoid_grad>>);
 
+// softsign
+NNVM_REGISTER_OP(softsign)
+.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::softsign>);
+
+NNVM_REGISTER_OP(_backward_softsign)
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<
+  gpu, unary_bwd<mshadow_op::softsign_grad>>);
+
 // copy
 NNVM_REGISTER_OP(_copy)
 .set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>)
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 640cd347dc1..d38e40e914a 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -467,6 +467,20 @@ def fsigmoid(a):
     check_symbolic_forward(y, [xa], [ya])
     check_symbolic_backward(y, [xa], [np.ones(shape)], [ya * (1 - ya)])
 
+def test_softsign():
+    def fsoftsign(a):
+        return np.divide(a, (1.0 + np.abs(a)))
+    def fsoftsign_grad(a):
+        return np.divide(1.0, np.square((1.0 + np.abs(a))))
+    shape = (3, 4)
+    x = mx.symbol.Variable("x")
+    y = mx.sym.softsign(x)
+    xa = np.random.uniform(low=-1.0,high=1.0,size=shape)
+    ya = fsoftsign(xa)
+    ya_grad = fsoftsign_grad(xa)
+    check_numeric_gradient(y, [xa], numeric_eps=1E-3)
+    check_symbolic_forward(y, [xa], [ya])
+    check_symbolic_backward(y, [xa], [np.ones(shape)], [ya_grad])
 
 def test_binary_logic():
     def _inner_test(forward_gt, logic_sym, x_shape, y_shape, test_scalar=True):
@@ -4688,6 +4702,10 @@ def test_unary_math_operators():
                     lambda x: 1. / (np.exp(-x) + 1.),
                     lambda x: 1. / (np.exp(-x) + 1.) / (np.exp(x) + 1.),
                     -3.0, 3.0],
+        'softsign': [lambda x: mx.sym.softsign(x),
+                    lambda x: x / (1. + np.abs(x)),
+                    lambda x: 1. / np.square(1. + np.abs(x)),
+                    -3.0, 3.0],
         'sin': [lambda x: mx.sym.sin(x),
                 lambda x: np.sin(x),
                 lambda x: np.cos(x),


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services