You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/25 21:58:20 UTC

[incubator-mxnet] branch master updated: multi-device support (#8812)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new fdc0766  multi-device support (#8812)
fdc0766 is described below

commit fdc0766971ed95811d0db15ad0d878998192fce5
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Sat Nov 25 13:58:18 2017 -0800

    multi-device support (#8812)
---
 python/mxnet/gluon/utils.py           | 3 ++-
 tests/python/gpu/test_operator_gpu.py | 9 +++++++++
 2 files changed, 11 insertions(+), 1 deletion(-)

diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py
index 890fb60..88effc9 100644
--- a/python/mxnet/gluon/utils.py
+++ b/python/mxnet/gluon/utils.py
@@ -117,7 +117,8 @@ def clip_global_norm(arrays, max_norm):
     """Rescales NDArrays so that the sum of their 2-norm is smaller than `max_norm`.
     """
     assert len(arrays) > 0
-    total_norm = ndarray.add_n(*[ndarray.dot(x, x)
+    ctx = arrays[0].context
+    total_norm = ndarray.add_n(*[ndarray.dot(x, x).as_in_context(ctx)
                                  for x in (arr.reshape((-1,)) for arr in arrays)])
     total_norm = ndarray.sqrt(total_norm).asscalar()
     if not np.isfinite(total_norm):
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index 13b547e..15354b6 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -1423,6 +1423,15 @@ def test_cuda_rtc():
     assert (y.asnumpy() == 12).all()
 
 
+def test_global_norm_clip_multi_device():
+    x1 = mx.nd.ones((3,3), ctx=mx.gpu(0))
+    x2 = mx.nd.ones((4,4), ctx=mx.gpu(1))
+    norm = gluon.utils.clip_global_norm([x1, x2], 1.0)
+    assert norm == 5.0
+    assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5)
+    assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5)
+
+
 def test_cross_device_autograd():
     x = mx.nd.random.uniform(shape=(10,))
     x.attach_grad()

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].