You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/25 21:58:20 UTC
[incubator-mxnet] branch master updated: multi-device support
(#8812)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new fdc0766 multi-device support (#8812)
fdc0766 is described below
commit fdc0766971ed95811d0db15ad0d878998192fce5
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Sat Nov 25 13:58:18 2017 -0800
multi-device support (#8812)
---
python/mxnet/gluon/utils.py | 3 ++-
tests/python/gpu/test_operator_gpu.py | 9 +++++++++
2 files changed, 11 insertions(+), 1 deletion(-)
diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py
index 890fb60..88effc9 100644
--- a/python/mxnet/gluon/utils.py
+++ b/python/mxnet/gluon/utils.py
@@ -117,7 +117,8 @@ def clip_global_norm(arrays, max_norm):
"""Rescales NDArrays so that the sum of their 2-norm is smaller than `max_norm`.
"""
assert len(arrays) > 0
- total_norm = ndarray.add_n(*[ndarray.dot(x, x)
+ ctx = arrays[0].context
+ total_norm = ndarray.add_n(*[ndarray.dot(x, x).as_in_context(ctx)
for x in (arr.reshape((-1,)) for arr in arrays)])
total_norm = ndarray.sqrt(total_norm).asscalar()
if not np.isfinite(total_norm):
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index 13b547e..15354b6 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -1423,6 +1423,15 @@ def test_cuda_rtc():
assert (y.asnumpy() == 12).all()
+def test_global_norm_clip_multi_device():
+ x1 = mx.nd.ones((3,3), ctx=mx.gpu(0))
+ x2 = mx.nd.ones((4,4), ctx=mx.gpu(1))
+ norm = gluon.utils.clip_global_norm([x1, x2], 1.0)
+ assert norm == 5.0
+ assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5)
+ assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5)
+
+
def test_cross_device_autograd():
x = mx.nd.random.uniform(shape=(10,))
x.attach_grad()
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].