You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2021/02/15 14:59:37 UTC

[incubator-mxnet] branch master updated: Use multi-tensor zeroing for resetting grads (#19894)

This is an automated email from the ASF dual-hosted git repository.

lausen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new da24765  Use multi-tensor zeroing for resetting grads (#19894)
da24765 is described below

commit da247654371664fb906ad286053c63e6fc2f5b5a
Author: Moises Hernandez <50...@users.noreply.github.com>
AuthorDate: Mon Feb 15 06:57:38 2021 -0800

    Use multi-tensor zeroing for resetting grads (#19894)
---
 python/mxnet/gluon/block.py | 14 ++++++--------
 1 file changed, 6 insertions(+), 8 deletions(-)

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 547fbaa..299df18 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -742,18 +742,16 @@ class Block:
                 if g.stype == 'row_sparse':
                     ndarray.zeros_like(g, out=g)
                 else:
-                    arrays[g.ctx].append(g)
+                    if is_np_array():
+                        arrays[g.ctx].append(g.as_nd_ndarray())
+                    else:
+                        arrays[g.ctx].append(g)
 
         if len(arrays) == 0:
             return
 
-        if is_np_array():
-            for arr in arrays.values():
-                for ele in arr:
-                    ele[()] = 0
-        else:
-            for arr in arrays.values():
-                ndarray.reset_arrays(*arr, num_arrays=len(arr))
+        for arr in arrays.values():
+            ndarray.reset_arrays(*arr, num_arrays=len(arr))
 
     def reset_ctx(self, ctx):
         """Re-assign all Parameters to other contexts.