You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sa...@apache.org on 2020/10/23 07:04:14 UTC
[incubator-mxnet] branch v1.8.x updated: Fix SoftReLU fused
operator numerical stability (#17849) (#19390)
This is an automated email from the ASF dual-hosted git repository.
samskalicky pushed a commit to branch v1.8.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.8.x by this push:
new ddab468 Fix SoftReLU fused operator numerical stability (#17849) (#19390)
ddab468 is described below
commit ddab4683b53ef18bb2115d298bd1ee2c85049f42
Author: Manu Seth <22...@users.noreply.github.com>
AuthorDate: Fri Oct 23 00:01:59 2020 -0700
Fix SoftReLU fused operator numerical stability (#17849) (#19390)
* fix numerically unstable fused softrelu op
* implement test for softrelu numerical stability
Co-authored-by: RuRo <ru...@ya.ru>
---
src/operator/fusion/fused_op-inl.h | 5 ++++-
tests/python/gpu/test_fusion.py | 3 +++
2 files changed, 7 insertions(+), 1 deletion(-)
diff --git a/src/operator/fusion/fused_op-inl.h b/src/operator/fusion/fused_op-inl.h
index c838d85..0b10f82 100644
--- a/src/operator/fusion/fused_op-inl.h
+++ b/src/operator/fusion/fused_op-inl.h
@@ -566,7 +566,10 @@ __device__ inline DType sigmoid(const DType val) {
template <typename DType>
__device__ inline DType softrelu(const DType val) {
- return logf(1 + expf(val));
+ // Avoid overflow of exp for large inputs.
+ // The threshold 20 is chosen such that softrelu(a) = a
+ // for a > 20 using floating precision.
+ return val > 20 ? val : logf(1 + expf(val));
}
template <typename DType>
diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py
index 1bbf598..1febf8d 100644
--- a/tests/python/gpu/test_fusion.py
+++ b/tests/python/gpu/test_fusion.py
@@ -138,6 +138,9 @@ def check_unary_ops():
for act_type in ['relu', 'sigmoid', 'tanh', 'softrelu', 'softsign']:
announce_check("Activation(act_type='{}')".format(act_type))
check_fused_symbol(mx.sym.Activation(a, act_type=act_type), a=arr)
+ if act_type == 'softrelu':
+ # Check that softrelu implementation doesn't overflow on large inputs
+ check_fused_symbol(mx.sym.Activation(a, act_type=act_type), a=1000 * arr)
# Cast requires dtype
for dtype in ['float16', 'float32', 'float64', 'int32']: