You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/27 04:05:00 UTC

[GitHub] eric-haibin-lin closed pull request #12042: Make check_isfinite, check_scale optional in clip_global_norm

eric-haibin-lin closed pull request #12042: Make check_isfinite, check_scale optional in clip_global_norm
URL: https://github.com/apache/incubator-mxnet/pull/12042
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py
index f04479d2371..d5a14a6859a 100644
--- a/python/mxnet/gluon/utils.py
+++ b/python/mxnet/gluon/utils.py
@@ -115,8 +115,23 @@ def split_and_load(data, ctx_list, batch_axis=0, even_split=True):
     return [i.as_in_context(ctx) for i, ctx in zip(slices, ctx_list)]
 
 
-def clip_global_norm(arrays, max_norm):
+def clip_global_norm(arrays, max_norm, check_isfinite=True):
     """Rescales NDArrays so that the sum of their 2-norm is smaller than `max_norm`.
+
+    Parameters
+    ----------
+    arrays : list of NDArray
+    max_norm : float
+    check_isfinite : bool, default True
+         If True, check that the total_norm is finite (not nan or inf). This
+         requires a blocking .asscalar() call.
+
+    Returns
+    -------
+    NDArray or float
+      Total norm. Return type is NDArray of shape (1,) if check_isfinite is
+      False. Otherwise a float is returned.
+
     """
     def _norm(array):
         if array.stype == 'default':
@@ -126,15 +141,20 @@ def _norm(array):
     assert len(arrays) > 0
     ctx = arrays[0].context
     total_norm = ndarray.add_n(*[_norm(arr).as_in_context(ctx) for arr in arrays])
-    total_norm = ndarray.sqrt(total_norm).asscalar()
-    if not np.isfinite(total_norm):
-        warnings.warn(UserWarning('nan or inf is detected. Clipping results will be undefined.'),
-                      stacklevel=2)
+    total_norm = ndarray.sqrt(total_norm)
+    if check_isfinite:
+        if not np.isfinite(total_norm.asscalar()):
+            warnings.warn(
+                UserWarning('nan or inf is detected. '
+                            'Clipping results will be undefined.'), stacklevel=2)
     scale = max_norm / (total_norm + 1e-8)
-    if scale < 1.0:
-        for arr in arrays:
-            arr *= scale
-    return total_norm
+    scale = ndarray.min(ndarray.concat(scale, ndarray.ones(1, ctx=ctx), dim=0))
+    for arr in arrays:
+        arr *= scale.as_in_context(arr.context)
+    if check_isfinite:
+        return total_norm.asscalar()
+    else:
+        return total_norm
 
 
 def _indent(s_, numSpaces):
diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py
index 42d65dab5fd..69375afdfe0 100644
--- a/tests/python/gpu/test_gluon_gpu.py
+++ b/tests/python/gpu/test_gluon_gpu.py
@@ -111,12 +111,16 @@ def test_gluon_ctc_consistency():
 
 @with_seed()
 def test_global_norm_clip_multi_device():
-    x1 = mx.nd.ones((3,3), ctx=mx.gpu(0))
-    x2 = mx.nd.ones((4,4), ctx=mx.cpu(0))
-    norm = gluon.utils.clip_global_norm([x1, x2], 1.0)
-    assert norm == 5.0
-    assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5)
-    assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5)
+    for check_isfinite in [True, False]:
+        x1 = mx.nd.ones((3,3), ctx=mx.gpu(0))
+        x2 = mx.nd.ones((4,4), ctx=mx.cpu(0))
+        norm = gluon.utils.clip_global_norm([x1, x2], 1.0, check_isfinite=check_isfinite)
+        if check_isfinite:
+            assert norm == 5.0
+        else:
+            assert norm.asscalar() == 5.0
+        assert_almost_equal(x1.asnumpy(), np.ones((3, 3)) / 5)
+        assert_almost_equal(x2.asnumpy(), np.ones((4, 4)) / 5)
 
 
 def _check_batchnorm_result(input, num_devices=1, cuda=False):
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 61b441a5f84..bf9f5a77c84 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -735,10 +735,10 @@ def test_sequential_warning():
 @with_seed()
 def test_global_norm_clip():
     stypes = ['default', 'row_sparse']
-    def check_global_norm_clip(stype):
+    def check_global_norm_clip(stype, check_isfinite):
         x1 = mx.nd.ones((3,3)).tostype(stype)
         x2 = mx.nd.ones((4,4)).tostype(stype)
-        norm = gluon.utils.clip_global_norm([x1, x2], 1.0)
+        norm = gluon.utils.clip_global_norm([x1, x2], 1.0, check_isfinite=check_isfinite)
         assert norm == 5.0
         assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5)
         assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5)
@@ -746,11 +746,12 @@ def check_global_norm_clip(stype):
         x3 = mx.nd.array([1.0, 2.0, float('nan')]).tostype(stype)
         with warnings.catch_warnings(record=True) as w:
             warnings.simplefilter("always")
-            gluon.utils.clip_global_norm([x1, x3], 2.0)
-            assert len(w) == 1
+            gluon.utils.clip_global_norm([x1, x3], 2.0, check_isfinite=check_isfinite)
+            assert len(w) == check_isfinite
 
     for stype in stypes:
-        check_global_norm_clip(stype)
+        for check_isfinite in [True, False]:
+            check_global_norm_clip(stype, check_isfinite)
 
 @with_seed()
 def test_embedding():


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services