You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/10/27 16:52:33 UTC

[incubator-mxnet] branch master updated: [BUGFIX] Fix #20293 (#20462)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 026dbf8  [BUGFIX] Fix #20293 (#20462)
026dbf8 is described below

commit 026dbf8d91a0c6bd9af86813a939c3643a2064e5
Author: Zhenghui Jin <69...@users.noreply.github.com>
AuthorDate: Wed Oct 27 09:51:09 2021 -0700

    [BUGFIX] Fix #20293 (#20462)
    
    * fix 20293
    
    * avoid state.array_reqs being overrided by reqs
    
    * update
    
    * fix AddTo grad_req in staticbackward
    
    * fix lint
    
    * fix executor
---
 python/mxnet/executor.py                    |  5 +++--
 python/mxnet/symbol/symbol.py               |  7 +++++--
 src/imperative/cached_op.cc                 | 21 +++++++++++++++++----
 src/imperative/inplace_addto_detect_pass.cc |  7 ++++++-
 tests/python/unittest/test_executor.py      | 18 ++++++++++++++++++
 5 files changed, 49 insertions(+), 9 deletions(-)

diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py
index 85dba4b..36e8f39 100644
--- a/python/mxnet/executor.py
+++ b/python/mxnet/executor.py
@@ -34,7 +34,7 @@ class Executor:
     >>> c = 2 * a + b
     >>> texec = c._bind(mx.cpu(), {'a': mx.nd.array([1,2]), 'b':mx.nd.array([2,3])})
     """
-    def __init__(self, sym, ctx, args, args_grad, grad_req, aux_states):
+    def __init__(self, sym, ctx, args, args_grad, grad_req, aux_states, static_alloc=False):
         self.outputs = None
         self._input_names = sym.list_inputs()
         self._aux_names = sym.list_auxiliary_states()
@@ -42,6 +42,7 @@ class Executor:
         self._output_names = sym.list_outputs()
         self._ctx = ctx
         self._grad_req = grad_req
+        self.static_alloc = static_alloc
         # grad_req
         self._requires_grad = False
         if isinstance(grad_req, dict):
@@ -121,7 +122,7 @@ class Executor:
                         with self._ctx:
                             self._args[i].attach_grad(req, stype=g.stype)
                             self._args[i].grad[:] = g
-        self._cached_op = ndarray.CachedOp(sym)
+        self._cached_op = ndarray.CachedOp(sym, flags=[("static_alloc", self.static_alloc)])
 
     def get_optimized_symbol(self):
         """Get an optimized version of the symbol from the executor.
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 34e53da..a9b1a60 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -1793,7 +1793,7 @@ class Symbol(SymbolBase):
         return Executor(self, ctx, args, args_grad, grad_req, aux_states)
 
     def _bind(self, ctx, args, args_grad=None, grad_req='write',
-              aux_states=None):
+              aux_states=None, static_alloc=False):
         """Binds the current symbol to an executor and returns it.
 
         We first declare the computation and then bind to the data to run.
@@ -1856,6 +1856,9 @@ class Symbol(SymbolBase):
               `auxiliary_states` to the corresponding `NDArray`,
             - In either case, all the auxiliary states need to be provided.
 
+        static_alloc : bool, default False
+            Statically allocate memory to improve speed. Memory usage may increase.
+
         Returns
         -------
         executor : Executor
@@ -1874,7 +1877,7 @@ class Symbol(SymbolBase):
         gradient they interested in.
         """
         assert isinstance(grad_req, (str, dict))
