You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/08/27 04:05:07 UTC

[incubator-mxnet] branch master updated: Make check_isfinite, check_scale optional in clip_global_norm (#12042)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 308ada1  Make check_isfinite, check_scale optional in clip_global_norm (#12042)
308ada1 is described below

commit 308ada1e412a56343e012f1ef7a4aa4fbe243032
Author: Leonard Lausen <le...@lausen.nl>
AuthorDate: Sun Aug 26 21:04:59 2018 -0700

    Make check_isfinite, check_scale optional in clip_global_norm (#12042)
    
    * Make check_isfinite, check_scale optional in clip_global_norm
    
    If both are set to false, clip_global_norm does not force any synchronization
    and throughput can be increased.
    
    * Add tests
    
    * Remove check_scale
    
    * Document return type
    
    * Fix test_gluon_gpu
---
 python/mxnet/gluon/utils.py         | 38 ++++++++++++++++++++++++++++---------
 tests/python/gpu/test_gluon_gpu.py  | 16 ++++++++++------
 tests/python/unittest/test_gluon.py | 11 ++++++-----
 3 files changed, 45 insertions(+), 20 deletions(-)

diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py
index f04479d..d5a14a6 100644
--- a/python/mxnet/gluon/utils.py
+++ b/python/mxnet/gluon/utils.py
@@ -115,8 +115,23 @@ def split_and_load(data, ctx_list, batch_axis=0, even_split=True):
     return [i.as_in_context(ctx) for i, ctx in zip(slices, ctx_list)]
 
 
-def clip_global_norm(arrays, max_norm):
+def clip_global_norm(arrays, max_norm, check_isfinite=True):
     """Rescales NDArrays so that the sum of their 2-norm is smaller than `max_norm`.
+
+    Parameters
+    ----------
+    arrays : list of NDArray
+    max_norm : float
+    check_isfinite : bool, default True
+         If True, check that the total_norm is finite (not nan or inf). This
+         requires a blocking .asscalar() call.
+
+    Returns
+    -------
+    NDArray or float
+      Total norm. Return type is NDArray of shape (1,) if check_isfinite is
+      False. Otherwise a float is returned.
+
     """
     def _norm(array):
         if array.stype == 'default':
@@ -126,15 +141,20 @@ def clip_global_norm(arrays, max_norm):
     assert len(arrays) > 0
     ctx = arrays[0].context
     total_norm = ndarray.add_n(*[_norm(arr).as_in_context(ctx) for arr in arrays])
-    total_norm = ndarray.sqrt(total_norm).asscalar()
-    if not np.isfinite(total_norm):
-        warnings.warn(UserWarning('nan or inf is detected. Clipping results will be undefined.'),
-                      stacklevel=2)
+    total_norm = ndarray.sqrt(total_norm)
+    if check_isfinite:
+        if not np.isfinite(total_norm.asscalar()):
+            warnings.warn(
+                UserWarning('nan or inf is detected. '
+                            'Clipping results will be undefined.'), stacklevel=2)
     scale = max_norm / (total_norm + 1e-8)
-    if scale < 1.0:
-        for arr in arrays:
-            arr *= scale
-    return total_norm
+    scale = ndarray.min(ndarray.concat(scale, ndarray.ones(1, ctx=ctx), dim=0))
+    for arr in arrays:
+        arr *= scale.as_in_context(arr.context)
+    if check_isfinite:
+        return total_norm.asscalar()
+    else:
+        return total_norm
 
 
 def _indent(s_, numSpaces):
diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py
index 42d65da..69375af 100644
--- a/tests/python/gpu/test_gluon_gpu.py
+++ b/tests/python/gpu/test_gluon_gpu.py
@@ -111,12 +111,16 @@ def test_gluon_ctc_consistency():
 
 @with_seed()
 def test_global_norm_clip_multi_device():
-    x1 = mx.nd.ones((3,3), ctx=mx.gpu(0))
-    x2 = mx.nd.ones((4,4), ctx=mx.cpu(0))
-    norm = gluon.utils.clip_global_norm([x1, x2], 1.0)
-    assert norm == 5.0
-    assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5)
-    assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5)
+    for check_isfinite in [True, False]:
+        x1 = mx.nd.ones((3,3), ctx=mx.gpu(0))
+        x2 = mx.nd.ones((4,4), ctx=mx.cpu(0))
+        norm = gluon.utils.clip_global_norm([x1, x2], 1.0, check_isfinite=check_isfinite)
+        if check_isfinite:
+            assert norm == 5.0
+        else:
+            assert norm.asscalar() == 5.0
+        assert_almost_equal(x1.asnumpy(), np.ones((3, 3)) / 5)
+        assert_almost_equal(x2.asnumpy(), np.ones((4, 4)) / 5)
 
 
 def _check_batchnorm_result(input, num_devices=1, cuda=False):
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 61b441a..bf9f5a7 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -735,10 +735,10 @@ def test_sequential_warning():
 @with_seed()
 def test_global_norm_clip():
     stypes = ['default', 'row_sparse']
-    def check_global_norm_clip(stype):
+    def check_global_norm_clip(stype, check_isfinite):
         x1 = mx.nd.ones((3,3)).tostype(stype)
         x2 = mx.nd.ones((4,4)).tostype(stype)
-        norm = gluon.utils.clip_global_norm([x1, x2], 1.0)
+        norm = gluon.utils.clip_global_norm([x1, x2], 1.0, check_isfinite=check_isfinite)
         assert norm == 5.0
         assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5)
         assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5)
@@ -746,11 +746,12 @@ def test_global_norm_clip():
         x3 = mx.nd.array([1.0, 2.0, float('nan')]).tostype(stype)
         with warnings.catch_warnings(record=True) as w:
             warnings.simplefilter("always")
-            gluon.utils.clip_global_norm([x1, x3], 2.0)
-            assert len(w) == 1
+            gluon.utils.clip_global_norm([x1, x3], 2.0, check_isfinite=check_isfinite)
+            assert len(w) == check_isfinite
 
     for stype in stypes:
-        check_global_norm_clip(stype)
+        for check_isfinite in [True, False]:
+            check_global_norm_clip(stype, check_isfinite)
 
 @with_seed()
 def test_embedding():