You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/05/02 03:59:02 UTC

[incubator-mxnet] branch master updated: fix flaky test for hard_sigmoid (#10759)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 65df1ee  fix flaky test for hard_sigmoid (#10759)
65df1ee is described below

commit 65df1ee711a07fccc74a0131d7b3fb67b4f48e74
Author: Hao Jin <ha...@users.noreply.github.com>
AuthorDate: Tue May 1 20:58:56 2018 -0700

    fix flaky test for hard_sigmoid (#10759)
---
 tests/python/unittest/test_operator.py | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 838d8d8..7ee67dd 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -601,18 +601,18 @@ def test_hard_sigmoid():
     for dtype in [np.float16, np.float32, np.float64]:
         if dtype is np.float16:
             rtol = 1e-2
-            atol = 1e-3
         else:
             rtol = 1e-3
-            atol = 1e-5
+        atol = 1e-3
+        eps = 1e-3
         xa = np.random.uniform(low=-3.0,high=3.0,size=shape).astype(dtype)
         # function not differentiable at x=2.5 and -2.5
-        xa[xa == 2.5] = xa[xa == 2.5] - 1e-2
-        xa[xa == -2.5] = xa[xa == -2.5] - 1e-2
+        xa[abs(xa-2.5) < eps] -= 2 * eps
+        xa[abs(xa+2.5) < eps] += 2 * eps
         ya = fhardsigmoid(xa)
         grad_xa = fhardsigmoid_grad(xa, np.ones(shape))
         if dtype is not np.float16:
-            check_numeric_gradient(y, [xa], numeric_eps=1e-3, rtol=rtol, atol=atol, dtype=dtype)
+            check_numeric_gradient(y, [xa], numeric_eps=eps, rtol=rtol, atol=atol, dtype=dtype)
         check_symbolic_forward(y, [xa], [ya], rtol=rtol, atol=atol, dtype=dtype)
         check_symbolic_backward(y, [xa], [np.ones(shape)], [grad_xa], rtol=rtol, atol=atol, dtype=dtype)
 

-- 
To stop receiving notification emails like this one, please contact
zhasheng@apache.org.