You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/15 18:27:21 UTC

[incubator-mxnet] branch master updated: fix custom op error when using auxiliary states (#8637)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new e27f054  fix custom op error when using auxiliary states (#8637)
e27f054 is described below

commit e27f054b4693b556cddb7f7c262651862ca3d02d
Author: Eric Junyuan Xie <pi...@users.noreply.github.com>
AuthorDate: Wed Nov 15 10:27:18 2017 -0800

    fix custom op error when using auxiliary states (#8637)
    
    * fix custom op error when using auxiliary states
    
    * Update custom.cc
    
    * Update custom.cc
---
 python/mxnet/optimizer.py              | 14 +++++++-------
 src/imperative/cached_op.cc            | 16 ++++------------
 src/operator/custom/custom.cc          | 16 ++++++++++++----
 src/operator/tensor/matrix_op.cu       |  2 +-
 tests/python/unittest/test_operator.py | 25 +++++++++++++++----------
 5 files changed, 39 insertions(+), 34 deletions(-)

diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 66c261b..eaaf521 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -101,7 +101,7 @@ class Optimizer(object):
         assert isinstance(param_idx2name, dict), \
             'param_idx2name should be a dict of param indexes to names.'
         self.idx2name = param_idx2name.copy()
-        self.sym = sym
+        self.sym_info = (sym.attr_dict(), sym.list_arguments()) if sym is not None else ()
         self.param_dict = param_dict if param_dict else {}
 
         self.set_lr_mult({})
@@ -321,9 +321,9 @@ class Optimizer(object):
             compatibility, and we recommend to use the name instead.
         """
         self.lr_mult = {}
-        if self.sym is not None:
-            attr = self.sym.attr_dict()
-            for name in self.sym.list_arguments():
+        if self.sym_info:
+            attr, arg_names = self.sym_info
+            for name in arg_names:
                 if name in attr and '__lr_mult__' in attr[name]:
                     self.lr_mult[name] = float(attr[name]['__lr_mult__'])
         self.lr_mult.update(args_lr_mult)
@@ -358,9 +358,9 @@ class Optimizer(object):
         for n in self.idx2name.values():
             if not (n.endswith('_weight') or n.endswith('_gamma')):
                 self.wd_mult[n] = 0.0
-        if self.sym is not None:
-            attr = self.sym.attr_dict()
-            for name in self.sym.list_arguments():
+        if self.sym_info:
+            attr, arg_names = self.sym_info
+            for name in arg_names:
                 if name in attr and '__wd_mult__' in attr[name]:
                     self.wd_mult[name] = float(attr[name]['__wd_mult__'])
         self.wd_mult.update(args_wd_mult)
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 60d66db..e9d801f 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -136,7 +136,7 @@ std::vector<nnvm::NodeEntry> Imperative::CachedOp::Gradient(
     const std::vector<nnvm::NodeEntry>& ograds) {
   using namespace nnvm;
   static const auto _backward_CachedOp = Op::Get("_backward_CachedOp");
-  static const auto _CachedOp_NoGrad = Op::Get("_CachedOp_NoGrad");
+  static const auto _NoGrad = Op::Get("_NoGradient");
 
   auto p = Node::Create();
   p->attrs.op = _backward_CachedOp;
@@ -152,13 +152,12 @@ std::vector<nnvm::NodeEntry> Imperative::CachedOp::Gradient(
   const auto& auxs = mutable_input_nodes();
   if (auxs.size()) {
     auto nop = Node::Create();
-    nop->attrs.op = _CachedOp_NoGrad;
-    nop->attrs.parsed = static_cast<uint32_t>(auxs.size());
-    nop->control_deps.push_back(node);
+    nop->attrs.op = _NoGrad;
+    nop->attrs.name = "NoGradient";
     uint32_t j = 0, k = 0;
     for (const auto& i : fwd_graph_.indexed_graph().input_nodes()) {
       if (auxs.count(i)) {
-        ret.emplace_back(NodeEntry{nop, j++, 0});
+        ret.emplace_back(NodeEntry{nop, 0, 0});
       } else {
         ret.emplace_back(NodeEntry{p, k++, 0});
       }
@@ -491,11 +490,4 @@ NNVM_REGISTER_OP(_backward_CachedOp)
 .set_attr<bool>("TIsLayerOpBackward", true)
 .set_attr<bool>("TIsBackward", true);
 
-NNVM_REGISTER_OP(_CachedOp_NoGrad)
-.set_num_inputs(0)
-.set_num_outputs([](const NodeAttrs& attrs) {
-    const uint32_t& nout = nnvm::get<uint32_t>(attrs.parsed);
-    return nout;
-  });
-
 }  // namespace mxnet
diff --git a/src/operator/custom/custom.cc b/src/operator/custom/custom.cc
index 456c39c..683423f 100644
--- a/src/operator/custom/custom.cc
+++ b/src/operator/custom/custom.cc
@@ -212,9 +212,17 @@ std::vector<nnvm::NodeEntry> Gradient(
   }
 
   std::vector<nnvm::NodeEntry> ret;
-  for (index_t i = 0; i < g->num_outputs(); ++i) {
+  for (index_t i = 0; i < params.num_args; ++i) {
     ret.emplace_back(nnvm::NodeEntry{g, i, 0});
   }
+  if (params.num_auxs) {
+    nnvm::NodePtr ng = nnvm::Node::Create();
+    ng->attrs.op = nnvm::Op::Get("_NoGradient");
+    ng->attrs.name = "NoGradient";
+    for (index_t i = 0; i < params.num_auxs; ++i) {
+      ret.emplace_back(nnvm::NodeEntry{ng, 0, 0});
+    }
+  }
 
   return ret;
 }
@@ -225,8 +233,8 @@ OpStatePtr CreateState(const NodeAttrs& attrs, Context ctx,
                        const std::vector<int>& in_type) {
   const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);
 
-  std::vector<uint32_t*> shapes(params.num_args);
-  std::vector<int> ndims(params.num_args);
+  std::vector<uint32_t*> shapes(in_shape.size());
+  std::vector<int> ndims(in_shape.size());
   size_t buff_size = 0;
   for (const auto& i : in_shape) buff_size += i.ndim();
   std::vector<uint32_t> buff(buff_size);
@@ -245,7 +253,7 @@ OpStatePtr CreateState(const NodeAttrs& attrs, Context ctx,
   MXCallbackList *op_info = new MXCallbackList;
   CHECK(reinterpret_cast<CustomOpCreateFunc>(
       params.info->callbacks[kCustomOpPropCreateOperator])(
-          os.str().c_str(), params.num_args, shapes.data(), ndims.data(), in_type.data(),
+          os.str().c_str(), shapes.size(), shapes.data(), ndims.data(), in_type.data(),
           op_info, params.info->contexts[kCustomOpPropCreateOperator]));
 
   CustomParam state = params;
diff --git a/src/operator/tensor/matrix_op.cu b/src/operator/tensor/matrix_op.cu
index 237b872..c21a0e7 100644
--- a/src/operator/tensor/matrix_op.cu
+++ b/src/operator/tensor/matrix_op.cu
@@ -47,7 +47,7 @@ NNVM_REGISTER_OP(_backward_slice)
 NNVM_REGISTER_OP(_slice_assign)
 .set_attr<FCompute>("FCompute<gpu>", SliceAssignOpForward<gpu>);
 
-NNVM_REGISTER_OP(_crop_assign_scalar)
+NNVM_REGISTER_OP(_slice_assign_scalar)
 .set_attr<FCompute>("FCompute<gpu>", SliceAssignScalarOpForward<gpu>);
 
 NNVM_REGISTER_OP(slice_axis)
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 93dc4a0..3484b18 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -3571,9 +3571,11 @@ def test_custom_op():
     class Sqr(mx.operator.CustomOp):
         def forward(self, is_train, req, in_data, out_data, aux):
             self.assign(out_data[0], req[0], in_data[0]*in_data[0])
+            aux[0][:] = 1
 
         def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
             self.assign(in_grad[0], req[0], 2*in_data[0]*out_grad[0])
+            assert (aux[0].asnumpy() == 1).all()
 
     @mx.operator.register("sqr")
     class SqrProp(mx.operator.CustomOpProp):
@@ -3586,31 +3588,34 @@ def test_custom_op():
         def list_outputs(self):
             return ['output']
 
+        def list_auxiliary_states(self):
+            return ['aux']
+
         def infer_shape(self, in_shape):
-            return in_shape, [in_shape[0]], []
+            return in_shape, [in_shape[0]], [in_shape[0]]
 
         def infer_type(self, in_type):
-            return in_type, [in_type[0]], []
+            return in_type, [in_type[0]], [in_type[0]]
 
         def create_operator(self, ctx, shapes, dtypes):
             return Sqr()
 
     data = mx.symbol.Variable('data')
-    op = mx.symbol.Custom(data=data, name='sqr', op_type='sqr')
+    aux = mx.symbol.Variable('aux')
+    op = mx.symbol.Custom(data=data, aux=aux, name='sqr', op_type='sqr')
     x = mx.nd.array(np.random.uniform(-1, 1, size=(4, 10)))
-    check_numeric_gradient(op, [x])
+    aux = mx.nd.zeros_like(x)
+    check_numeric_gradient(op, [x], [aux])
 
-    data = mx.symbol.Variable('data')
     data = mx.symbol.cast(data, dtype='float64')
-    op = mx.symbol.Custom(data=data, name='sqr', op_type='sqr')
     op = mx.symbol.cast(op, dtype='float32')
     x = mx.nd.array(np.random.uniform(-1, 1, size=(4, 10)))
-    check_numeric_gradient(op, [x])
+    aux = mx.nd.zeros_like(x)
+    check_numeric_gradient(op, [x], [aux])
 
-    dx = mx.nd.zeros_like(x)
-    mx.contrib.autograd.mark_variables([x], [dx])
+    x.attach_grad()
     with mx.contrib.autograd.train_section():
-        y = mx.nd.Custom(x, op_type='sqr')
+        y = mx.nd.Custom(x, aux, op_type='sqr')
         y.backward()
 
 

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].