You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/09/23 21:43:19 UTC

[incubator-mxnet] branch master updated: [FEATURE] Add feature of retain_grad (#20500)

This is an automated email from the ASF dual-hosted git repository.

zhenghuijin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 05d1b26  [FEATURE] Add feature of retain_grad  (#20500)
05d1b26 is described below

commit 05d1b26520d0dcaf3bf9d55343b9792eccb63a6f
Author: KexinFeng <fe...@umn.edu>
AuthorDate: Thu Sep 23 16:41:08 2021 -0500

    [FEATURE] Add feature of retain_grad  (#20500)
    
    * Replace "CloneGradient" with "ElemwiseGradUseNone"
    
    * fix issue elemwise_add
    
    * fix elemwise_add issue with `ElemwiseGradUseNone`
    
    * reverse_to_CloneGradient
    
    * add_retain_grad
    
    * unit_test
    
    * tidy_up
    
    * tidy_up
    
    * sanity
    
    * const_reference
    
    * const_ref
    
    * merge_rg_to_ag
    
    * sanity
    
    * sanity
    
    * add_drop_grad
    
    * sanity_check
    
    * sanity_check
    
    * sanity_check
    
    * build_err
    
    * build_err
    
    * skip_remark_variable
    
    * repetitive_mark
    
    * ReInit_in_dropgrad
    
    * ReInit_in_dropgrad
    
    * sanity_check
    
    * add drop and tests to gluon
    
    * sanity
    
    * update exec_pass.h
    
    Co-authored-by: Zhenghui Jin <69...@users.noreply.github.com>
---
 include/mxnet/c_api.h                  |   8 +++
 include/mxnet/imperative.h             |   4 ++
 python/mxnet/ndarray/ndarray.py        |   5 ++
 python/mxnet/numpy/multiarray.py       |   5 ++
 src/c_api/c_api_ndarray.cc             |  12 ++++
 src/imperative/exec_pass.h             |   4 +-
 src/imperative/imperative.cc           | 101 +++++++++++++++++++++++++++------
 src/nnvm/gradient.cc                   |  30 ++++++++--
 tests/python/unittest/test_autograd.py |  67 +++++++++++++++++++++-
 9 files changed, 211 insertions(+), 25 deletions(-)

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index e759b76..926b31e 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1275,6 +1275,14 @@ MXNET_DLL int MXAutogradMarkVariables(uint32_t num_var,
                                       uint32_t *reqs_array,
                                       NDArrayHandle *grad_handles);
 /*!
+ * \brief unmark nonleaf NDArrays to free the memory
+ * \param num_var number of variable NDArrays
+ * \param var_handles variable NDArrays
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXAutogradDropGrads(uint32_t num_var,
+                                  NDArrayHandle *var_handles);
+/*!
  * \brief compute the gradient of outputs w.r.t variabels
  * \param num_output number of output NDArray
  * \param output_handles output NDArrays
diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h
index d998a74..76ccf25 100644
--- a/include/mxnet/imperative.h
+++ b/include/mxnet/imperative.h
@@ -272,12 +272,16 @@ class Imperative {
   void MarkVariables(const std::vector<NDArray*>& variables,
                      const std::vector<uint32_t>& grad_reqs,
                      const std::vector<NDArray*>& gradients);
+  /*! \brief unmark nonleaf variables to free the memory. */
+  void DropGrads(const std::vector<NDArray*>& variables);
   /*! \brief compute the gradient of outputs w.r.t variables. */
   std::vector<NDArray*> Backward(const std::vector<NDArray*>& outputs,
                                  const std::vector<NDArray*>& ograds,
                                  const std::vector<NDArray*>& variables,
                                  bool is_train, bool retain_graph,
                                  bool create_graph);
+  /*! \brief Return the marked nonleaf nodes. */
+  std::vector<nnvm::ObjectPtr> ListNonleafVariables(const nnvm::Symbol& sym) const;
   /*! \return AutogradRuntime singleton */
   static Imperative* Get();
   /*! \brief Should op execution bulking be employed during inference. */
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index cbd0c51..1f49fc5 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -2885,6 +2885,11 @@ fixed-size items.
             ctypes.pointer(mx_uint(grad_req)),
             ctypes.pointer(grad.handle)))
 
+    def drop_grad(self):
+        """Free the memory of the marked ndarray."""
+        check_call(_LIB.MXAutogradDropGrads(
+            1, ctypes.pointer(self.handle)))
+
     @property
     def grad(self):
         """Returns gradient buffer attached to this NDArray."""
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 5cca1fa..c2d9db9 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -1410,6 +1410,11 @@ class ndarray(NDArray):  # pylint: disable=invalid-name
             ctypes.pointer(mx_uint(grad_req)),
             ctypes.pointer(grad.handle)))
 
