You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2020/09/22 01:00:09 UTC

[incubator-mxnet] branch master updated: Fix numpy ndarray `__getitem__` for HybridBlock.forward usecase (#19171)

This is an automated email from the ASF dual-hosted git repository.

lausen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new dd44c0c  Fix numpy ndarray `__getitem__` for HybridBlock.forward usecase (#19171)
dd44c0c is described below

commit dd44c0c3bc168b3e88cda22c443283894fd24c54
Author: Leonard Lausen <la...@amazon.com>
AuthorDate: Mon Sep 21 17:57:24 2020 -0700

    Fix numpy ndarray `__getitem__` for HybridBlock.forward usecase (#19171)
---
 python/mxnet/numpy/multiarray.py               | 65 +++++++++++++++++++++-----
 python/mxnet/symbol/numpy/_symbol.py           | 27 ++++-------
 tests/python/unittest/test_deferred_compute.py | 14 ++++++
 3 files changed, 78 insertions(+), 28 deletions(-)

diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index f2b5706..79c995f 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -724,6 +724,7 @@ class ndarray(NDArray):
         if isinstance(key, ndarray) and key.dtype == _np.bool_:
             return self._get_np_boolean_indexing(key, ndim, shape)
 
+        all = __builtins__['all']  # `def all` below shadows the all builtin
         if ndim == 0 and key != ():
             raise IndexError('scalar tensor can only accept `()` as index')
         # Handle simple cases for higher speed
@@ -736,28 +737,72 @@ class ndarray(NDArray):
                 out = out[idx]
             return out
         if isinstance(key, integer_types):
+            # Equivalent to isinstance(key, integer_types) case in numpy/_symbol.py
             if key > shape[0] - 1:
                 raise IndexError(
                     'index {} is out of bounds for axis 0 with size {}'.format(
                         key, shape[0]))
             return self._at(key)
         elif isinstance(key, py_slice):
+            # Unlike numpy/_symbol.py, calls MXNDArraySlice64 writable memory
+            # sharing if key.step not in [None, 1]. Equivalent otherwise to
+            # isinstance(key, py_slice) case in _symbol.py otherwise.
             if key.step is None or key.step == 1:
                 if key.start is not None or key.stop is not None:
                     return self._slice(key.start, key.stop)
                 else:
                     return self
-            elif key.step == 0:
+            elif key.step != 0:
+                start = [None] if key.start is None else key.start
+                stop = [None] if key.stop is None else key.stop
+                return _npi.slice(self, start, stop, key.step)
+            else:
                 raise ValueError("slice step cannot be zero")
-
-
-        all = __builtins__['all']  # `def all` below shadows the all builtin
-        if (isinstance(key, tuple) and all( \
-        (isinstance(arr, NDArray) \
-        and _np.issubdtype(arr.dtype, _np.integer) and arr.ndim > 0) \
-        for arr in key)):
+        elif isinstance(key, tuple) and \
+           all((isinstance(arr, NDArray) and _np.issubdtype(arr.dtype, _np.integer) and \
+                arr.ndim > 0) for arr in key):
+            # Equivalent case in numpy/_symbol.py
             return _npi.advanced_indexing_multiple(self, _npi.stack(*key))
-
+        elif isinstance(key, tuple) and dc.is_deferred_compute():
+            # Equivalent to isinstance(key, tuple) case in numpy/_symbol.py
+            # Only enabled in deferred compute mode, as this codepath prevents
+            # memory sharing which may be desired in non-deferred compute
+            # imperative mode.
+            begin = []
+            end = []
+            step = []
+            new_shape = ()
+            assert len(key)  # len(key) == 0 is handled a above
+            unsupported = False
+            for index in key:
+                if isinstance(index, py_slice):
+                    if index.step is not None and index.step == 0:
+                        raise ValueError("slice step cannot be zero")
+                    begin.append(index.start)
+                    end.append(index.stop)
+                    step.append(index.step)
+                    new_shape += (-2,)
+                elif isinstance(index, integer_types):
+                    if index >= 0:
+                        begin.append(index)
+                        end.append(index+1)
+                        step.append(1)
+                    else:
+                        begin.append(index)
+                        end.append(index - 1)
+                        step.append(-1)
+                    new_shape += (-3,)
+                else:
+                    unsupported = True
+                    break
+            if not unsupported:
+                new_shape += (-4,)
+                sliced = _npi.slice(self, begin, end, step)
+                return _npi.reshape(sliced, new_shape)
+
+        # Special handling for cases only supported in imperative mode
+        if dc.is_deferred_compute():
+            raise TypeError('The type of indexing used is not supported in HybridBlock.')
         # For 0-d boolean indices: A new axis is added,
         # but at the same time no axis is "used". So if we have True,
         # we add a new axis (a bit like with np.newaxis). If it is
@@ -779,8 +824,6 @@ class ndarray(NDArray):
                 key = (_np.newaxis,) + key
             return self._get_np_basic_indexing(key)
         elif indexing_dispatch_code == _NDARRAY_ADVANCED_INDEXING:
-            if dc.is_deferred_compute():
-                raise TypeError('Advanced indexing is not supported in HybridBlock.')
             if prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE:
                 return empty((0,) + self._get_np_adanced_indexing(key).shape,
                              dtype=self.dtype, ctx=self.ctx)
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index 95b232b..4c3fa7d 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -99,6 +99,7 @@ class _Symbol(Symbol):
                 raise TypeError('indices of symbol group must be integers or slices, not {}'
                                 .format(type(key)))
         else:
+            all = __builtins__['all']  # pylint: disable=redefined-outer-name
             if isinstance(key, integer_types):
                 if key == -1:
                     sliced = _npi.slice(self, [key], [None])
@@ -112,15 +113,20 @@ class _Symbol(Symbol):
                     return _npi.slice(self, start, stop, key.step)
                 else:
                     raise ValueError("slice step cannot be zero")
+            elif isinstance(key, Symbol):
+                return _npi.advanced_indexing(self, key)
+            elif isinstance(key, tuple) and len(key) == 0:
+                return self
+            elif isinstance(key, tuple) and all(isinstance(k, Symbol) for k in key):
+                key = _npi.stack(*[i for i in key])
+                sliced = _npi.advanced_indexing_multiple(self, key)
+                return sliced
             elif isinstance(key, tuple):
                 begin = []
                 end = []
                 step = []
                 new_shape = ()
-                result = self
-                is_symbol_tuple = False
-                if len(key) == 0:
-                    return self
+                assert len(key)  # len(key) == 0 handled above
                 for index in key:
                     if isinstance(index, py_slice):
                         if index.step is not None and index.step == 0:
@@ -139,25 +145,12 @@ class _Symbol(Symbol):
                             end.append(index - 1)
                             step.append(-1)
                         new_shape += (-3,)
-                    elif isinstance(index, Symbol):
-                        if new_shape != ():
-                            new_shape += (-4,)
-                            sliced = _npi.slice(result, begin, end, step)
-                            result = _npi.reshape(sliced, new_shape)
-                        if not is_symbol_tuple:
-                            is_symbol_tuple = True
                     else:
                         raise IndexError('Only integer, slice, symbol or tuple of these types'
                                          ' are supported! Received key={}'.format(key))
-                if is_symbol_tuple:
-                    key = _npi.stack(*[i for i in key])
-                    sliced = _npi.advanced_indexing_multiple(self, key)
-                    return sliced
                 new_shape += (-4,)
                 sliced = _npi.slice(self, begin, end, step)
                 return _npi.reshape(sliced, new_shape)
-            elif isinstance(key, Symbol):
-                return _npi.advanced_indexing(self, key)
             else:
                 raise IndexError('Only integer, slice, tuple or Symbol of these types are supported! '
                                  'Received key={}'.format(key))
diff --git a/tests/python/unittest/test_deferred_compute.py b/tests/python/unittest/test_deferred_compute.py
index ea6f2b4..cf867aa 100644
--- a/tests/python/unittest/test_deferred_compute.py
+++ b/tests/python/unittest/test_deferred_compute.py
@@ -561,3 +561,17 @@ def test_dc_hybridblock_symbolblock_error():
     net.hybridize()
     with pytest.raises(RuntimeError):
         out_hybrid = net(data)  # Raises RuntimeError
+
+
+def test_indexing_shape_change():
+    class ConcatBlock(mx.gluon.nn.HybridBlock):
+        def forward(self, inputs):
+            return mx.np.concatenate([
+                inputs,
+                mx.np.pad(inputs[:,1:], ((0,0), (0,1))),
+            ])
+
+    net = ConcatBlock()
+    net.hybridize()
+    net(mx.np.random.uniform(size=(8, 16)))
+    net(mx.np.random.uniform(size=(8, 8)))