You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/05 00:12:39 UTC

[GitHub] piiswrong closed pull request #8322: fix custom op error when using auxiliary states

piiswrong closed pull request #8322: fix custom op error when using auxiliary states
URL: https://github.com/apache/incubator-mxnet/pull/8322
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/ps-lite b/ps-lite
index acdb698fa3..bdd4c67e9e 160000
--- a/ps-lite
+++ b/ps-lite
@@ -1 +1 @@
-Subproject commit acdb698fa3bb80929ef83bb37c705f025e119b82
+Subproject commit bdd4c67e9e34dc0b8350ce306b0caa737eb31c83
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 66c261b880..eaaf521645 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -101,7 +101,7 @@ def __init__(self, rescale_grad=1., param_idx2name=None, wd=0.,
         assert isinstance(param_idx2name, dict), \
             'param_idx2name should be a dict of param indexes to names.'
         self.idx2name = param_idx2name.copy()
-        self.sym = sym
+        self.sym_info = (sym.attr_dict(), sym.list_arguments()) if sym is not None else ()
         self.param_dict = param_dict if param_dict else {}
 
         self.set_lr_mult({})
@@ -321,9 +321,9 @@ def set_lr_mult(self, args_lr_mult):
             compatibility, and we recommend to use the name instead.
         """
         self.lr_mult = {}
-        if self.sym is not None:
-            attr = self.sym.attr_dict()
-            for name in self.sym.list_arguments():
+        if self.sym_info:
+            attr, arg_names = self.sym_info
+            for name in arg_names:
                 if name in attr and '__lr_mult__' in attr[name]:
                     self.lr_mult[name] = float(attr[name]['__lr_mult__'])
         self.lr_mult.update(args_lr_mult)
@@ -358,9 +358,9 @@ def set_wd_mult(self, args_wd_mult):
         for n in self.idx2name.values():
             if not (n.endswith('_weight') or n.endswith('_gamma')):
                 self.wd_mult[n] = 0.0
-        if self.sym is not None:
-            attr = self.sym.attr_dict()
-            for name in self.sym.list_arguments():
+        if self.sym_info:
+            attr, arg_names = self.sym_info
+            for name in arg_names:
                 if name in attr and '__wd_mult__' in attr[name]:
                     self.wd_mult[name] = float(attr[name]['__wd_mult__'])
         self.wd_mult.update(args_wd_mult)
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index eb99aabf11..c7eadcc5c8 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -139,7 +139,7 @@ std::vector<nnvm::NodeEntry> Imperative::CachedOp::Gradient(
     const std::vector<nnvm::NodeEntry>& ograds) {
   using namespace nnvm;
   static const auto _backward_CachedOp = Op::Get("_backward_CachedOp");
-  static const auto _CachedOp_NoGrad = Op::Get("_CachedOp_NoGrad");
+  static const auto _NoGrad = Op::Get("_NoGradient");
 
   auto p = Node::Create();
   p->attrs.op = _backward_CachedOp;
@@ -155,13 +155,12 @@ std::vector<nnvm::NodeEntry> Imperative::CachedOp::Gradient(
   const auto& auxs = mutable_input_nodes();
   if (auxs.size()) {
     auto nop = Node::Create();
-    nop->attrs.op = _CachedOp_NoGrad;
-    nop->attrs.parsed = static_cast<uint32_t>(auxs.size());
-    nop->control_deps.push_back(node);
+    nop->attrs.op = _NoGrad;
+    nop->attrs.name = "NoGradient";
     uint32_t j = 0, k = 0;
     for (const auto& i : fwd_graph_.indexed_graph().input_nodes()) {
       if (auxs.count(i)) {
-        ret.emplace_back(NodeEntry{nop, j++, 0});
+        ret.emplace_back(NodeEntry{nop, 0, 0});
       } else {
         ret.emplace_back(NodeEntry{p, k++, 0});
       }
@@ -475,11 +474,4 @@ NNVM_REGISTER_OP(_backward_CachedOp)
 .set_attr<bool>("TIsLayerOpBackward", true)
 .set_attr<bool>("TIsBackward", true);
 
-NNVM_REGISTER_OP(_CachedOp_NoGrad)
-.set_num_inputs(0)
-.set_num_outputs([](const NodeAttrs& attrs) {
-    const uint32_t& nout = nnvm::get<uint32_t>(attrs.parsed);
-    return nout;
-  });
-
 }  // namespace mxnet
diff --git a/src/operator/custom/custom.cc b/src/operator/custom/custom.cc
index 456c39c17b..f0f7f2d024 100644
--- a/src/operator/custom/custom.cc
+++ b/src/operator/custom/custom.cc
@@ -212,10 +212,19 @@ std::vector<nnvm::NodeEntry> Gradient(
   }
 
   std::vector<nnvm::NodeEntry> ret;
-  for (index_t i = 0; i < g->num_outputs(); ++i) {
+  for (index_t i = 0; i < params.num_args; ++i) {
     ret.emplace_back(nnvm::NodeEntry{g, i, 0});
   }
 
+  if (params.num_auxs) {
+    nnvm::NodePtr ng = nnvm::Node::Create();
+    ng->attrs.op = nnvm::Op::Get("_NoGradient");
+    ng->attrs.name = "NoGradient";
+    for (index_t i = 0; i < params.num_auxs; ++i) {
+      ret.emplace_back(nnvm::NodeEntry{ng, 0, 0});
+    }
+  }
+
   return ret;
 }
 
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 024e089832..fa8a28bf94 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -3569,9 +3569,11 @@ def test_custom_op():
     class Sqr(mx.operator.CustomOp):
         def forward(self, is_train, req, in_data, out_data, aux):
             self.assign(out_data[0], req[0], in_data[0]*in_data[0])
+            aux[0][:] = 1
 
         def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
             self.assign(in_grad[0], req[0], 2*in_data[0]*out_grad[0])
+            assert (aux[0].asnumpy() == 1).all()
 
     @mx.operator.register("sqr")
     class SqrProp(mx.operator.CustomOpProp):
@@ -3584,31 +3586,34 @@ def list_arguments(self):
         def list_outputs(self):
             return ['output']
 
+        def list_auxiliary_states(self):
+            return ['aux']
+
         def infer_shape(self, in_shape):
-            return in_shape, [in_shape[0]], []
+            return in_shape, [in_shape[0]], [in_shape[0]]
 
         def infer_type(self, in_type):
-            return in_type, [in_type[0]], []
+            return in_type, [in_type[0]], [in_type[0]]
 
         def create_operator(self, ctx, shapes, dtypes):
             return Sqr()
 
     data = mx.symbol.Variable('data')
-    op = mx.symbol.Custom(data=data, name='sqr', op_type='sqr')
+    aux = mx.symbol.Variable('aux')
+    op = mx.symbol.Custom(data=data, aux=aux, name='sqr', op_type='sqr')
     x = mx.nd.array(np.random.uniform(-1, 1, size=(4, 10)))
-    check_numeric_gradient(op, [x])
+    aux = mx.nd.zeros_like(x)
+    check_numeric_gradient(op, [x], [aux])
 
-    data = mx.symbol.Variable('data')
     data = mx.symbol.cast(data, dtype='float64')
-    op = mx.symbol.Custom(data=data, name='sqr', op_type='sqr')
     op = mx.symbol.cast(op, dtype='float32')
     x = mx.nd.array(np.random.uniform(-1, 1, size=(4, 10)))
-    check_numeric_gradient(op, [x])
+    aux = mx.nd.zeros_like(x)
+    check_numeric_gradient(op, [x], [aux])
 
-    dx = mx.nd.zeros_like(x)
-    mx.contrib.autograd.mark_variables([x], [dx])
+    x.attach_grad()
     with mx.contrib.autograd.train_section():
-        y = mx.nd.Custom(x, op_type='sqr')
+        y = mx.nd.Custom(x, aux, op_type='sqr')
         y.backward()
 
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services