You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/13 23:01:29 UTC

[GitHub] piiswrong closed pull request #8543: fix custom op error when using auxiliary states

piiswrong closed pull request #8543: fix custom op error when using auxiliary states
URL: https://github.com/apache/incubator-mxnet/pull/8543
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 66c261b880..eaaf521645 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -101,7 +101,7 @@ def __init__(self, rescale_grad=1., param_idx2name=None, wd=0.,
         assert isinstance(param_idx2name, dict), \
             'param_idx2name should be a dict of param indexes to names.'
         self.idx2name = param_idx2name.copy()
-        self.sym = sym
+        self.sym_info = (sym.attr_dict(), sym.list_arguments()) if sym is not None else ()
         self.param_dict = param_dict if param_dict else {}
 
         self.set_lr_mult({})
@@ -321,9 +321,9 @@ def set_lr_mult(self, args_lr_mult):
             compatibility, and we recommend to use the name instead.
         """
         self.lr_mult = {}
-        if self.sym is not None:
-            attr = self.sym.attr_dict()
-            for name in self.sym.list_arguments():
+        if self.sym_info:
+            attr, arg_names = self.sym_info
+            for name in arg_names:
                 if name in attr and '__lr_mult__' in attr[name]:
                     self.lr_mult[name] = float(attr[name]['__lr_mult__'])
         self.lr_mult.update(args_lr_mult)
@@ -358,9 +358,9 @@ def set_wd_mult(self, args_wd_mult):
         for n in self.idx2name.values():
             if not (n.endswith('_weight') or n.endswith('_gamma')):
                 self.wd_mult[n] = 0.0
-        if self.sym is not None:
-            attr = self.sym.attr_dict()
-            for name in self.sym.list_arguments():
+        if self.sym_info:
+            attr, arg_names = self.sym_info
+            for name in arg_names:
                 if name in attr and '__wd_mult__' in attr[name]:
                     self.wd_mult[name] = float(attr[name]['__wd_mult__'])
         self.wd_mult.update(args_wd_mult)
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 60d66db485..e9d801fae4 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -136,7 +136,7 @@ std::vector<nnvm::NodeEntry> Imperative::CachedOp::Gradient(
     const std::vector<nnvm::NodeEntry>& ograds) {
   using namespace nnvm;
   static const auto _backward_CachedOp = Op::Get("_backward_CachedOp");
-  static const auto _CachedOp_NoGrad = Op::Get("_CachedOp_NoGrad");
+  static const auto _NoGrad = Op::Get("_NoGradient");
 
   auto p = Node::Create();
   p->attrs.op = _backward_CachedOp;
@@ -152,13 +152,12 @@ std::vector<nnvm::NodeEntry> Imperative::CachedOp::Gradient(
   const auto& auxs = mutable_input_nodes();
   if (auxs.size()) {
     auto nop = Node::Create();
-    nop->attrs.op = _CachedOp_NoGrad;
-    nop->attrs.parsed = static_cast<uint32_t>(auxs.size());
-    nop->control_deps.push_back(node);
+    nop->attrs.op = _NoGrad;
+    nop->attrs.name = "NoGradient";
     uint32_t j = 0, k = 0;
     for (const auto& i : fwd_graph_.indexed_graph().input_nodes()) {
       if (auxs.count(i)) {
-        ret.emplace_back(NodeEntry{nop, j++, 0});
+        ret.emplace_back(NodeEntry{nop, 0, 0});
       } else {
         ret.emplace_back(NodeEntry{p, k++, 0});
       }
@@ -491,11 +490,4 @@ NNVM_REGISTER_OP(_backward_CachedOp)
 .set_attr<bool>("TIsLayerOpBackward", true)
 .set_attr<bool>("TIsBackward", true);
 
-NNVM_REGISTER_OP(_CachedOp_NoGrad)
-.set_num_inputs(0)
-.set_num_outputs([](const NodeAttrs& attrs) {
-    const uint32_t& nout = nnvm::get<uint32_t>(attrs.parsed);
-    return nout;
-  });
-
 }  // namespace mxnet
diff --git a/src/operator/custom/custom.cc b/src/operator/custom/custom.cc
index 456c39c17b..683423f969 100644
--- a/src/operator/custom/custom.cc
+++ b/src/operator/custom/custom.cc
@@ -212,9 +212,17 @@ std::vector<nnvm::NodeEntry> Gradient(
   }
 
   std::vector<nnvm::NodeEntry> ret;
-  for (index_t i = 0; i < g->num_outputs(); ++i) {
+  for (index_t i = 0; i < params.num_args; ++i) {
     ret.emplace_back(nnvm::NodeEntry{g, i, 0});
   }
+  if (params.num_auxs) {
+    nnvm::NodePtr ng = nnvm::Node::Create();
+    ng->attrs.op = nnvm::Op::Get("_NoGradient");
+    ng->attrs.name = "NoGradient";
+    for (index_t i = 0; i < params.num_auxs; ++i) {
+      ret.emplace_back(nnvm::NodeEntry{ng, 0, 0});
+    }
+  }
 
   return ret;
 }
@@ -225,8 +233,8 @@ OpStatePtr CreateState(const NodeAttrs& attrs, Context ctx,
                        const std::vector<int>& in_type) {
   const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);
 
-  std::vector<uint32_t*> shapes(params.num_args);
-  std::vector<int> ndims(params.num_args);
+  std::vector<uint32_t*> shapes(in_shape.size());
+  std::vector<int> ndims(in_shape.size());
   size_t buff_size = 0;
   for (const auto& i : in_shape) buff_size += i.ndim();
   std::vector<uint32_t> buff(buff_size);
@@ -245,7 +253,7 @@ OpStatePtr CreateState(const NodeAttrs& attrs, Context ctx,
   MXCallbackList *op_info = new MXCallbackList;
   CHECK(reinterpret_cast<CustomOpCreateFunc>(
       params.info->callbacks[kCustomOpPropCreateOperator])(
-          os.str().c_str(), params.num_args, shapes.data(), ndims.data(), in_type.data(),
+          os.str().c_str(), shapes.size(), shapes.data(), ndims.data(), in_type.data(),
           op_info, params.info->contexts[kCustomOpPropCreateOperator]));
 
   CustomParam state = params;
diff --git a/src/operator/tensor/matrix_op.cu b/src/operator/tensor/matrix_op.cu
index 237b87296c..c21a0e7cc3 100644
--- a/src/operator/tensor/matrix_op.cu
+++ b/src/operator/tensor/matrix_op.cu
@@ -47,7 +47,7 @@ NNVM_REGISTER_OP(_backward_slice)
 NNVM_REGISTER_OP(_slice_assign)
 .set_attr<FCompute>("FCompute<gpu>", SliceAssignOpForward<gpu>);
 
-NNVM_REGISTER_OP(_crop_assign_scalar)
+NNVM_REGISTER_OP(_slice_assign_scalar)
 .set_attr<FCompute>("FCompute<gpu>", SliceAssignScalarOpForward<gpu>);
 
 NNVM_REGISTER_OP(slice_axis)
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 93dc4a0534..3484b18d02 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -3571,9 +3571,11 @@ def test_custom_op():
     class Sqr(mx.operator.CustomOp):
         def forward(self, is_train, req, in_data, out_data, aux):
             self.assign(out_data[0], req[0], in_data[0]*in_data[0])
+            aux[0][:] = 1
 
         def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
             self.assign(in_grad[0], req[0], 2*in_data[0]*out_grad[0])
+            assert (aux[0].asnumpy() == 1).all()
 
     @mx.operator.register("sqr")
     class SqrProp(mx.operator.CustomOpProp):
@@ -3586,31 +3588,34 @@ def list_arguments(self):
         def list_outputs(self):
             return ['output']
 
+        def list_auxiliary_states(self):
+            return ['aux']
+
         def infer_shape(self, in_shape):
-            return in_shape, [in_shape[0]], []
+            return in_shape, [in_shape[0]], [in_shape[0]]
 
         def infer_type(self, in_type):
-            return in_type, [in_type[0]], []
+            return in_type, [in_type[0]], [in_type[0]]
 
         def create_operator(self, ctx, shapes, dtypes):
             return Sqr()
 
     data = mx.symbol.Variable('data')
-    op = mx.symbol.Custom(data=data, name='sqr', op_type='sqr')
+    aux = mx.symbol.Variable('aux')
+    op = mx.symbol.Custom(data=data, aux=aux, name='sqr', op_type='sqr')
     x = mx.nd.array(np.random.uniform(-1, 1, size=(4, 10)))
-    check_numeric_gradient(op, [x])
+    aux = mx.nd.zeros_like(x)
+    check_numeric_gradient(op, [x], [aux])
 
-    data = mx.symbol.Variable('data')
     data = mx.symbol.cast(data, dtype='float64')
-    op = mx.symbol.Custom(data=data, name='sqr', op_type='sqr')
     op = mx.symbol.cast(op, dtype='float32')
     x = mx.nd.array(np.random.uniform(-1, 1, size=(4, 10)))
-    check_numeric_gradient(op, [x])
+    aux = mx.nd.zeros_like(x)
+    check_numeric_gradient(op, [x], [aux])
 
-    dx = mx.nd.zeros_like(x)
-    mx.contrib.autograd.mark_variables([x], [dx])
+    x.attach_grad()
     with mx.contrib.autograd.train_section():
-        y = mx.nd.Custom(x, op_type='sqr')
+        y = mx.nd.Custom(x, aux, op_type='sqr')
         y.backward()
 
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services