-        return Executor(self, ctx, args, args_grad, grad_req, aux_states)
+        return Executor(self, ctx, args, args_grad, grad_req, aux_states, static_alloc)
 
     def gradient(self, wrt):
         """Gets the autodiff of current symbol.
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 692f9d6..894ef09 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -308,6 +308,21 @@ bool CachedOp::SetBackwardGraph(GraphInfo* info,
     g.attrs[AddPrefix(BACKWARD, REF_COUNT)] = std::make_shared<dmlc::any>(std::move(ref_count));
   }
 
+  // Set AddTo Entry based on the req that users provide
+  if (detect_inplace_addto) {
+    std::vector<int> addto_entry(idx.num_node_entries(), 0);
+    for (size_t i = 0; i < info->grad_graph.outputs.size(); ++i) {
+      if (reqs[i] == kAddTo) {
+        auto entry = info->grad_graph.outputs[i];
+        if (!idx.exist(entry.node.get()))
+          continue;
+        auto eid         = idx.entry_id(entry);
+        addto_entry[eid] = 1;
+      }
+    }
+    g.attrs["addto_entry"] = std::make_shared<nnvm::any>(std::move(addto_entry));
+  }
+
   auto shapes = info->fwd_graph.GetAttr<mxnet::ShapeVector>("shape");
   shapes.resize(idx.num_node_entries(), mxnet::TShape());
   auto dtypes = info->fwd_graph.GetAttr<DTypeVector>("dtype");
@@ -1047,8 +1062,7 @@ void CachedOp::StaticBackward(const bool retain_graph,
       auto entry = state.info.grad_graph.outputs[iter->second];
       if (!idx.exist(entry.node.get()))
         continue;
-      auto eid              = idx.entry_id(entry);
-      state.array_reqs[eid] = reqs[iter->second];
+      auto eid = idx.entry_id(entry);
       // An input and an output may share the same array.
       INIT_DETACHED(outputs[iter->second], arrays[eid]);
       arrays[eid] = outputs[iter->second];
@@ -1058,8 +1072,7 @@ void CachedOp::StaticBackward(const bool retain_graph,
       auto entry = state.info.grad_graph.outputs[i];
       if (!idx.exist(entry.node.get()))
         continue;
-      auto eid              = idx.entry_id(entry);
-      state.array_reqs[eid] = reqs[i];
+      auto eid = idx.entry_id(entry);
       // An input and an output may share the same array.
       INIT_DETACHED(outputs[i], arrays[eid]);
       arrays[eid] = outputs[i];
diff --git a/src/imperative/inplace_addto_detect_pass.cc b/src/imperative/inplace_addto_detect_pass.cc
index a3633bc..86480b4 100644
--- a/src/imperative/inplace_addto_detect_pass.cc
+++ b/src/imperative/inplace_addto_detect_pass.cc
@@ -39,7 +39,12 @@ Graph DetectInplaceAddTo(Graph g) {
   auto& idx                      = g.indexed_graph();
   // reference cont.
   std::vector<int> ref_count(idx.num_node_entries(), 0);
-  std::vector<int> addto_entry(idx.num_node_entries(), 0);
+  std::vector<int> addto_entry;
+  if (g.attrs.count("addto_entry")) {
+    addto_entry = g.GetAttr<std::vector<int> >("addto_entry");
+  } else {
+    addto_entry = std::vector<int>(idx.num_node_entries(), 0);
+  }
   std::vector<int> skip_plus_node(idx.num_nodes(), 0);
 
   for (auto& e : idx.outputs()) {
diff --git a/tests/python/unittest/test_executor.py b/tests/python/unittest/test_executor.py
index b735c83..f374798 100644
--- a/tests/python/unittest/test_executor.py
+++ b/tests/python/unittest/test_executor.py
@@ -159,3 +159,21 @@ def test_cached_op_init():
     check_init(False, False)
     check_init(True, False)
     check_init(True, True)
+
+def test_elemwise_add_grad():
+    json = "{\"nodes\": [{\"op\":\"null\",\"name\":\".Inputs.Input1\",\"inputs\":[]},{\"op\":\"null\",\"name\":\".Inputs.Input2\",\"inputs\":[]},{\"op\":\"elemwise_add\",\"name\":\".$0\",\"inputs\":[[0,0,0],[1,0,0]]},{\"op\":\"_copy\",\"name\":\".Outputs.Output\",\"inputs\":[[2,0,0]]}],\"arg_nodes\":[0,1],\"heads\":[[3,0,0]]}"
+    sym = mx.symbol.fromjson(json)
+
+    ex = sym._bind(
+        mx.cpu(), 
+        {'.Inputs.Input1': mx.nd.array([0.4]), '.Inputs.Input2': mx.nd.array([0.5])},
+        args_grad={
+            '.Inputs.Input1': mx.ndarray.zeros((1)), 
+            '.Inputs.Input2': mx.ndarray.zeros((1))
+        },
+        grad_req={'.Inputs.Input1': 'null', '.Inputs.Input2': 'write'}
+    )
+    ex.forward(is_train=True)
+    print(ex.outputs)
+    ex.backward(out_grads=mx.nd.array([1]))
+    print(ex.grad_arrays)