You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/01 03:44:46 UTC

[GitHub] reminisce commented on a change in pull request #8246: [WIP] Continued Work on Advanced Indexing

reminisce commented on a change in pull request #8246: [WIP] Continued Work on Advanced Indexing
URL: https://github.com/apache/incubator-mxnet/pull/8246#discussion_r148177103
 
 

 ##########
 File path: python/mxnet/ndarray/ndarray.py
 ##########
 @@ -507,71 +509,357 @@ def __getitem__(self, key):
                         key, self.shape[0]))
             return self._at(key)
         elif isinstance(key, py_slice):
-            if key.step is not None:
-                raise ValueError("NDArray only supports slicing with step size 1.")
-            if key.start is not None or key.stop is not None:
+            if key.step is not None and key.step != 1:
+                if key.step == 0:
+                    raise ValueError("slice step cannot be zero")
+                start, stop, step = _get_index_range(key.start, key.stop, self.shape[0], key.step)
+                indices = arange(start, stop, step, ctx=self.context, dtype='int32')
+                if len(indices) == 0:
+                    raise ValueError('slicing NDArray with %s is not valid'
+                                     ' since it would generate an empty ndarray' % key)
+                return self.take(indices)
+            elif key.start is not None or key.stop is not None:
                 return self._slice(key.start, key.stop)
-            return self
-        elif isinstance(key, tuple):
-            shape = self.shape
-            assert len(key) > 0, "Cannot slice with empty indices"
-            assert len(shape) >= len(key), \
-                "Slicing dimensions exceeds array dimensions, %d vs %d"%(
-                    len(key), len(shape))
-            if isinstance(key[0], (NDArray, np.ndarray, list, tuple)):
-                indices = []
-                dtype = 'int32'
-                shape = None
-                for idx_i in key:
-                    if not isinstance(idx_i, NDArray):
-                        assert isinstance(idx_i, (NDArray, np.ndarray, list, tuple)), \
-                            "Combining basic and advanced indexing is not supported " \
-                            "yet. Indices must be all NDArray or all slice, not a " \
-                            "mix of both."
-                        idx_i = array(idx_i, ctx=self.context, dtype=dtype)
-                    else:
-                        dtype = idx_i.dtype
-                    if shape is None:
-                        shape = idx_i.shape
+            else:
+                return self
+
+        # Start dealing with more complicated indexing cases than only slicing at axis=0.
+        # if key is any type of NDArray, list, or np.ndarray make a tuple from it
 
 Review comment:
   Are you referring to something like this: `a[i for i in range(3)]`?
   I tried it in numpy, it doesn't work.
   ```python
   >>> import numpy as np
   >>> a = np.arange(16).reshape(4, 4)
   >>> a
   array([[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11],
          [12, 13, 14, 15]])
   >>> a[x for x in range(3)]
     File "<stdin>", line 1
       a[x for x in range(3)]
             ^
   SyntaxError: invalid syntax
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services