You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2020/09/22 01:00:09 UTC
[incubator-mxnet] branch master updated: Fix numpy ndarray
`__getitem__` for HybridBlock.forward usecase (#19171)
This is an automated email from the ASF dual-hosted git repository.
lausen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new dd44c0c Fix numpy ndarray `__getitem__` for HybridBlock.forward usecase (#19171)
dd44c0c is described below
commit dd44c0c3bc168b3e88cda22c443283894fd24c54
Author: Leonard Lausen <la...@amazon.com>
AuthorDate: Mon Sep 21 17:57:24 2020 -0700
Fix numpy ndarray `__getitem__` for HybridBlock.forward usecase (#19171)
---
python/mxnet/numpy/multiarray.py | 65 +++++++++++++++++++++-----
python/mxnet/symbol/numpy/_symbol.py | 27 ++++-------
tests/python/unittest/test_deferred_compute.py | 14 ++++++
3 files changed, 78 insertions(+), 28 deletions(-)
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index f2b5706..79c995f 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -724,6 +724,7 @@ class ndarray(NDArray):
if isinstance(key, ndarray) and key.dtype == _np.bool_:
return self._get_np_boolean_indexing(key, ndim, shape)
+ all = __builtins__['all'] # `def all` below shadows the all builtin
if ndim == 0 and key != ():
raise IndexError('scalar tensor can only accept `()` as index')
# Handle simple cases for higher speed
@@ -736,28 +737,72 @@ class ndarray(NDArray):
out = out[idx]
return out
if isinstance(key, integer_types):
+ # Equivalent to isinstance(key, integer_types) case in numpy/_symbol.py
if key > shape[0] - 1:
raise IndexError(
'index {} is out of bounds for axis 0 with size {}'.format(
key, shape[0]))
return self._at(key)
elif isinstance(key, py_slice):
+ # Unlike numpy/_symbol.py, calls MXNDArraySlice64 writable memory
+ # sharing if key.step not in [None, 1]. Equivalent otherwise to
+ # isinstance(key, py_slice) case in _symbol.py otherwise.
if key.step is None or key.step == 1:
if key.start is not None or key.stop is not None:
return self._slice(key.start, key.stop)
else:
return self
- elif key.step == 0:
+ elif key.step != 0:
+ start = [None] if key.start is None else key.start
+ stop = [None] if key.stop is None else key.stop
+ return _npi.slice(self, start, stop, key.step)
+ else:
raise ValueError("slice step cannot be zero")
-
-
- all = __builtins__['all'] # `def all` below shadows the all builtin
- if (isinstance(key, tuple) and all( \
- (isinstance(arr, NDArray) \
- and _np.issubdtype(arr.dtype, _np.integer) and arr.ndim > 0) \
- for arr in key)):
+ elif isinstance(key, tuple) and \
+ all((isinstance(arr, NDArray) and _np.issubdtype(arr.dtype, _np.integer) and \
+ arr.ndim > 0) for arr in key):
+ # Equivalent case in numpy/_symbol.py
return _npi.advanced_indexing_multiple(self, _npi.stack(*key))
-
+ elif isinstance(key, tuple) and dc.is_deferred_compute():
+ # Equivalent to isinstance(key, tuple) case in numpy/_symbol.py
+ # Only enabled in deferred compute mode, as this codepath prevents
+ # memory sharing which may be desired in non-deferred compute
+ # imperative mode.
+ begin = []
+ end = []
+ step = []
+ new_shape = ()
+ assert len(key) # len(key) == 0 is handled a above
+ unsupported = False
+ for index in key:
+ if isinstance(index, py_slice):
+ if index.step is not None and index.step == 0:
+ raise ValueError("slice step cannot be zero")
+ begin.append(index.start)
+ end.append(index.stop)
+ step.append(index.step)
+ new_shape += (-2,)
+ elif isinstance(index, integer_types):
+ if index >= 0:
+ begin.append(index)
+ end.append(index+1)
+ step.append(1)
+ else:
+ begin.append(index)
+ end.append(index - 1)
+ step.append(-1)
+ new_shape += (-3,)
+ else:
+ unsupported = True
+ break
+ if not unsupported:
+ new_shape += (-4,)
+ sliced = _npi.slice(self, begin, end, step)
+ return _npi.reshape(sliced, new_shape)
+
+ # Special handling for cases only supported in imperative mode
+ if dc.is_deferred_compute():
+ raise TypeError('The type of indexing used is not supported in HybridBlock.')
# For 0-d boolean indices: A new axis is added,
# but at the same time no axis is "used". So if we have True,
# we add a new axis (a bit like with np.newaxis). If it is
@@ -779,8 +824,6 @@ class ndarray(NDArray):
key = (_np.newaxis,) + key
return self._get_np_basic_indexing(key)
elif indexing_dispatch_code == _NDARRAY_ADVANCED_INDEXING:
- if dc.is_deferred_compute():
- raise TypeError('Advanced indexing is not supported in HybridBlock.')
if prepend == _NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE:
return empty((0,) + self._get_np_adanced_indexing(key).shape,
dtype=self.dtype, ctx=self.ctx)
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index 95b232b..4c3fa7d 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -99,6 +99,7 @@ class _Symbol(Symbol):
raise TypeError('indices of symbol group must be integers or slices, not {}'
.format(type(key)))
else:
+ all = __builtins__['all'] # pylint: disable=redefined-outer-name
if isinstance(key, integer_types):
if key == -1:
sliced = _npi.slice(self, [key], [None])
@@ -112,15 +113,20 @@ class _Symbol(Symbol):
return _npi.slice(self, start, stop, key.step)
else:
raise ValueError("slice step cannot be zero")
+ elif isinstance(key, Symbol):
+ return _npi.advanced_indexing(self, key)
+ elif isinstance(key, tuple) and len(key) == 0:
+ return self
+ elif isinstance(key, tuple) and all(isinstance(k, Symbol) for k in key):
+ key = _npi.stack(*[i for i in key])
+ sliced = _npi.advanced_indexing_multiple(self, key)
+ return sliced
elif isinstance(key, tuple):
begin = []
end = []
step = []
new_shape = ()
- result = self
- is_symbol_tuple = False
- if len(key) == 0:
- return self
+ assert len(key) # len(key) == 0 handled above
for index in key:
if isinstance(index, py_slice):
if index.step is not None and index.step == 0:
@@ -139,25 +145,12 @@ class _Symbol(Symbol):
end.append(index - 1)
step.append(-1)
new_shape += (-3,)
- elif isinstance(index, Symbol):
- if new_shape != ():
- new_shape += (-4,)
- sliced = _npi.slice(result, begin, end, step)
- result = _npi.reshape(sliced, new_shape)
- if not is_symbol_tuple:
- is_symbol_tuple = True
else:
raise IndexError('Only integer, slice, symbol or tuple of these types'
' are supported! Received key={}'.format(key))
- if is_symbol_tuple:
- key = _npi.stack(*[i for i in key])
- sliced = _npi.advanced_indexing_multiple(self, key)
- return sliced
new_shape += (-4,)
sliced = _npi.slice(self, begin, end, step)
return _npi.reshape(sliced, new_shape)
- elif isinstance(key, Symbol):
- return _npi.advanced_indexing(self, key)
else:
raise IndexError('Only integer, slice, tuple or Symbol of these types are supported! '
'Received key={}'.format(key))
diff --git a/tests/python/unittest/test_deferred_compute.py b/tests/python/unittest/test_deferred_compute.py
index ea6f2b4..cf867aa 100644
--- a/tests/python/unittest/test_deferred_compute.py
+++ b/tests/python/unittest/test_deferred_compute.py
@@ -561,3 +561,17 @@ def test_dc_hybridblock_symbolblock_error():
net.hybridize()
with pytest.raises(RuntimeError):
out_hybrid = net(data) # Raises RuntimeError
+
+
+def test_indexing_shape_change():
+ class ConcatBlock(mx.gluon.nn.HybridBlock):
+ def forward(self, inputs):
+ return mx.np.concatenate([
+ inputs,
+ mx.np.pad(inputs[:,1:], ((0,0), (0,1))),
+ ])
+
+ net = ConcatBlock()
+ net.hybridize()
+ net(mx.np.random.uniform(size=(8, 16)))
+ net(mx.np.random.uniform(size=(8, 8)))