+    def drop_grad(self):
+        """Free the memory of the marked ndarray."""
+        check_call(_LIB.MXAutogradDropGrads(
+            1, ctypes.pointer(self.handle)))
+
     @property
     def grad(self):
         """Returns gradient buffer attached to this ndarray."""
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index d967ae6..3d66996 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -335,6 +335,18 @@ int MXAutogradMarkVariables(uint32_t num_var,
   API_END();
 }
 
+int MXAutogradDropGrads(uint32_t num_var,
+                       NDArrayHandle *var_handles) {
+  API_BEGIN();
+  std::vector<NDArray*> variables;
+  variables.reserve(num_var);
+  for (uint32_t i = 0; i < num_var; ++i) {
+    variables.emplace_back(static_cast<NDArray*>(var_handles[i]));
+  }
+  Imperative::Get()->DropGrads(variables);
+  API_END();
+}
+
 int MXAutogradComputeGradient(uint32_t num_output, NDArrayHandle* output_handles) {
   return MXAutogradBackward(num_output, output_handles, nullptr, 0);
 }
diff --git a/src/imperative/exec_pass.h b/src/imperative/exec_pass.h
index 440fc6b..5f27a16 100644
--- a/src/imperative/exec_pass.h
+++ b/src/imperative/exec_pass.h
@@ -287,12 +287,14 @@ inline Graph MXGradient(
     std::vector<const Op*> zero_ops  = std::vector<const Op*>(),
     std::string copy_op_str          = std::string(),
     mxnet::ShapeVector in_arg_shapes = mxnet::ShapeVector(),
-    DTypeVector in_arg_dtypes        = DTypeVector()) {
+    DTypeVector in_arg_dtypes        = DTypeVector(),
+    std::vector<NodeEntry> us        = std::vector<NodeEntry>() ) {
   graph.attrs["grad_ys"]          = std::make_shared<any>(std::move(ys));
   graph.attrs["grad_xs"]          = std::make_shared<any>(std::move(xs));
   graph.attrs["grad_ys_out_grad"] = std::make_shared<any>(std::move(ys_out_grad));
   graph.attrs["in_arg_shapes"]    = std::make_shared<any>(std::move(in_arg_shapes));
   graph.attrs["in_arg_dtypes"]    = std::make_shared<any>(std::move(in_arg_dtypes));
+  graph.attrs["grad_us"]          = std::make_shared<any>(std::move(us));
 
   if (aggregate_fun != nullptr) {
     graph.attrs["grad_aggregate_fun"] = std::make_shared<any>(aggregate_fun);
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index 0ec5ae5..af1ee09 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -142,29 +142,54 @@ void Imperative::MarkVariables(const std::vector<NDArray*>& variables,
                                const std::vector<uint32_t>& grad_reqs,
                                const std::vector<NDArray*>& gradients) {
   for (uint32_t i = 0; i < variables.size(); ++i) {
-    std::string str_c(std::to_string(variable_count_++));
-
-    variables[i]->autograd_entry_ =
-        nnvm::NodeEntry{nnvm::Symbol::CreateVariable("var" + str_c).outputs[0].node, 0, 0};
-    AGInfo& info = AGInfo::Create(variables[i]->autograd_entry_.node);
-    info.outputs.emplace_back(variables[i]->Detach());
-    info.out_grads.emplace_back(gradients[i]->Detach());
-    info.grad_req = static_cast<OpReqType>(grad_reqs[i]);
-    info.ctx      = variables[i]->ctx();
-
-    gradients[i]->autograd_entry_ =
-        nnvm::NodeEntry{nnvm::Symbol::CreateVariable("grad" + str_c).outputs[0].node, 0, 0};
-    AGInfo& grad_info = AGInfo::Create(gradients[i]->autograd_entry_.node);
-    grad_info.outputs.emplace_back(gradients[i]->Detach());
-    grad_info.ctx = gradients[i]->ctx();
+    // Unmarked leaf nodes have null autograd_entry_, while marked nonleaf nodes don't.
+    if (!variables[i]->autograd_entry_.node || variables[i]->autograd_entry_.node->is_variable()) {
+      std::string str_c(std::to_string(variable_count_++));
+      variables[i]->autograd_entry_ =
+          nnvm::NodeEntry{nnvm::Symbol::CreateVariable("var" + str_c).outputs[0].node, 0, 0};
+      AGInfo& info = AGInfo::Create(variables[i]->autograd_entry_.node);
+      info.outputs.emplace_back(variables[i]->Detach());
+      info.out_grads.emplace_back(gradients[i]->Detach());
+      info.grad_req = static_cast<OpReqType>(grad_reqs[i]);
+      info.ctx      = variables[i]->ctx();
+
+      gradients[i]->autograd_entry_ =
+          nnvm::NodeEntry{nnvm::Symbol::CreateVariable("grad" + str_c).outputs[0].node, 0, 0};
+      AGInfo& grad_info = AGInfo::Create(gradients[i]->autograd_entry_.node);
+      grad_info.outputs.emplace_back(gradients[i]->Detach());
+      grad_info.ctx = gradients[i]->ctx();
+    } else {
+      AGInfo& info = AGInfo::Get(variables[i]->autograd_entry_.node);
+      CHECK_EQ(info.out_grads.size(), 0)
+        <<"The node has already been marked. Cannot mark it again.";
+      info.out_grads.emplace_back(gradients[i]->Detach());
+      info.grad_req = static_cast<OpReqType>(grad_reqs[i]);
+      info.ctx      = variables[i]->ctx();
+    }
+  }
+}
+
+// Unmark the variables to free the memory.
+void Imperative::DropGrads(const std::vector<NDArray*>& variables) {
+  for (auto variable : variables) {
+    if (variable->autograd_entry_.node) {
+      AGInfo& info = AGInfo::Get(variable->autograd_entry_.node);
+      CHECK_NE(info.out_grads.size(), 0)
+        <<"The node has empty out_grads already. Cannot DropGrads again.";
+      for (auto grad : info.out_grads) {
+        grad.ReInit();
+      }
+      info.out_grads.clear();
+      info.grad_req = kNullOp;
+    }
   }
 }
 
 void Imperative::GetBackwardDependency(const nnvm::ObjectPtr& node,
                                        uint32_t num_inputs,
                                        uint32_t num_outputs,
-                                       std::vector<bool>* p_save_inputs,
-                                       std::vector<bool>* p_save_outputs) {
+                                       std::vector<bool> *p_save_inputs,
+                                       std::vector<bool> *p_save_outputs) {
   static auto& fgradient          = nnvm::Op::GetAttr<nnvm::FGradient>("FGradient");
   std::vector<bool>& save_inputs  = *p_save_inputs;
   std::vector<bool>& save_outputs = *p_save_outputs;
@@ -488,6 +513,12 @@ std::vector<NDArray*> Imperative::Backward(const std::vector<NDArray*>& outputs,
     }
     CHECK_GT(xs.size(), 0) << "There are no inputs in computation graph that require gradients.";
   }
+  std::vector<ObjectPtr> nleaf_vars = ListNonleafVariables(sym);
+  std::vector<NodeEntry> us;
+  us.reserve(nleaf_vars.size());
+  for (const auto& i : nleaf_vars) {
+    us.emplace_back(NodeEntry{i, 0, 0});
+  }
 
   Graph g_graph = pass::MXGradient(graph,
                                    graph.outputs,
@@ -496,7 +527,10 @@ std::vector<NDArray*> Imperative::Backward(const std::vector<NDArray*>& outputs,
                                    mxnet::AggregateGradient,
                                    nullptr,
                                    zero_ops,
-                                   "_copy");
+                                   "_copy",
+                                   ShapeVector(),
+                                   DTypeVector(),
+                                   us);
   CHECK_EQ(g_graph.outputs.size(), xs.size());
   for (const auto& e : g_graph.outputs) {
     if (e.node->op() == nullptr) {
@@ -575,6 +609,20 @@ std::vector<NDArray*> Imperative::Backward(const std::vector<NDArray*>& outputs,
     arrays[eid]    = x_grads[i - num_forward_outputs];
     ref_count[eid] = 1;
   }
+  const std::vector<NodeEntry>& us_grads =
+    g_graph.GetAttr<std::vector<NodeEntry>>("nleaf_grads");
+  CHECK_EQ(us_grads.size(), us.size())
+    << "Size of queried nleaf_vars and size of their gradients don't match.";
+  for (size_t i = 0; i < us_grads.size(); i++) {
+    size_t eid = idx.entry_id(us_grads[i]);
+    AGInfo& info = AGInfo::Get(us[i].node);
+    if (arrays[eid]->dtype_ == -1) {
+      arrays[eid] = &info.out_grads[0];
+    } else {
+      info.out_grads[0] = *arrays[eid];
+    }
+    ref_count[eid] = 1;
+  }
 
   // Assign context
   auto vctx = PlaceDevice(idx);
@@ -627,6 +675,11 @@ std::vector<NDArray*> Imperative::Backward(const std::vector<NDArray*>& outputs,
     size_t eid      = idx.entry_id(idx.outputs()[i]);
     array_reqs[eid] = x_reqs[i - num_forward_outputs];
   }
+  for (size_t i = 0; i < us_grads.size(); i++) {
+    size_t eid = idx.entry_id(us_grads[i]);
+    AGInfo& info = AGInfo::Get(us[i].node);
+    array_reqs[eid] = info.grad_req;
+  }
 
   const auto& shapes         = graph.GetAttr<mxnet::ShapeVector>("shape");
   const auto& dtypes         = graph.GetAttr<DTypeVector>("dtype");
@@ -766,4 +819,16 @@ void Imperative::DCInfo::Compute(const NDArray& arr) {
   info.outputs_.clear();
 }
 
+std::vector<nnvm::ObjectPtr> Imperative::ListNonleafVariables(const nnvm::Symbol& sym) const {
+  using namespace nnvm;
+  std::vector<ObjectPtr> ret;
+  DFSVisit(sym.outputs, [&ret](const ObjectPtr& node) {
+    AGInfo& info = AGInfo::Get(node);
+    if (info.out_grads.size() > 0 && !node->is_variable()) {
+      ret.push_back(node);
+    }
+  });
+  return ret;
+}
+
 }  // namespace mxnet
diff --git a/src/nnvm/gradient.cc b/src/nnvm/gradient.cc
index 2609f78..625eaa3 100644
--- a/src/nnvm/gradient.cc
+++ b/src/nnvm/gradient.cc
@@ -62,7 +62,8 @@ Graph BuildGradientGraph(const Graph& src,
                          const std::vector<ObjectPtr>& topo_order,
                          std::unordered_map<const Node*, std::vector<GradEntry> > output_grads,
                          std::function<int(const Node&)> mirror_fun,
-                         const std::unordered_map<const Node*, ObjectPtr>& mirror_map);
+                         const std::unordered_map<const Node*, ObjectPtr>& mirror_map,
+                         const std::vector<NodeEntry>& us = std::vector<NodeEntry>());
 
 /*!
  * \brief Auxiliary function that maps the forward node of the source graph to
@@ -88,6 +89,8 @@ Graph Gradient(Graph src) {
   const std::vector<NodeEntry>& ys_out_grad =
       src.GetAttr<std::vector<NodeEntry> >("grad_ys_out_grad");
   CHECK_EQ(ys.size(), ys_out_grad.size());
+  const std::vector<NodeEntry>& us =
+      src.GetAttr<std::vector<NodeEntry> >("grad_us");
 
   // initialize a topological order of the graph nodes and `output_grads`
   // that maps every operator node to its gradient entries
@@ -120,7 +123,7 @@ Graph Gradient(Graph src) {
   std::unordered_map<const Node*, ObjectPtr> mirror_map;
 
   // complete the backward graph of the src, but without backward mirroring
-  nnvm::Graph gsrc = BuildGradientGraph(src, xs, topo_order, output_grads, nullptr, mirror_map);
+  nnvm::Graph gsrc = BuildGradientGraph(src, xs, topo_order, output_grads, nullptr, mirror_map, us);
   if (mirror_fun == nullptr) {
     return gsrc;  // Gradient pass without mirroring ends here.
   }
@@ -504,12 +507,14 @@ inline bool CheckGradAllZero(const std::vector<NodeEntry>& grads,
   return true;
 }
 
+
 Graph BuildGradientGraph(const Graph& src,
                          const std::vector<NodeEntry>& xs,
                          const std::vector<ObjectPtr>& topo_order,
                          std::unordered_map<const Node*, std::vector<GradEntry> > output_grads,
                          std::function<int(const Node&)> mirror_fun,
-                         const std::unordered_map<const Node*, ObjectPtr>& mirror_map) {
+                         const std::unordered_map<const Node*, ObjectPtr>& mirror_map,
+                         const std::vector<NodeEntry>& us) {
   static auto& grad_fun_map = Op::GetAttr<nnvm::FGradient>("FGradient");
 
   // gradient aggregation function
@@ -617,7 +622,7 @@ Graph BuildGradientGraph(const Graph& src,
       CHECK(src_fwd_node->inputs.size() <= input_grads.size());
       for (auto input_iter = src_fwd_node->inputs.begin(); input_iter != src_fwd_node->inputs.end();
            ++input_iter, ++input_grad_iter) {
-        // propagate the input gradients to the output gradients of the input nodes
+        // propagate the input_grads to the corresponding GradEntries mapped by output_grads
         output_grads[input_iter->node.get()][input_iter->index].grads.emplace_back(
             std::move(*input_grad_iter));
       }
@@ -661,6 +666,20 @@ Graph BuildGradientGraph(const Graph& src,
       ret.outputs[kv.second.second] = kv.first;
     }
   }
+
+  // Take the us' grad NodeEntry and store them in graph.attrs
+  std::vector<NodeEntry> nleaf_grads;
+  nleaf_grads.reserve(us.size());
+  for (const NodeEntry& e : us) {
+    GradEntry& entry = output_grads[e.node.get()][e.index];
+    // aggregate sum if it hasn't been
+    if (entry.sum.node.get() == nullptr) {
+      entry.sum = agg_fun(std::move(entry.grads));
+    }
+    nleaf_grads.push_back(entry.sum);
+  }
+  ret.attrs["nleaf_grads"] = std::make_shared<any>(std::move(nleaf_grads));
+
   return ret;
 }
 
@@ -673,7 +692,8 @@ NNVM_REGISTER_PASS(MXGradient)
     .depend_graph_attr("grad_xs")
     .depend_graph_attr("in_arg_shapes")
     .depend_graph_attr("in_arg_dtypes")
-    .depend_graph_attr("grad_ys_out_grad");
+    .depend_graph_attr("grad_ys_out_grad")
+    .depend_graph_attr("grad_us");
 
 }  // namespace
 
diff --git a/tests/python/unittest/test_autograd.py b/tests/python/unittest/test_autograd.py
index 554d830..c48d204 100644
--- a/tests/python/unittest/test_autograd.py
+++ b/tests/python/unittest/test_autograd.py
@@ -243,7 +243,7 @@ def test_detach_updated_grad():
     assert x._fresh_grad == False
 
 
-def test_retain_grad():
+def test_retain_graph():
     x = mx.nd.ones((2, 2))
     dx = mx.nd.zeros((2, 2))
     mark_variables([x], [dx], grad_reqs='add')
@@ -519,3 +519,68 @@ def test_gradient():
     dx.backward()
     assert abs(x.grad.asscalar() - 2.71828175) < 1e-7
 
+def test_retain_grad_drop_grad():
+    x = nd.array([1,2,3,4])
+    x.attach_grad()
+    y = nd.array([5,6,7,8])
+    y.attach_grad()
+
+    with mx.autograd.record():
+        u = x * y
+        z = u * x
+
+    u.attach_grad()
+    z.attach_grad()
+    out_grad = nd.array([10, 10, 10, 10])
+    z.backward(out_grad, retain_graph=True)
+    
+    assert (u.grad == out_grad * x).asnumpy().all()
+    assert (z.grad == out_grad).asnumpy().all()
+    assert (x.grad == out_grad * 2 * x * y).asnumpy().all()
+    assert (y.grad == out_grad * x*x).asnumpy().all()
+
+    u.drop_grad()
+    z.drop_grad()
+    y.drop_grad()
+    out_grad = nd.array([0.1, 0.1, 0.1, 0.1])
+    z.backward(out_grad)
+
+    assert u.grad is None and z.grad is None and y.grad is None
+    assert (x.grad == out_grad * 2 * x * y).asnumpy().all()
+
+def test_retain_grad_drop_grad_gluon():
+    class CompBlock(mx.gluon.HybridBlock):
+        def __init__(self):
+            super().__init__()
+            self.marked_var = None
+        def forward(self, a, b):
+            out1 = a*b
+            out2 = out1 * a
+            self.marked_var = out1
+            return out2
+    x = mx.np.array([1,2,3,4])
+    y = mx.np.array([5,6,7,8])
+    x.attach_grad()
+    y.attach_grad()
+    block2 = CompBlock()
+    block2.initialize()
+    # block2.hybridize()
+    with mx.autograd.record():
+        z = block2(x, y)
+    u = block2.marked_var
+    u.attach_grad()
+    z.attach_grad()
+    z.backward(retain_graph=True)
+
+    assert (u.grad == x).all()
+    assert (z.grad == mx.np.array([1,1,1,1])).all()
+    assert (x.grad == 2 * x * y).all()
+    assert (y.grad == x*x).all()
+
+    u.drop_grad()
+    z.drop_grad()
+    y.drop_grad()
+    z.backward()
+
+    assert u.grad is None and z.grad is None and y.grad is None
+    assert (x.grad == 2 * x * y).all()