You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/20 20:13:22 UTC

[GitHub] piiswrong closed pull request #8725: [WIP] Fix custom op when used with need_top_grad=False

piiswrong closed pull request #8725: [WIP] Fix custom op when used with need_top_grad=False
URL: https://github.com/apache/incubator-mxnet/pull/8725
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/operator.py b/python/mxnet/operator.py
index f515bf83b8..259b6d5655 100644
--- a/python/mxnet/operator.py
+++ b/python/mxnet/operator.py
@@ -563,7 +563,7 @@ def infer_storage_type_backward(self, in_stype):
             list of aux stypes calculated from in_stype,
             in the same order as declared in list_auxiliary_states.
         """
-        return in_stype, [in_stype[0]]*len(self.list_outputs()), \
+        return in_stype, [in_stype[0]]*len(self.list_arguments()), \
             [in_stype[0]]*len(self.list_auxiliary_states())
 
     def list_outputs(self):
@@ -717,10 +717,10 @@ def infer_storage_type_backward_entry(num_tensor, tensor_stypes, _):
                     n_in = len(op_prop.list_arguments())
                     n_out = len(op_prop.list_outputs())
                     n_aux = len(op_prop.list_auxiliary_states())
-                    total_inputs = n_in + 2 * n_out
+                    total_inputs = (n_in + 2 * n_out) if op_prop.need_top_grad_ else (n_in + n_out)
                     total_aux = n_aux
                     total_outputs = n_in
-                    assert num_tensor == (2 * n_in + 2 * n_out + n_aux)
+                    assert num_tensor == (total_inputs + total_aux + total_outputs)
 
                     stypes = [_STORAGE_TYPE_ID_TO_STR[tensor_stypes[i]] \
                              for i in range(total_inputs + total_aux)]
@@ -923,6 +923,10 @@ def backward_entry(num_ndarray, ndarraies, tags, reqs, is_train, _):
                         try:
                             tensors = [[] for i in range(5)]
                             for i in range(num_ndarray):
+                                # continue for ograd when need_top_grad_ is False
+                                # This will cause len(ograd) = 0 when passed to backward
+                                if not op_prop.need_top_grad_ and tags[i] == 3:
+                                    continue
                                 if tags[i] == 2 or tags[i] == 4:
                                     tensors[tags[i]].append(_ndarray_cls(cast(ndarraies[i],
                                                                               NDArrayHandle),
diff --git a/src/operator/custom/custom.cc b/src/operator/custom/custom.cc
index 609f6acd2f..e26040a636 100644
--- a/src/operator/custom/custom.cc
+++ b/src/operator/custom/custom.cc
@@ -370,7 +370,8 @@ inline bool BackwardInferStorageType(const nnvm::NodeAttrs& attrs,
   }
 
   std::vector<int> stypes;
-  stypes.reserve(params.num_outs * 2 + params.num_args * 2 + params.num_auxs);
+  const size_t num_bwd_args = params.bwd_idx.size();
+  stypes.reserve(num_bwd_args + params.num_args + params.num_auxs);
   for (size_t i = 0; i < iattr->size(); ++i) {
     stypes.push_back((*iattr)[i]);
   }
@@ -382,17 +383,17 @@ inline bool BackwardInferStorageType(const nnvm::NodeAttrs& attrs,
       params.info->callbacks[kCustomOpPropBackwardInferStorageType])(
       stypes.size(), stypes.data(),
       params.info->contexts[kCustomOpPropBackwardInferStorageType]));
-  for (size_t i = 0; i < 2 * params.num_outs + params.num_args; ++i) {
+  for (size_t i = 0; i < num_bwd_args; ++i) {
     STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, stypes[i]);
   }
   for (size_t i = 0; i < params.num_args; ++i) {
     STORAGE_TYPE_ASSIGN_CHECK(
-        *oattr, i, stypes[i + 2 * params.num_outs + params.num_args]);
+        *oattr, i, stypes[i + num_bwd_args]);
   }
   for (size_t i = 0; i < params.num_auxs; ++i) {
     STORAGE_TYPE_ASSIGN_CHECK(
-        *iattr, i + 2 * params.num_outs + params.num_args,
-        stypes[i + 2 * params.num_outs + 2 * params.num_args]);
+        *iattr, i + num_bwd_args,
+        stypes[i + num_bwd_args + params.num_args]);
   }
 
   DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 55a3a57218..a5ea7b5dbe 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -3653,6 +3653,77 @@ def create_operator(self, ctx, shapes, dtypes):
     assert (y.stype == 'csr')
     assert (aux.stype == 'csr')
 
+    # test for backward compatibility, i.e. the correctness of default implementation of
+    # infer storage in custom operator
+    class Mult(mx.operator.CustomOp):
+        def forward(self, is_train, req, in_data, out_data, aux):
+            self.assign(out_data[0], req[0], in_data[0]*in_data[1])
+
+        def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
+            self.assign(in_grad[0], req[0], in_data[1]*out_grad[0])
+            self.assign(in_grad[1], req[1], in_data[0]*out_grad[0])
+
+    @mx.operator.register("mult")
+    class MultProp(mx.operator.CustomOpProp):
+        def __init__(self):
+            super(MultProp, self).__init__(need_top_grad=True)
+
+        def list_arguments(self):
+            return ['lhs', 'rhs']
+
+        def list_outputs(self):
+            return ['output']
+
+        def infer_shape(self, in_shape):
+            return in_shape, [in_shape[0]], []
+
+        def create_operator(self, ctx, shapes, dtypes):
+            return Mult()
+
+    class MultNoGrad(mx.operator.CustomOp):
+        def forward(self, is_train, req, in_data, out_data, aux):
+            self.assign(out_data[0], req[0], in_data[0]*in_data[1])
+
+        def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
+            self.assign(in_grad[0], req[0], in_data[1])
+            self.assign(in_grad[1], req[1], in_data[0])
+            assert (len(out_grad) == 0)
+
+    @mx.operator.register("mult_no_grad")
+    class MultNoGradProp(mx.operator.CustomOpProp):
+        def __init__(self):
+            super(MultNoGradProp, self).__init__(need_top_grad=False)
+
+        def list_arguments(self):
+            return ['lhs', 'rhs']
+
+        def list_outputs(self):
+            return ['output']
+
+        def infer_shape(self, in_shape):
+            return in_shape, [in_shape[0]], []
+
+        def create_operator(self, ctx, shapes, dtypes):
+            return MultNoGrad()
+
+    lhs = mx.nd.array(np.random.uniform(-1, 1, size=(4, 10)))
+    rhs = mx.nd.array(np.random.uniform(-1, 1, size=(4, 10)))
+    lhs.attach_grad()
+    rhs.attach_grad()
+    with mx.contrib.autograd.train_section():
+        y = mx.nd.Custom(lhs, rhs, op_type='mult')
+        y.backward()
+    mx.nd.waitall()
+    assert_almost_equal(rhs.asnumpy(), lhs.grad.asnumpy())
+    assert_almost_equal(lhs.asnumpy(), rhs.grad.asnumpy())
+
+    with mx.contrib.autograd.train_section():
+        y = mx.nd.Custom(lhs, rhs, op_type='mult_no_grad')
+        y.backward()
+    mx.nd.waitall()
+    assert_almost_equal(rhs.asnumpy(), lhs.grad.asnumpy())
+    assert_almost_equal(lhs.asnumpy(), rhs.grad.asnumpy())
+
 def test_psroipooling():
     for num_rois in [1, 2]:
         for num_classes, num_group in itertools.product([2, 3], [2, 3]):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services