You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by li...@apache.org on 2020/07/25 23:21:03 UTC

[incubator-mxnet] branch master updated: add support for np.ndarray in autograd.function (#18790)

This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 98b3f73  add support for np.ndarray in autograd.function (#18790)
98b3f73 is described below

commit 98b3f73bd0f30034e3f6848eb75d38c30c8b60b4
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Sat Jul 25 16:19:36 2020 -0700

    add support for np.ndarray in autograd.function (#18790)
---
 python/mxnet/autograd.py               | 14 +++++---
 tests/python/unittest/test_autograd.py | 61 ++++++++++++++++++++++++++++++++++
 2 files changed, 71 insertions(+), 4 deletions(-)

diff --git a/python/mxnet/autograd.py b/python/mxnet/autograd.py
index f968275..aac7cbc 100644
--- a/python/mxnet/autograd.py
+++ b/python/mxnet/autograd.py
@@ -28,6 +28,7 @@ from .base import NDArrayHandle, c_array, c_handle_array, c_array_buf, MXCallbac
 from .ndarray import NDArray, _ndarray_cls
 from .ndarray import _GRAD_REQ_MAP
 from .symbol import Symbol
+from .util import is_np_array
 
 
 def set_recording(is_recording): #pylint: disable=redefined-outer-name
@@ -448,25 +449,30 @@ class Function(object):
             outputs = (outputs,)
 
         key = Function._registry.inc()
+        if is_np_array():
+            from .numpy import ndarray
+            array_cls = ndarray
+        else:
+            array_cls = NDArray
 
         def backward_entry(num_ograds, num_igrads, ptrs, reqs, is_train, _):
             """entry point for backward."""
             # pylint: disable=W0613
             try:
-                output_grads = [NDArray(ctypes.cast(i, NDArrayHandle), writable=False) \
+                output_grads = [array_cls(ctypes.cast(i, NDArrayHandle), writable=False) \
                                 for i in ptrs[:num_ograds]]
-                input_grads = [NDArray(ctypes.cast(i, NDArrayHandle), writable=True) \
+                input_grads = [array_cls(ctypes.cast(i, NDArrayHandle), writable=True) \
                                for i in ptrs[num_ograds:num_ograds+num_igrads]]
                 reqs = [reqs[i] for i in range(num_igrads)]
                 rets = self.backward(*output_grads)
-                if isinstance(rets, NDArray):
+                if isinstance(rets, array_cls):
                     rets = (rets,)
                 assert len(rets) == len(input_grads), \
                     "%s.backward must return exactly the same number " \
                     "of NDArrays as the number of NDArrays arguments to forward." \
                     "Expecting %d got %d"%(self.__class__.name, len(input_grads), len(rets))
                 for igrad, ret, req in zip(input_grads, rets, reqs):
-                    assert isinstance(ret, NDArray), \
+                    assert isinstance(ret, array_cls), \
                         "autograd.Function.backward must return NDArrays, not %s"%type(ret)
                     if req == 0:  # null
                         return True
diff --git a/tests/python/unittest/test_autograd.py b/tests/python/unittest/test_autograd.py
index 6a75eed..f9a7ecc 100644
--- a/tests/python/unittest/test_autograd.py
+++ b/tests/python/unittest/test_autograd.py
@@ -407,6 +407,67 @@ def test_function1():
 
 @with_seed()
 @pytest.mark.garbage_expected
+@use_np
+def test_np_function():
+    class func(Function):
+        def forward(self, x, y):
+            m = x / y
+            n = x * y
+            self.save_for_backward(x, y)
+            return m, n
+
+        def backward(self, dm, dn):
+            x, y = self.saved_tensors
+            dx = dm/y + dn*y
+            dy = dn*x - dm * x / y / y
+            return dx, dy
+
+    f = func()
+    x = mx.np.random.uniform(size=(10,))
+    x.attach_grad()
+    y = mx.np.random.uniform(size=(10,))
+    y.attach_grad()
+    with record():
+        m, n = f(x, y)
+        backward([m, n])
+
+    dx1 = x.grad.asnumpy()
+    dy1 = y.grad.asnumpy()
+
+    with record():
+        backward([x/y, x*y])
+
+    # Non-zero atol required, as exposed by seed 630179191
+    atol = 1e-6
+    assert_almost_equal(x.grad.asnumpy(), dx1, atol=atol)
+    assert_almost_equal(y.grad.asnumpy(), dy1, atol=atol)
+
+
+@with_seed()
+@pytest.mark.garbage_expected
+@use_np
+def test_np_function1():
+    class Foo(mx.autograd.Function):
+        def __init__(self):
+            super(Foo, self).__init__()
+
+        def forward(self, X):
+            return X + 1;
+
+        def backward(self, dY):
+            return dY
+
+    with mx.autograd.record():
+        X = mx.np.zeros((3, 4))
+        #X.attach_grad()  # uncommenting this line works
+        for i in range(5):
+            f = Foo()
+            X = f(X)
+        X.wait_to_read()
+
+
+@with_seed()
+@pytest.mark.garbage_expected
 def test_get_symbol():
     x = mx.nd.ones((1,))
     x.attach_grad()