You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/05/09 17:58:36 UTC

[incubator-mxnet] branch master updated: [MXNET-407] Better error handling of NDArray setitem autograd (#10844)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 333e7fe  [MXNET-407] Better error handling of NDArray setitem autograd (#10844)
333e7fe is described below

commit 333e7fef64033572d1b02017309d97e7b91a9343
Author: reminisce <wu...@gmail.com>
AuthorDate: Wed May 9 10:58:31 2018 -0700

    [MXNET-407] Better error handling of NDArray setitem autograd (#10844)
    
    * Initial commit
    
    * More fix
---
 python/mxnet/ndarray/ndarray.py       | 30 ++++++++++--------
 src/imperative/imperative.cc          |  2 +-
 src/operator/tensor/indexing_op.cc    | 57 ++++++++++++++++++++++++++++++-----
 src/operator/tensor/indexing_op.h     |  7 +++--
 tests/python/unittest/test_ndarray.py | 15 ++++++++-
 5 files changed, 87 insertions(+), 24 deletions(-)

diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 2411932..7bfb3c7 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -682,17 +682,20 @@ fixed-size items.
         on the values of slices' steps."""
         shape = self.shape
         if isinstance(key, integer_types):
-            sliced_arr = self._at(key)
-            sliced_arr[:] = value
-            return
-        elif isinstance(key, py_slice):
-            if key.step is None or key.step == 1:  # trivial step
-                if key.start is not None or key.stop is not None:
-                    sliced_arr = self._slice(key.start, key.stop)
-                    sliced_arr[:] = value
-                    return
-                # assign value to the whole NDArray
-                # may need to broadcast first
+            if key < 0:
+                key += shape[0]
+            if key < 0 or key >= shape[0]:
+                if key < 0:
+                    key -= shape[0]
+                raise IndexError('index %d is out of bounds for axis 0 with size %d'
+                                 % (key, shape[0]))
+            key = py_slice(key, key+1)  # key must be >= 0 here
+
+        if isinstance(key, py_slice):
+            assign_to_self = key.step is None or key.step == 1
+            assign_to_self &= key.start is None or key.start == 0
+            assign_to_self &= key.stop is None or key.stop == shape[0]
+            if assign_to_self:  # trivial case, assign value to self
                 if isinstance(value, NDArray):
                     if value.handle is not self.handle:
                         if value.shape != shape:
@@ -709,7 +712,7 @@ fixed-size items.
                     value_nd = self._prepare_value_nd(value, shape)
                     value_nd.copyto(self)
                 return
-            else:  # non-trivial step, use _slice_assign or _slice_assign_scalar
+            else:  # non-trivial case, use _slice_assign or _slice_assign_scalar
                 key = (key,)
 
         assert isinstance(key, tuple), "key=%s must be a tuple of slices and integers" % str(key)
@@ -762,7 +765,8 @@ fixed-size items.
         indices = self._get_index_nd(key)
         vshape = _get_oshape_of_gather_nd_op(self.shape, indices.shape)
         value_nd = self._prepare_value_nd(value, vshape)
-        _internal._scatter_set_nd(data=value_nd, indices=indices, shape=self.shape, out=self)
+        _internal._scatter_set_nd(lhs=self, rhs=value_nd, indices=indices,
+                                  shape=self.shape, out=self)
 
     def _get_nd_basic_indexing(self, key):
         """This function is called when key is a slice, or an integer,
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index c5a4740..7caf305 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -194,7 +194,7 @@ void Imperative::RecordOp(
       << "will cause undefined behavior when evaluating gradients. "
       << "Please call backward first to clear the graph or do this out side of "
       << "a record section. Also note that you cannot use inplace operations "
-      << "like +=, *=, relu(x, out=x), etc inside a record section.";
+      << "like +=, *=, relu(x, out=x), y[idx]=x, etc inside a record section.";
   }
 
   bool need_grad = false;
diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc
index 6f0f468..fbb94b2 100644
--- a/src/operator/tensor/indexing_op.cc
+++ b/src/operator/tensor/indexing_op.cc
@@ -668,7 +668,9 @@ Examples::
 NNVM_REGISTER_OP(_scatter_set_nd)
 .describe(R"code(This operator has the same functionality as scatter_nd
 except that it does not reset the elements not indexed by the input
-index `NDArray` in the input data `NDArray`.
+index `NDArray` in the input data `NDArray`. output should be explicitly
+given and be the same as lhs.
+
 .. note:: This operator is for internal use only.
 
 Examples::
@@ -676,21 +678,62 @@ Examples::
   data = [2, 3, 0]
   indices = [[1, 1, 0], [0, 1, 0]]
   out = [[1, 1], [1, 1]]
-  scatter_nd(data=data, indices=indices, out=out)
+  _scatter_set_nd(lhs=out, rhs=data, indices=indices, out=out)
   out = [[0, 1], [2, 3]]
 
 )code")
 .set_num_outputs(1)
-.set_num_inputs(2)
+.set_num_inputs(3)
 .set_attr_parser(ParamParser<ScatterNDParam>)
 .set_attr<nnvm::FListInputNames>("FListInputNames",
   [](const NodeAttrs& attrs) {
-    return std::vector<std::string>{"data", "indices"};
+    return std::vector<std::string>{"lhs", "rhs", "indices"};
+  })
+.set_attr<nnvm::FInferShape>("FInferShape",
+  [](const nnvm::NodeAttrs& attrs,
+     std::vector<TShape> *in_attrs,
+     std::vector<TShape> *out_attrs) {
+    CHECK_EQ(in_attrs->size(), 3U);
+    CHECK_EQ(out_attrs->size(), 1U);
+    SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
+    std::vector<TShape> tmp_in_attrs = {in_attrs->at(1), in_attrs->at(2)};
+    if (!ScatterNDShape(attrs, &tmp_in_attrs, out_attrs)) {
+      return false;
+    }
+    SHAPE_ASSIGN_CHECK(*in_attrs, 1, tmp_in_attrs[0]);
+    SHAPE_ASSIGN_CHECK(*in_attrs, 2, tmp_in_attrs[1]);
+    SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
+    return true;
+  })
+.set_attr<nnvm::FInferType>("FInferType",
+  [](const nnvm::NodeAttrs& attrs,
+     std::vector<int> *in_attrs,
+     std::vector<int> *out_attrs) {
+    CHECK_EQ(in_attrs->size(), 3U);
+    CHECK_EQ(out_attrs->size(), 1U);
+    TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
+    TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
+    std::vector<int> tmp_in_attrs = {in_attrs->at(1), in_attrs->at(2)};
+    if (!ScatterNDType(attrs, &tmp_in_attrs, out_attrs)) {
+      return false;
+    }
+    TYPE_ASSIGN_CHECK(*in_attrs, 1, tmp_in_attrs[0]);
+    TYPE_ASSIGN_CHECK(*in_attrs, 2, tmp_in_attrs[1]);
+    TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
+    return true;
   })
-.set_attr<nnvm::FInferShape>("FInferShape", ScatterNDShape)
-.set_attr<nnvm::FInferType>("FInferType", ScatterNDType)
 .set_attr<FCompute>("FCompute<cpu>", ScatterSetNDForward<cpu>)
-.add_argument("data", "NDArray-or-Symbol", "data")
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
+.set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity",
+  [](const NodeAttrs& attrs){
+    return std::vector<bool>{true};
+  })
+.add_argument("lhs", "NDArray-or-Symbol", "source input")
+.add_argument("rhs", "NDArray-or-Symbol", "value to assign")
 .add_argument("indices", "NDArray-or-Symbol", "indices")
 .add_arguments(ScatterNDParam::__FIELDS__());
 
diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h
index ef0779b..f824add 100644
--- a/src/operator/tensor/indexing_op.h
+++ b/src/operator/tensor/indexing_op.h
@@ -1111,7 +1111,7 @@ inline bool ScatterNDType(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(out_attrs->size(), 1U);
   TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]);
   TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[0]);
-  return true;
+  return in_attrs->at(0) != -1 && in_attrs->at(1) != -1;
 }
 
 struct scatter_nd {
@@ -1228,7 +1228,10 @@ void ScatterSetNDForward(const nnvm::NodeAttrs& attrs,
                          const std::vector<TBlob>& inputs,
                          const std::vector<OpReqType>& req,
                          const std::vector<TBlob>& outputs) {
-  ScatterNDForward<xpu>(attrs, ctx, inputs, {kWriteInplace}, outputs);
+  CHECK_EQ(inputs.size(), 3U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(inputs[0].dptr_, outputs[0].dptr_);
+  ScatterNDForward<xpu>(attrs, ctx, {inputs[1], inputs[2]}, {kWriteInplace}, outputs);
 }
 
 }  // namespace op
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index 9ff2f1a..496f80f 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -23,7 +23,7 @@ import unittest
 from nose.tools import raises
 from common import setup_module, with_seed, assertRaises, TemporaryDirectory
 from mxnet.test_utils import almost_equal
-from mxnet.test_utils import assert_almost_equal
+from mxnet.test_utils import assert_almost_equal, assert_exception
 from mxnet.test_utils import default_context
 from mxnet.test_utils import np_reduce
 from mxnet.test_utils import same
@@ -1064,6 +1064,18 @@ def test_ndarray_indexing():
         x_grad[index] = value
         assert same(x_grad.asnumpy(), x.grad.asnumpy())
 
+    def test_setitem_autograd(np_array, index):
+        x = mx.nd.array(np_array, dtype=np_array.dtype)
+        out_shape = x[index].shape
+        y = mx.nd.random.uniform(shape=out_shape)
+        y.attach_grad()
+        try:
+            with mx.autograd.record():
+                x[index] = y
+                assert False  # should not reach here
+        except mx.base.MXNetError as err:
+            assert str(err).find('Inplace operations (+=, -=, x[:]=, etc) are not supported when recording with') != -1
+
     def np_int(index, int_type=np.int32):
         def convert(num):
             if num is None:
@@ -1187,6 +1199,7 @@ def test_ndarray_indexing():
         test_getitem(np_array, index[0], index[1])
         test_setitem(np_array, index[0], index[1])
         test_getitem_autograd(np_array, index[0])
+        test_setitem_autograd(np_array, index[0])
 
 
 def test_assign_float_value_to_ndarray():

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.