You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sa...@apache.org on 2020/10/23 07:04:14 UTC

[incubator-mxnet] branch v1.8.x updated: Fix SoftReLU fused operator numerical stability (#17849) (#19390)

This is an automated email from the ASF dual-hosted git repository.

samskalicky pushed a commit to branch v1.8.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.8.x by this push:
     new ddab468  Fix SoftReLU fused operator numerical stability (#17849) (#19390)
ddab468 is described below

commit ddab4683b53ef18bb2115d298bd1ee2c85049f42
Author: Manu Seth <22...@users.noreply.github.com>
AuthorDate: Fri Oct 23 00:01:59 2020 -0700

    Fix SoftReLU fused operator numerical stability (#17849) (#19390)
    
    * fix numerically unstable fused softrelu op
    
    * implement test for softrelu numerical stability
    
    Co-authored-by: RuRo <ru...@ya.ru>
---
 src/operator/fusion/fused_op-inl.h | 5 ++++-
 tests/python/gpu/test_fusion.py    | 3 +++
 2 files changed, 7 insertions(+), 1 deletion(-)

diff --git a/src/operator/fusion/fused_op-inl.h b/src/operator/fusion/fused_op-inl.h
index c838d85..0b10f82 100644
--- a/src/operator/fusion/fused_op-inl.h
+++ b/src/operator/fusion/fused_op-inl.h
@@ -566,7 +566,10 @@ __device__ inline DType sigmoid(const DType val) {
 
 template <typename DType>
 __device__ inline DType softrelu(const DType val) {
-  return logf(1 + expf(val));
+  // Avoid overflow of exp for large inputs.
+  // The threshold 20 is chosen such that softrelu(a) = a
+  // for a > 20 using floating precision.
+  return val > 20 ? val : logf(1 + expf(val));
 }
 
 template <typename DType>
diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py
index 1bbf598..1febf8d 100644
--- a/tests/python/gpu/test_fusion.py
+++ b/tests/python/gpu/test_fusion.py
@@ -138,6 +138,9 @@ def check_unary_ops():
     for act_type in ['relu', 'sigmoid', 'tanh', 'softrelu', 'softsign']:
         announce_check("Activation(act_type='{}')".format(act_type))
         check_fused_symbol(mx.sym.Activation(a, act_type=act_type), a=arr)
+        if act_type == 'softrelu':
+            # Check that softrelu implementation doesn't overflow on large inputs
+            check_fused_symbol(mx.sym.Activation(a, act_type=act_type), a=1000 * arr)
 
     # Cast requires dtype
     for dtype in ['float16', 'float32', 'float64', 'int32']: