[GitHub] piiswrong closed pull request #8246: Continued Work on Advanced Indexing

piiswrong closed pull request #8246: Continued Work on Advanced Indexing

diff --git a/python/mxnet/ndarray/ b/python/mxnet/ndarray/
index ad1acaf9b5..6cbf3284e5 100644
--- a/python/mxnet/ndarray/
+++ b/python/mxnet/ndarray/
@@ -96,6 +96,11 @@
 # pylint: enable= no-member
+# Return code for dispatching indexing function call
 def _new_empty_handle():
     """Returns a new empty handle.
@@ -389,100 +394,36 @@ def __setitem__(self, key, value):
         >>> x.asnumpy()
         array([[ 1.,  2.,  1.],
                [ 0.,  0.,  4.]], dtype=float32)
-        """
-        # pylint: disable=too-many-branches
-        if not self.writable:
-            raise ValueError('Cannot assign to readonly NDArray')
-        if isinstance(key, integer_types):
-            sliced_arr = self._at(key)
-            sliced_arr[:] = value
-            return
-        elif isinstance(key, py_slice):
-            if key.step is not None:
-                raise ValueError('NDArray only supports slicing with step size 1')
-            if key.start is not None or key.stop is not None:
-                sliced_arr = self._slice(key.start, key.stop)
-                sliced_arr[:] = value
-                return
-            if isinstance(value, NDArray):
-                if value.handle is not self.handle:
-                    value.copyto(self)
-            elif isinstance(value, numeric_types):
-                _internal._set_value(float(value), out=self)
-            elif isinstance(value, (np.ndarray, np.generic)):
-                self._sync_copyfrom(value)
-            else:
-                raise TypeError(
-                    'NDArray does not support assignment with %s of type %s'%(
-                        str(value), str(type(value))))
-        elif isinstance(key, tuple):
-            # multi-dimension indexing
-            my_shape = self.shape
-            assert len(key) <= len(my_shape), \
-                "Indexing dimensions exceed array dimensions, %d vs %d"%(
-                    len(key), len(my_shape))
-            begin = [0 for _ in my_shape]
-            end = [x for x in my_shape]
-            expand = []
-            for i, slice_i in enumerate(key):
-                if isinstance(slice_i, integer_types):
-                    assert slice_i < my_shape[i]
-                    begin[i] = slice_i
-                    end[i] = slice_i + 1
-                    expand.append(i)
-                elif isinstance(slice_i, py_slice):
-                    # only support continuous slicing
-                    assert slice_i.step is None, \
-                        "NDArray only supports slicing with step size 1."
-                    begin[i] = slice_i.start or 0
-                    end[i] = slice_i.stop or my_shape[i]
-                    assert begin[i] < end[i]
-                    assert end[i] <= my_shape[i]
-                else:
-                    raise ValueError(
-                        "NDArray does not support slicing with key %s of type %s."%(
-                            str(slice_i), str(type(slice_i))))
-            if isinstance(value, NDArray):
-                value = value.as_in_context(self.context)
-                self._slice_assign(value, begin, end, expand)
-            elif isinstance(value, numeric_types):
-                _internal._crop_assign_scalar(self, out=self,
-                                              begin=begin, end=end,
-                                              scalar=value)
-            elif isinstance(value, (np.ndarray, np.generic)):
-                value = array(value, ctx=self.context, dtype=self.dtype)
-                self._slice_assign(value, begin, end, expand)
-            else:
-                raise TypeError(
-                    'NDArray does not support assignment with %s of type %s'%(
-                        str(value), str(type(value))))
+        >>> x[[0], [1, 2]] = 5
+        >>> x.asnumpy()
+        array([[ 1.,  5.,  5.],
+               [ 0.,  0.,  4.]], dtype=float32)
+        >>> x[::-1, 0:2:2] = [6]
+        >>> x.asnumpy()
+        array([[ 6.,  5.,  5.],
+               [ 6.,  0.,  4.]], dtype=float32)
+        """
+        indexing_dispatch_code = _get_indexing_dispatch_code(key)
+        if indexing_dispatch_code == _NDARRAY_BASIC_INDEXING:
+            self._set_nd_basic_indexing(key, value)
+        elif indexing_dispatch_code == _NDARRAY_ADVANCED_INDEXING:
+            self._set_nd_advanced_indexing(key, value)
-            raise ValueError(
-                "NDArray does not support slicing with key %s of type %s."%(
-                    str(key), str(type(key))))
-        # pylint: enable=too-many-branches
-    def _slice_assign(self, value, begin, end, expand):
-        vshape = list(value.shape)
-        if expand and len(vshape) != len(begin):
-            if len(expand) + len(vshape) != len(begin):
-                sshape = [e - b for e, b in zip(end, begin)]
-                for i in reversed(expand):
-                    sshape.pop(i)
-                raise ValueError(
-                    "Cannot assign NDArray with shape %s to NDArray slice with " \
-                    "shape %s"%(str(vshape), str(sshape)))
-            for i in expand:
-                vshape.insert(i, 1)
-            value = value.reshape(vshape)
-        _internal._crop_assign(self, value, out=self,
-                               begin=begin, end=end)
+            raise ValueError('Indexing NDArray with index=%s and type=%s is not supported'
+                             % (str(key), str(type(key))))
     def __getitem__(self, key):
         """x.__getitem__(i) <=> x[i]
-        Returns a sliced view of this array.
+        Returns a sliced view of this array if the elements fetched are contiguous in memory;
+        otherwise, returns a newly created NDArray.
+        This functions supports advanced indexing defined in the following reference with
+        some limitations.
+  # pylint: disable=line-too-long
+        The following features/functionality are not supported for now:
+        1. If key is a list type, only a list of integers is supported,
+           i.e. key=[1, 2] is okay, while not for key=[[1]].
+        2. Ellipsis (...) and np.newaxis are not supported.
+        3. Boolean array indexing.
@@ -502,80 +443,353 @@ def __getitem__(self, key):
         >>> x.asnumpy()
         array([[ 2.,  2.,  2.],
                [ 3.,  4.,  5.]], dtype=float32)
-        """
-        # multi-dimensional slicing is not supported yet
+        >>> x = mx.nd.arange(0, 8, dtype='int32').reshape((2, 2, 2))
+        >>> x[[0, 1]]
+        [[[0 1]
+          [2 3]]
+         [[4 5]
+          [6 7]]]
+        >>> x[1:, [0, 1]]
+        [[[4 5]
+          [6 7]]]
+        >>> y = np.array([0, 1], dtype='int32')
+        >>> x[1:, y]
+        [[[4 5]
+          [6 7]]]
+        >>> y = mx.nd.array([0, 1], dtype='int32')
+        >>> x[1:, y]
+        [[[4 5]
+          [6 7]]]
+        """
+        indexing_dispatch_code = _get_indexing_dispatch_code(key)
+        if indexing_dispatch_code == _NDARRAY_BASIC_INDEXING:
+            return self._get_nd_basic_indexing(key)
+        elif indexing_dispatch_code == _NDARRAY_ADVANCED_INDEXING:
+            return self._get_nd_advanced_indexing(key)
+        else:
+            raise ValueError('Indexing NDArray with index=%s and type=%s is not supported'
+                             % (str(key), str(type(key))))
+    def _get_index_nd(self, key):
+        """Returns an index array for use in scatter_nd and gather_nd."""
+        def _is_advanced_index(index):
+            """The definition of advanced index here includes integers as well, while
+            integers are considered as basic index type when the key contains only
+            slices and integers."""
+            return not isinstance(index, py_slice)
+        if isinstance(key, (NDArray, np.ndarray, list, integer_types, py_slice)):
+            key = (key,)
+        assert isinstance(key, tuple),\
+            'index=%s must be a NDArray, or np.ndarray, or list, or tuple ' \
+            ' type to use advanced indexing, received type=%s' % (str(key), str(type(key)))
+        assert len(key) > 0, "Cannot slice with empty indices"
+        shape = self.shape
+        assert len(shape) >= len(key),\
+            "Slicing dimensions exceeds array dimensions, %d vs %d" % (len(key), len(shape))
+        indices = []
+        dtype = 'int32'  # index data type passed to gather_nd op
+        need_broadcast = (len(key) != 1)
+        advanced_indices = []  # include list, NDArray, np.ndarray, integer
+        basic_indices = []  # include only slices
+        advanced_index_bshape = None  # final advanced index shape
+        for i, idx_i in enumerate(key):
+            is_advanced_index = True
+            if isinstance(idx_i, (np.ndarray, list, tuple)):
+                idx_i = array(idx_i, ctx=self.context, dtype=dtype)
+                advanced_indices.append(i)
+            elif isinstance(idx_i, py_slice):
+                start, stop, step = _get_index_range(idx_i.start, idx_i.stop, shape[i], idx_i.step)
+                idx_i = arange(start, stop, step, ctx=self.context, dtype=dtype)
+                basic_indices.append(i)
+                is_advanced_index = False
+            elif isinstance(idx_i, integer_types):
+                start, stop, step = _get_index_range(idx_i, idx_i+1, shape[i], 1)
+                idx_i = arange(start, stop, step, ctx=self.context, dtype=dtype)
+                advanced_indices.append(i)
+            elif isinstance(idx_i, NDArray):
+                if dtype != idx_i.dtype:
+                    idx_i = idx_i.astype(dtype)
+                advanced_indices.append(i)
+            else:
+                raise IndexError('Indexing NDArray with index=%s of type=%s is not supported'
+                                 % (str(key), str(type(key))))
+            if is_advanced_index:
+                if advanced_index_bshape is None:
+                    advanced_index_bshape = idx_i.shape
+                elif advanced_index_bshape != idx_i.shape:
+                    need_broadcast = True
+                    advanced_index_bshape = _get_broadcast_shape(advanced_index_bshape, idx_i.shape)
+            indices.append(idx_i)
+        # Get final index shape for gather_nd. See the following reference
+        # for determining the output array shape.
+        #  # pylint: disable=line-too-long
+        if len(advanced_indices) == 0:
+            raise ValueError('Advanced index tuple must contain at least one of the following types:'
+                             ' list, tuple, NDArray, np.ndarray, integer, received index=%s' % key)
+        # determine the output array's shape by checking whether advanced_indices are all adjacent
+        # or separated by slices
+        advanced_indices_adjacent = True
+        for i in range(0, len(advanced_indices)-1):
+            if advanced_indices[i] + 1 != advanced_indices[i+1]:
+                advanced_indices_adjacent = False
+                break
+        index_bshape_list = []  # index broadcasted shape
+        if advanced_indices_adjacent:
+            for i in range(0, advanced_indices[0]):
+                index_bshape_list.extend(indices[i].shape)
+                if not need_broadcast and indices[i].shape != advanced_index_bshape:
+                    need_broadcast = True
+            index_bshape_list.extend(advanced_index_bshape)
+            for i in range(advanced_indices[-1]+1, len(indices)):
+                if not need_broadcast and indices[i].shape != advanced_index_bshape:
+                    need_broadcast = True
+                index_bshape_list.extend(indices[i].shape)
+        else:
+            index_bshape_list.extend(advanced_index_bshape)
+            for i in basic_indices:
+                index_bshape_list.extend(indices[i].shape)
+                if not need_broadcast and indices[i].shape != advanced_index_bshape:
+                    need_broadcast = True
+        index_bshape = tuple(index_bshape_list)
+        # Need to broadcast all ndarrays in indices to the final shape.
+        # For example, suppose an array has shape=(5, 6, 7, 8) and
+        # key=(slice(1, 5), [[1, 2]], slice(2, 5), [1]).
+        # Since key[1] and key[3] are two advanced indices here and they are
+        # separated by basic indices key[0] and key[2], the output shape
+        # is (1, 2, 4, 3), where the first two elements come from the shape
+        # that key[1] and key[3] should broadcast to, which is (1, 2), and
+        # the last two elements come from the shape of two basic indices.
+        # In order to broadcast all basic and advanced indices to the output shape,
+        # we need to reshape them based on their axis. For example, to broadcast key[0],
+        # with shape=(4,), we first need to reshape it into (1, 1, 4, 1), and then
+        # broadcast the reshaped array to (1, 2, 4, 3); to broadcast key[1], we first
+        # reshape it into (1, 2, 1, 1), then broadcast the reshaped array to (1, 2, 4, 3).
+        if need_broadcast:
+            broadcasted_indices = []
+            idx_rshape = [1] * len(index_bshape)
+            if advanced_indices_adjacent:
+                advanced_index_bshape_start = advanced_indices[0]  # start index of advanced_index_bshape in index_shape
+                advanced_index_bshape_stop = advanced_index_bshape_start + len(advanced_index_bshape)
+                for i, idx in enumerate(key):
+                    if _is_advanced_index(idx):
+                        k = advanced_index_bshape_stop
+                        # find the reshaped shape for indices[i]
+                        for dim_size in indices[i].shape[::-1]:
+                            k -= 1
+                            idx_rshape[k] = dim_size
+                    else:
+                        if i < advanced_indices[0]:  # slice is on the left side of advanced indices
+                            idx_rshape[i] = indices[i].shape[0]
+                        elif i > advanced_indices[-1]:  # slice is on the right side of advanced indices
+                            idx_rshape[i-len(key)] = indices[i].shape[0]
+                        else:
+                            raise ValueError('basic index i=%d cannot be between advanced index i=%d and i=%d'
+                                             % (i, advanced_indices[0], advanced_indices[-1]))
+                    # broadcast current index to the final shape
+                    broadcasted_indices.append(indices[i].reshape(tuple(idx_rshape)).broadcast_to(index_bshape))
+                    # reset idx_rshape to ones
+                    for j, _ in enumerate(idx_rshape):
+                        idx_rshape[j] = 1
+            else:
+                basic_index_offset = len(advanced_index_bshape)
+                for i, idx in enumerate(key):
+                    if _is_advanced_index(idx):
+                        k = len(advanced_index_bshape)
+                        for dim_size in indices[i].shape[::-1]:
+                            k -= 1
+                            idx_rshape[k] = dim_size
+                    else:
+                        idx_rshape[basic_index_offset] = indices[i].shape[0]
+                        basic_index_offset += 1
+                    # broadcast current index to the final shape
+                    broadcasted_indices.append(indices[i].reshape(tuple(idx_rshape)).broadcast_to(index_bshape))
+                    # reset idx_rshape to ones
+                    for j, _ in enumerate(idx_rshape):
+                        idx_rshape[j] = 1
+            indices = broadcasted_indices
+        return op.stack(*indices)
+    def _prepare_value_nd(self, value, vshape):
+        """Given value and vshape, create an `NDArray` from value with the same
+        context and dtype as the current one and broadcast it to vshape."""
+        if isinstance(value, numeric_types):
+            value_nd = full(shape=vshape, val=value, ctx=self.context, dtype=self.dtype)
+        elif isinstance(value, NDArray):
+            value_nd = value.as_in_context(self.context)
+            if value_nd.dtype != self.dtype:
+                value_nd = value_nd.astype(self.dtype)
+        else:
+            try:
+                value_nd = array(value, ctx=self.context, dtype=self.dtype)
+            except:
+                raise TypeError('NDArray does not support assignment with non-array-like'
+                                ' object %s of type %s' % (str(value), str(type(value))))
+        if value_nd.shape != vshape:
+            value_nd = value_nd.broadcast_to(vshape)
+        return value_nd
+    def _set_nd_basic_indexing(self, key, value):
+        """This function is called by __setitem__ when key is a basic index, i.e.
+        an integer, or a slice, or a tuple of integers and slices. No restrictions
+        on the values of slices' steps."""
+        shape = self.shape
+        if isinstance(key, integer_types):
+            sliced_arr = self._at(key)
+            sliced_arr[:] = value
+            return
+        elif isinstance(key, py_slice):
+            if key.step is None or key.step == 1:  # trivial step
+                if key.start is not None or key.stop is not None:
+                    sliced_arr = self._slice(key.start, key.stop)
+                    sliced_arr[:] = value
+                    return
+                # assign value to the whole NDArray
+                # may need to broadcast first
+                if isinstance(value, NDArray):
+                    if value.handle is not self.handle:
+                        value.copyto(self)
+                elif isinstance(value, numeric_types):
+                    _internal._full(shape=shape, ctx=self.context,
+                                    dtype=self.dtype, value=value, out=self)
+                elif isinstance(value, (np.ndarray, np.generic)):
+                    if isinstance(value, np.generic) or value.shape != shape:
+                        value = np.broadcast_to(value, shape)
+                    self._sync_copyfrom(value)
+                else:  # value might be a list or a tuple
+                    value_nd = self._prepare_value_nd(value, shape)
+                    value_nd.copyto(self)
+                return
+            else:  # non-trivial step, use _slice_assign or _slice_assign_scalar
+                key = (key,)
+        assert isinstance(key, tuple), "key=%s must be a tuple of slices and integers" % str(key)
+        assert len(key) <= len(shape), "Indexing dimensions exceed array dimensions, %d vs %d"\
+                                       % (len(key), len(shape))
+        begin = []
+        end = []
+        steps = []
+        oshape = []  # output shape of slice using key
+        vshape = []  # value shape of data[key]
+        for i, slice_i in enumerate(key):
+            dim_size = 1
+            if isinstance(slice_i, py_slice):
+                begin.append(slice_i.start)
+                end.append(slice_i.stop)
+                steps.append(slice_i.step)
+                start, stop, step = _get_index_range(slice_i.start, slice_i.stop,
+                                                     shape[i], slice_i.step)
+                dim_size = _get_dim_size(start, stop, step)
+                vshape.append(dim_size)
+            elif isinstance(slice_i, integer_types):
+                begin.append(slice_i)
+                end.append(slice_i+1)
+                steps.append(1)
+            else:
+                raise ValueError("basic indexing does not support index=%s of type=%s"
+                                 % (str(slice_i), str(type(slice_i))))
+            oshape.append(dim_size)
+        oshape.extend(shape[len(key):])
+        vshape.extend(shape[len(key):])
+        # if key contains all integers, vshape should be (1,)
+        if len(vshape) == 0:
+            vshape.append(1)
+        oshape = tuple(oshape)
+        vshape = tuple(vshape)
+        if isinstance(value, numeric_types):
+            _internal._slice_assign_scalar(self, out=self, begin=begin, end=end,
+                                           step=steps, scalar=float(value))
+        else:
+            value_nd = self._prepare_value_nd(value, vshape)
+            if vshape != oshape:
+                value_nd = value_nd.reshape(oshape)
+            _internal._slice_assign(self, value_nd, begin, end, steps, out=self)
+    def _set_nd_advanced_indexing(self, key, value):
+        """This function is called by __setitem__ when key is an advanced index."""
+        indices = self._get_index_nd(key)
+        vshape = _get_oshape_of_gather_nd_op(self.shape, indices.shape)
+        value_nd = self._prepare_value_nd(value, vshape)
+        _internal._scatter_set_nd(data=value_nd, indices=indices, shape=self.shape, out=self)
+    def _get_nd_basic_indexing(self, key):
+        """This function is called when key is a slice, or an integer,
+        or a tuple of slices or integers"""
+        shape = self.shape
         if isinstance(key, integer_types):
-            if key > self.shape[0] - 1:
+            if key > shape[0] - 1:
                 raise IndexError(
                     'index {} is out of bounds for axis 0 with size {}'.format(
-                        key, self.shape[0]))
+                        key, shape[0]))
             return self._at(key)
         elif isinstance(key, py_slice):
-            if key.step is not None:
-                raise ValueError("NDArray only supports slicing with step size 1.")
-            if key.start is not None or key.stop is not None:
+            if key.step is not None and key.step != 1:
+                if key.step == 0:
+                    raise ValueError("slice step cannot be zero")
+                return op.slice(self, begin=(key.start,), end=(key.stop,), step=(key.step,))
+            elif key.start is not None or key.stop is not None:
                 return self._slice(key.start, key.stop)
-            return self
-        elif isinstance(key, tuple):
-            shape = self.shape
-            assert len(key) > 0, "Cannot slice with empty indices"
-            assert len(shape) >= len(key), \
-                "Slicing dimensions exceeds array dimensions, %d vs %d"%(
-                    len(key), len(shape))
-            if isinstance(key[0], (NDArray, np.ndarray, list, tuple)):
-                indices = []
-                dtype = 'int32'
-                shape = None
-                for idx_i in key:
-                    if not isinstance(idx_i, NDArray):
-                        assert isinstance(idx_i, (NDArray, np.ndarray, list, tuple)), \
-                            "Combining basic and advanced indexing is not supported " \
-                            "yet. Indices must be all NDArray or all slice, not a " \
-                            "mix of both."
-                        idx_i = array(idx_i, ctx=self.context, dtype=dtype)
-                    else:
-                        dtype = idx_i.dtype
-                    if shape is None:
-                        shape = idx_i.shape
-                    else:
-                        assert shape == idx_i.shape, \
-                            "All index arrays must have the same shape: %s vs %s. " \
-                            "Broadcasting is not supported yet."%(shape, idx_i.shape)
-                    indices.append(idx_i)
-                indices = op.stack(*indices)
-                return op.gather_nd(self, indices)
-                oshape = []
-                begin = []
-                end = []
-                i = -1
-                for i, slice_i in enumerate(key):
-                    if isinstance(slice_i, integer_types):
-                        begin.append(slice_i)
-                        end.append(slice_i+1)
-                    elif isinstance(slice_i, py_slice):
-                        if slice_i.step is not None:
-                            raise ValueError("NDArray only supports slicing with step size 1.")
-                        begin.append(0 if slice_i.start is None else slice_i.start)
-                        end.append(shape[i] if slice_i.stop is None else slice_i.stop)
-                        oshape.append(end[i] - begin[i])
-                    elif isinstance(slice_i, (NDArray, np.ndarray, list, tuple)):
-                        raise ValueError(
-                            "Combining basic and advanced indexing is not supported " \
-                            "yet. Indices must be all NDArray or all slice, not a " \
-                            "mix of both.")
-                    else:
-                        raise ValueError(
-                            "NDArray does not support slicing with key %s of type %s."%(
-                                str(slice_i), str(type(slice_i))))
-                oshape.extend(shape[i+1:])
-                if len(oshape) == 0:
-                    oshape.append(1)
-                return op.slice(self, begin, end).reshape(oshape)
-        else:
-            raise ValueError(
-                "NDArray does not support slicing with key %s of type %s."%(
-                    str(key), str(type(key))))
+                return self
+        if not isinstance(key, tuple):
+            raise ValueError('index=%s must be a slice, or an ineger, or a tuple'
+                             ' of slices and integers to use basic indexing, received type=%s'
+                             % (str(key), str(type(key))))
+        assert len(key) != 0, 'basic index cannot be an empty tuple'
+        begin = []
+        end = []
+        step = []
+        kept_axes = []  # axes where slice_i is a slice
+        i = -1
+        for i, slice_i in enumerate(key):
+            if isinstance(slice_i, integer_types):
+                begin.append(slice_i)
+                end.append(slice_i+1)
+                step.append(1)
+            elif isinstance(slice_i, py_slice):
+                if slice_i.step == 0:
+                    raise ValueError('basic index=%s cannot have slice=%s with step = 0'
+                                     % (str(key), str(slice_i)))
+                begin.append(slice_i.start)
+                end.append(slice_i.stop)
+                step.append(slice_i.step)
+                kept_axes.append(i)
+            else:
+                raise ValueError('basic_indexing does not support slicing with '
+                                 'index=%s of type=%s.' % (str(slice_i), str(type(slice_i))))
+        kept_axes.extend(range(i+1, len(shape)))
+        sliced_nd = op.slice(self, begin, end, step)
+        if len(kept_axes) == len(shape):
+            return sliced_nd
+        # squeeze sliced_shape to remove the axes indexed by integers
+        oshape = []
+        sliced_shape = sliced_nd.shape
+        for axis in kept_axes:
+            oshape.append(sliced_shape[axis])
+        # if key is a tuple of integers, still need to keep 1 dim
+        # while in Numpy, the output will become an value instead of an ndarray
+        if len(oshape) == 0:
+            oshape.append(1)
+        oshape = tuple(oshape)
+        assert ==, 'oshape=%s has different size'\
+                                                         ' than sliced_shape=%s'\
+                                                         % (oshape, sliced_shape)
+        return sliced_nd.reshape(oshape)
+    def _get_nd_advanced_indexing(self, key):
+        """Get item when key is a tuple of any objects of the following types:
+        NDArray, np.ndarray, list, tuple, slice, and integer."""
+        return op.gather_nd(self, self._get_index_nd(key))
     def _sync_copyfrom(self, source_array):
         """Performs a synchronized copy from the `source_array` to the current array.
@@ -638,26 +852,10 @@ def _slice(self, start, stop):
         array([], shape=(0, 2), dtype=float32)
         handle = NDArrayHandle()
-        if start is None:
-            start = mx_uint(0)
-        elif start < 0:
-            length = self.shape[0]
-            start += length
-            assert start >= 0, "Slicing start %d exceeds limit of %d"%(start-length, length)
-            start = mx_uint(start)
-        else:
-            start = mx_uint(start)
-        if stop is None:
-            stop = mx_uint(self.shape[0])
-        elif stop < 0:
-            length = self.shape[0]
-            stop += length
-            assert stop >= 0, "Slicing end %d exceeds limit of %d"%(stop-length, length)
-            stop = mx_uint(stop)
-        else:
-            stop = mx_uint(stop)
+        start, stop, _ = _get_index_range(start, stop, self.shape[0])
-            self.handle, start, stop, ctypes.byref(handle)))
+            self.handle, mx_uint(start), mx_uint(stop), ctypes.byref(handle)))
         return NDArray(handle=handle, writable=self.writable)
     def _at(self, idx):
@@ -684,9 +882,14 @@ def _at(self, idx):
         array([ 1.], dtype=float32)
         handle = NDArrayHandle()
-        idx = mx_uint(idx)
+        if idx < 0:
+            length = self.shape[0]
+            idx += length
+            if idx < 0:
+                raise IndexError('index %d is out of bounds for axis 0 with size %d'
+                                 % (idx-length, length))
-            self.handle, idx, ctypes.byref(handle)))
+            self.handle, mx_uint(idx), ctypes.byref(handle)))
         return NDArray(handle=handle, writable=self.writable)
     def reshape(self, shape):
@@ -1779,6 +1982,119 @@ def tostype(self, stype):
         return op.cast_storage(self, stype=stype)
+def _get_indexing_dispatch_code(key):
+    """Returns a dispatch code for calling basic or advanced indexing functions."""
+    if isinstance(key, (NDArray, np.ndarray)):
+    elif isinstance(key, list):
+        # TODO(junwu): Add support for nested lists besides integer list
+        for i in key:
+            if not isinstance(i, integer_types):
+                raise TypeError('Indexing NDArray only supports a list of integers as index'
+                                ' when key is of list type, received element=%s of type=%s'
+                                % (str(i), str(type(i))))
+    elif isinstance(key, (integer_types, py_slice)):
+    elif isinstance(key, tuple):
+        for idx in key:
+            if isinstance(idx, (NDArray, np.ndarray, list, tuple)):
+                return _NDARRAY_ADVANCED_INDEXING
+            elif not isinstance(idx, (py_slice, integer_types)):
+                raise ValueError("NDArray does not support slicing with key %s of type %s."
+                                 % (str(idx), str(type(idx))))
+    else:
+def _get_index_range(start, stop, length, step=1):
+    """Given start, stop, step and array length, return
+    absolute values of start, stop, and step for generating index range.
+    The returned values have been compensated by adding length if they
+    are less than zero for all the cases but slice(None, None, -1).
+    Note that the returned value of stop is not necessarily >= 0, since
+    absolute stop is -1 in the case of slice(None, None, -1)."""
+    if step == 0:
+        raise ValueError('step size cannot be zero')
+    if length < 0:
+        raise ValueError('array length cannot be less than zero')
+    if step is None:
+        step = 1
+    if start is None:
+        if step > 0:
+            start = 0
+        else:
+            start = length - 1
+    elif start < 0:
+        start += length
+        if start < 0:
+            raise IndexError('Slicing start %d exceeds limit of %d' % (start-length, length))
+    elif start >= length:
+        raise IndexError('Slicing start %d exceeds limit of %d' % (start, length))
+    if stop is None:
+        if step > 0:
+            stop = length
+        else:
+            # this supports case such as ::-1
+            # stop = -1 here refers to the element before index 0,
+            # instead of the last element in the array
+            stop = -1
+    elif stop < 0:
+        stop += length
+        if stop < 0:
+            raise IndexError('Slicing stop %d exceeds limit of %d' % (stop-length, length))
+    elif stop > length:
+        raise IndexError('Slicing stop %d exceeds limit of %d' % (stop, length))
+    return start, stop, step
+def _get_oshape_of_gather_nd_op(dshape, ishape):
+    """Given data and index shapes, get the output `NDArray` shape.
+    This basically implements the infer shape logic of op gather_nd."""
+    assert len(dshape) > 0 and len(ishape) > 0
+    oshape = list(ishape[1:])
+    if ishape[0] < len(dshape):
+        oshape.extend(dshape[ishape[0]:])
+    return tuple(oshape)
+def _get_dim_size(start, stop, step):
+    """Given start, stop, and stop, calculate the number of elements
+    of this slice."""
+    assert step != 0
+    if step > 0:
+        assert start < stop
+        dim_size = (stop - start - 1) // step + 1
+    else:
+        assert stop < start
+        dim_size = (start - stop - 1) // (-step) + 1
+    return dim_size
+def _get_broadcast_shape(shape1, shape2):
+    """Given two shapes that are not identical, find the shape
+    that both input shapes can broadcast to."""
+    if shape1 == shape2:
+        return shape1
+    length1 = len(shape1)
+    length2 = len(shape2)
+    if length1 > length2:
+        shape = list(shape1)
+    else:
+        shape = list(shape2)
+    i = max(length1, length2) - 1
+    for a, b in zip(shape1[::-1], shape2[::-1]):
+        if a != 1 and b != 1 and a != b:
+            raise ValueError('shape1=%s is not broadcastable to shape2=%s' % (shape1, shape2))
+        shape[i] = max(a, b)
+        i -= 1
+    return tuple(shape)
 def onehot_encode(indices, out):
     """One-hot encoding indices into matrix out.
@@ -1858,10 +2174,8 @@ def full(shape, val, ctx=None, dtype=mx_real_t, out=None):
     >>> mx.nd.full((1, 2), 2.0, dtype='float16').asnumpy()
     array([[ 2.,  2.]], dtype=float16)
-    if ctx is None:
-        ctx = Context.default_ctx
-    dtype = mx_real_t if dtype is None else dtype
-    out = _internal._full(shape=shape, ctx=ctx, dtype=dtype, value=val, out=out)
+    out = empty(shape, ctx, dtype) if out is None else out
+    out[:] = val
     return out
diff --git a/src/operator/tensor/ b/src/operator/tensor/
index 7c8e53e529..273ebec488 100644
--- a/src/operator/tensor/
+++ b/src/operator/tensor/
@@ -438,6 +438,35 @@ Examples::
 .add_argument("indices", "NDArray-or-Symbol", "indices")
+.describe(R"code(This operator has the same functionality as scatter_nd
+except that it does not reset the elements not indexed by the input
+index `NDArray` in the input data `NDArray`.
+.. note:: This operator is for internal use only.
+  data = [2, 3, 0]
+  indices = [[1, 1, 0], [0, 1, 0]]
+  out = [[1, 1], [1, 1]]
+  scatter_nd(data=data, indices=indices, out=out)
+  out = [[0, 1], [2, 3]]
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"data", "indices"};
+  })
+.set_attr<nnvm::FInferShape>("FInferShape", ScatterNDShape)
+.set_attr<nnvm::FInferType>("FInferType", ScatterNDType)
+.set_attr<FCompute>("FCompute<cpu>", ScatterSetNDForward<cpu>)
+.add_argument("data", "NDArray-or-Symbol", "data")
+.add_argument("indices", "NDArray-or-Symbol", "indices")
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/ b/src/operator/tensor/
index e2822e8515..2cddd006a6 100644
--- a/src/operator/tensor/
+++ b/src/operator/tensor/
@@ -49,5 +49,8 @@ NNVM_REGISTER_OP(gather_nd)
 .set_attr<FCompute>("FCompute<gpu>", ScatterNDForward<gpu>);
+.set_attr<FCompute>("FCompute<gpu>", ScatterSetNDForward<gpu>);
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h
index 7624f2d1d4..684794bcd9 100644
--- a/src/operator/tensor/indexing_op.h
+++ b/src/operator/tensor/indexing_op.h
@@ -1208,10 +1208,10 @@ struct scatter_nd {
 template<typename xpu>
 void ScatterNDForward(const nnvm::NodeAttrs& attrs,
-                     const OpContext& ctx,
-                     const std::vector<TBlob>& inputs,
-                     const std::vector<OpReqType>& req,
-                     const std::vector<TBlob>& outputs) {
+                      const OpContext& ctx,
+                      const std::vector<TBlob>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<TBlob>& outputs) {
   using namespace mshadow;
   CHECK_EQ(inputs.size(), 2U);
   CHECK_EQ(outputs.size(), 1U);
@@ -1225,7 +1225,9 @@ void ScatterNDForward(const nnvm::NodeAttrs& attrs,
   mshadow::Shape<10> strides;
   for (int i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride;
   MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {  // output data type switch
-    Fill<true>(s, outputs[0], req[0], 0);
+    if (kWriteTo == req[0]) {
+      Fill<true>(s, outputs[0], req[0], 0);
+    }
     MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {  // indices data type switch
       mxnet_op::Kernel<scatter_nd, xpu>::Launch(
         s, N, req[0], N, M, K, strides, outputs[0].dptr<DType>(),
@@ -1234,6 +1236,19 @@ void ScatterNDForward(const nnvm::NodeAttrs& attrs,
+ * This is for internal use only.
+ * DO NOT call this function unless you have to.
+ */
+template<typename xpu>
+void ScatterSetNDForward(const nnvm::NodeAttrs& attrs,
+                         const OpContext& ctx,
+                         const std::vector<TBlob>& inputs,
+                         const std::vector<OpReqType>& req,
+                         const std::vector<TBlob>& outputs) {
+  ScatterNDForward<xpu>(attrs, ctx, inputs, {kWriteInplace}, outputs);
 }  // namespace op
 }  // namespace mxnet
 #ifdef __CUDACC__
diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h
index 6940cfa71a..1f25b94ff3 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -385,41 +385,6 @@ struct SliceParam : public dmlc::Parameter<SliceParam> {
-inline TShape GetSliceShape(const SliceParam& param, const TShape& dshape) {
-  CHECK_LE(param.begin.ndim(), dshape.ndim())
-    << "Slicing axis exceeds data dimensions";
-  CHECK_LE(param.end.ndim(), dshape.ndim())
-    << "Slicing axis exceeds data dimensions";
-  CHECK_EQ(param.begin.ndim(), param.end.ndim())
-    << "begin and end must have the same length";
-  TShape oshape = dshape;
-  for (index_t i = 0; i < param.begin.ndim(); ++i) {
-    int s = 0, e = dshape[i];
-    if (e != 0) {
-      if (param.begin[i]) {
-        CHECK_LE(*param.begin[i], e)
-          << "Slicing begin exceeds data dimensions "
-          << param.begin << " vs " << dshape;
-        s = *param.begin[i];
-        if (s < 0) s += dshape[i];
-      }
-      if (param.end[i]) {
-        CHECK_LE(*param.end[i], e)
-          << "Slicing end exceeds data dimensions "
-          << param.end << " vs " << dshape;
-        e = *param.end[i];
-        if (e < 0) e += dshape[i];
-      }
-      CHECK(s >= 0 && s < e && e <= static_cast<int>(dshape[i]))
-        << "Invalid slicing begin " << param.begin << " and end "
-        << param.end << " for data of shape " << dshape;
-    }
-    oshape[i] = e - s;
-  }
-  return oshape;
 inline bool SliceForwardInferStorageType(const nnvm::NodeAttrs& attrs,
                                          const int dev_mask,
                                          DispatchMode* dispatch_mode,
@@ -683,41 +648,43 @@ void SliceEx(const nnvm::NodeAttrs& attrs,
 template<int ndim>
-inline void GetIndexRange(const SliceParam& param,
-                          const TShape& dshape,
+inline void GetIndexRange(const TShape& dshape,
+                          const nnvm::Tuple<dmlc::optional<int>>& param_begin,
+                          const nnvm::Tuple<dmlc::optional<int>>& param_end,
+                          const nnvm::Tuple<dmlc::optional<int>>& param_step,
                           common::StaticArray<int, ndim>* begin,
                           common::StaticArray<int, ndim>* end,
                           common::StaticArray<int, ndim>* step) {
   CHECK_NE(dshape.ndim(), 0U);
   CHECK_NE(dshape.Size(), 0U);
-  CHECK_LE(param.begin.ndim(), dshape.ndim())
+  CHECK_LE(param_begin.ndim(), dshape.ndim())
     << "Slicing axis exceeds data dimensions";
-  CHECK_LE(param.end.ndim(), dshape.ndim())
+  CHECK_LE(param_end.ndim(), dshape.ndim())
     << "Slicing axis exceeds data dimensions";
-  CHECK_EQ(param.begin.ndim(), param.end.ndim())
+  CHECK_EQ(param_begin.ndim(), param_end.ndim())
     << "begin and end must have the same length";
   CHECK_EQ(ndim, dshape.ndim())
     << "Static array size=" << ndim
     << " is not equal to data shape ndim=" << dshape.ndim();
-  if (param.step.ndim() != 0U) {
-    CHECK_EQ(param.step.ndim(), param.begin.ndim())
+  if (param_step.ndim() != 0U) {
+    CHECK_EQ(param_step.ndim(), param_begin.ndim())
       << "step and begin must have the same length";
-  for (index_t i = 0; i < param.begin.ndim(); ++i) {
+  for (index_t i = 0; i < param_begin.ndim(); ++i) {
     int b = 0, e = dshape[i], s = 1;
     const int len = dshape[i];
-    if (param.step.ndim() != 0U) {
-      const auto& opt_step_val = param.step[i];
+    if (param_step.ndim() != 0U) {
+      const auto& opt_step_val = param_step[i];
       if (opt_step_val.has_value()) {
         s = opt_step_val.value();
         CHECK_NE(s, 0) << "slice op step[" << i << "] cannot be 0";
-    if (param.begin[i].has_value()) {
-      b = param.begin[i].value();
+    if (param_begin[i].has_value()) {
+      b = param_begin[i].value();
       if (b < 0) {
         b += len;
         CHECK_GE(b, 0) << "slicing with begin[" << i << "]="
@@ -729,8 +696,8 @@ inline void GetIndexRange(const SliceParam& param,
     CHECK_LT(b, len) << "slicing with begin[" << i << "]="
                      << b << " exceends limit of " << len;
-    if (param.end[i].has_value()) {
-      e = param.end[i].value();
+    if (param_end[i].has_value()) {
+      e = param_end[i].value();
       if (e < 0) {
         e += len;
         CHECK_GE(e, 0) << "slicing with end[" << i << "]="
@@ -746,13 +713,27 @@ inline void GetIndexRange(const SliceParam& param,
     (*end)[i] = e;
     (*step)[i] = s;
-  for (index_t i = param.begin.ndim(); i < dshape.ndim(); ++i) {
+  for (index_t i = param_begin.ndim(); i < dshape.ndim(); ++i) {
     (*begin)[i] = 0;
     (*end)[i] = dshape[i];
     (*step)[i] = 1;
+inline void SetSliceOpOutputDimSize(const index_t i, const int b,
+                                    const int e, const int s,
+                                    TShape* oshape) {
+  if (s > 0) {
+    CHECK_LT(b, e) << "slicing with begin=[" << i << "]=" << b << ", end[" << i << "]="
+                   << e << ", and step[" << i << "]=" << s << " is invalid";
+    (*oshape)[i] = (e - b - 1) / s + 1;
+  } else {
+    CHECK_LT(e, b) << "slicing with begin=[" << i << "]=" << b << ", end[" << i << "]="
+                   << e << ", and step[" << i << "]=" << s << " is invalid";
+    (*oshape)[i] = (b - e - 1) / (-s) + 1;
+  }
 inline bool SliceOpShape(const nnvm::NodeAttrs& attrs,
                          std::vector<TShape>* in_attrs,
                          std::vector<TShape>* out_attrs) {
@@ -764,19 +745,10 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs,
   TShape oshape = dshape;
   MXNET_NDIM_SWITCH(dshape.ndim(), ndim, {
     common::StaticArray<int, ndim> begin, end, step;
-    GetIndexRange(param, dshape, &begin, &end, &step);
+    GetIndexRange(dshape, param.begin, param.end, param.step, &begin, &end, &step);
     for (index_t i = 0; i < param.begin.ndim(); ++i) {
       const int b = begin[i], e = end[i], s = step[i];
-      if (s > 0) {
-        CHECK_LT(b, e) << "slicing with begin=[" << i << "]=" << b << ", end[" << i << "]="
-                       << e << ", and step[" << i << "]=" << s << " is invalid";
-        oshape[i] = (e - b - 1) / s + 1;
-      } else {
-        CHECK_LT(e, b) << "slicing with begin=[" << i << "]=" << b << ", end[" << i << "]="
-                       << e << ", and step[" << i << "]=" << s << " is invalid";
-        oshape[i] = (b - e - 1) / (-s) + 1;
-      }
+      SetSliceOpOutputDimSize(i, b, e, s, &oshape);
@@ -832,7 +804,7 @@ void SliceOpForward(const nnvm::NodeAttrs& attrs,
   const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
   MXNET_NDIM_SWITCH(data.ndim(), ndim, {
     common::StaticArray<int, ndim> begin, end, step;
-    GetIndexRange(param, data.shape_, &begin, &end, &step);
+    GetIndexRange(data.shape_, param.begin, param.end, param.step, &begin, &end, &step);
     MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
       mxnet_op::Kernel<slice_forward<ndim>, xpu>::Launch(s, out.shape_.FlatTo2D()[0],
           out.dptr<DType>(), data.dptr<DType>(), req[0],
@@ -842,32 +814,32 @@ void SliceOpForward(const nnvm::NodeAttrs& attrs,
 template<int ndim>
-struct slice_backward {
+struct slice_assign {
   // i is the i-th row after flattening out into 2D tensor
   template<typename DType>
-  MSHADOW_XINLINE static void Map(int i, DType* igrad, const DType* ograd,
+  MSHADOW_XINLINE static void Map(int i, DType* out, const DType* val,
                                   const OpReqType req,
-                                  const mshadow::Shape<ndim> dshape,
                                   const mshadow::Shape<ndim> oshape,
+                                  const mshadow::Shape<ndim> vshape,
                                   const common::StaticArray<int, ndim> begin,
                                   const common::StaticArray<int, ndim> step) {
-    const int data_last_dim_size = dshape[ndim-1];
-    const int out_last_dim_size = oshape[ndim-1];
+    const int data_last_dim_size = oshape[ndim-1];
+    const int out_last_dim_size = vshape[ndim-1];
     const int step_last_dim = step[ndim-1];
     const int begin_last_dim = begin[ndim-1];
-    int ograd_offset = i * out_last_dim_size;
+    int offset = i * out_last_dim_size;
     for (int j = 0; j < out_last_dim_size; ++j) {
-      int irow = 0;  // row id of flattend 2D igrad
+      int irow = 0;  // row id of flattend 2D out
       int stride = 1;
       int idx = i;
       #pragma unroll
       for (int k = ndim - 2; k >= 0; --k) {
-        irow += stride * ((idx % oshape[k]) * step[k] + begin[k]);
-        idx /= oshape[k];
-        stride *= dshape[k];
+        irow += stride * ((idx % vshape[k]) * step[k] + begin[k]);
+        idx /= vshape[k];
+        stride *= oshape[k];
-      KERNEL_ASSIGN(igrad[irow * data_last_dim_size + j * step_last_dim + begin_last_dim],
-                    req, ograd[ograd_offset++]);
+      KERNEL_ASSIGN(out[irow * data_last_dim_size + j * step_last_dim + begin_last_dim],
+                    req, val[offset++]);
@@ -894,136 +866,145 @@ void SliceOpBackward(const nnvm::NodeAttrs& attrs,
   MXNET_NDIM_SWITCH(ograd.ndim(), ndim, {
     common::StaticArray<int, ndim> begin, end, step;
-    GetIndexRange(param, igrad.shape_, &begin, &end, &step);
+    GetIndexRange(igrad.shape_, param.begin, param.end, param.step, &begin, &end, &step);
     MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
-      mxnet_op::Kernel<slice_backward<ndim>, xpu>::Launch(s, ograd.shape_.FlatTo2D()[0],
+      mxnet_op::Kernel<slice_assign<ndim>, xpu>::Launch(s, ograd.shape_.FlatTo2D()[0],
           igrad.dptr<DType>(), ograd.dptr<DType>(), req[0],
           igrad.shape_.get<ndim>(), ograd.shape_.get<ndim>(), begin, step);
-inline bool SliceAssignShape(const nnvm::NodeAttrs& attrs,
-                             std::vector<TShape> *in_attrs,
-                             std::vector<TShape> *out_attrs) {
-  const TShape& lshape = (*in_attrs)[0];
-  if (lshape.ndim() == 0) return false;
+inline bool SliceAssignOpShape(const nnvm::NodeAttrs& attrs,
+                               std::vector<TShape> *in_attrs,
+                               std::vector<TShape> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 2U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  const TShape& dshape = (*in_attrs)[0];
+  if (dshape.ndim() == 0U || dshape.Size() == 0U) return false;
+  TShape vshape = dshape;  // vshape is the value shape on the right hand side
   const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
-  SHAPE_ASSIGN_CHECK(*in_attrs, 1, GetSliceShape(param, lshape));
-  SHAPE_ASSIGN_CHECK(*out_attrs, 0, lshape);
-  return true;
-template<typename xpu>
-void SliceAssignImpl(mshadow::Stream<xpu> *s, const SliceParam& param,
-                     const TBlob& dst, const TBlob& src) {
-  using namespace mshadow;
-  using namespace mshadow::expr;
-  index_t N = dst.ndim();
-  TShape begin(N), end(N);
-  for (index_t i = 0; i < N; ++i) {
-    int s = 0;
-    if (param.begin[i]) {
-      s = *param.begin[i];
-      if (s < 0) s += dst.size(i);
-    }
-    begin[i] = s;
-    end[i] = s + src.size(i);
-  }
-  MSHADOW_TYPE_SWITCH(dst.type_flag_, DType, {
-    switch (dst.ndim()) {
-      case 0:
-        break;
-      case 1: {
-        Tensor<xpu, 1, DType> out = dst.get<xpu, 1, DType>(s);
-        Tensor<xpu, 1, DType> in = src.get<xpu, 1, DType>(s);
-        slice(out, begin.get<1>(), end.get<1>()) = in;
-        break;
-      }
-      case 2: {
-        Tensor<xpu, 2, DType> out = dst.get<xpu, 2, DType>(s);
-        Tensor<xpu, 2, DType> in = src.get<xpu, 2, DType>(s);
-        slice(out, begin.get<2>(), end.get<2>()) = in;
-        break;
-      }
-      case 3: {
-        Tensor<xpu, 3, DType> out = dst.get<xpu, 3, DType>(s);
-        Tensor<xpu, 3, DType> in = src.get<xpu, 3, DType>(s);
-        slice(out, begin.get<3>(), end.get<3>()) = in;
-        break;
-      }
-      case 4: {
-        Tensor<xpu, 4, DType> out = dst.get<xpu, 4, DType>(s);
-        Tensor<xpu, 4, DType> in = src.get<xpu, 4, DType>(s);
-        slice(out, begin.get<4>(), end.get<4>()) = in;
-        break;
-      }
-      case 5: {
-        Tensor<xpu, 5, DType> out = dst.get<xpu, 5, DType>(s);
-        Tensor<xpu, 5, DType> in = src.get<xpu, 5, DType>(s);
-        slice(out, begin.get<5>(), end.get<5>()) = in;
-        break;
-      }
-      default:
-        LOG(FATAL) << "CropAssign supports at most 5 dimensions";
-        break;
+  MXNET_NDIM_SWITCH(dshape.ndim(), ndim, {
+    common::StaticArray<int, ndim> begin, end, step;
+    GetIndexRange(dshape, param.begin, param.end, param.step, &begin, &end, &step);
+    for (index_t i = 0; i < param.begin.ndim(); ++i) {
+      const int b = begin[i], e = end[i], s = step[i];
+      SetSliceOpOutputDimSize(i, b, e, s, &vshape);
+  SHAPE_ASSIGN_CHECK(*in_attrs, 1, vshape);
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape);
+  return true;
 template<typename xpu>
-void SliceAssign(const nnvm::NodeAttrs& attrs,
-                 const OpContext& ctx,
-                 const std::vector<TBlob>& inputs,
-                 const std::vector<OpReqType>& req,
-                 const std::vector<TBlob>& outputs) {
+void SliceAssignOpForward(const nnvm::NodeAttrs& attrs,
+                          const OpContext& ctx,
+                          const std::vector<TBlob>& inputs,
+                          const std::vector<OpReqType>& req,
+                          const std::vector<TBlob>& outputs) {
   using namespace mshadow;
-  using namespace mshadow::expr;
-  const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
+  CHECK_EQ(inputs.size(), 2U);  // data[index] = val, data and val are two inputs
+  CHECK_EQ(outputs.size(), 1U);
+  if (req[0] == kNullOp) return;
   Stream<xpu> *s = ctx.get_stream<xpu>();
-  if (req[0] == kNullOp) {
-    return;
-  } else if (req[0] == kWriteTo) {
+  const TBlob& data = inputs[0];
+  const TBlob& val = inputs[1];
+  const TBlob& out = outputs[0];
+  if (req[0] == kWriteTo) {
     MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
       Tensor<xpu, 1, DType> in = inputs[0].FlatTo1D<xpu, DType>(s);
       Tensor<xpu, 1, DType> out = outputs[0].FlatTo1D<xpu, DType>(s);
       Copy(out, in, s);
   } else if (req[0] != kWriteInplace) {
-    LOG(FATAL) << "CropAssign only supports kWriteTo and kWriteInplace";
+    LOG(FATAL) << "_slice_assign only supports kWriteTo and kWriteInplace";
-  SliceAssignImpl<xpu>(s, param, outputs[0], inputs[1]);
+  const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
+  MXNET_NDIM_SWITCH(data.ndim(), ndim, {
+    common::StaticArray<int, ndim> begin, end, step;
+    GetIndexRange(data.shape_, param.begin, param.end, param.step, &begin, &end, &step);
+    MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
+      mxnet_op::Kernel<slice_assign<ndim>, xpu>::Launch(s, val.shape_.FlatTo2D()[0],
+          out.dptr<DType>(), val.dptr<DType>(), req[0],
+          out.shape_.get<ndim>(), val.shape_.get<ndim>(), begin, step);
+    })
+  })
-struct SimpleCropAssignScalarParam : public dmlc::Parameter<SimpleCropAssignScalarParam> {
+struct SliceAssignScalarParam : public dmlc::Parameter<SliceAssignScalarParam> {
   real_t scalar;
-  TShape begin, end;
-  DMLC_DECLARE_PARAMETER(SimpleCropAssignScalarParam) {
+  nnvm::Tuple<dmlc::optional<int>> begin, end;
+  nnvm::Tuple<dmlc::optional<int>> step;
+  DMLC_DECLARE_PARAMETER(SliceAssignScalarParam) {
     .describe("The scalar value for assignment.");
-    .describe("starting coordinates");
+    .describe("starting indices for the slice operation, supports negative indices.");
-    .describe("ending coordinates");
+    .describe("ending indices for the slice operation, supports negative indices.");
+    .set_default(nnvm::Tuple<dmlc::optional<int>>())
+    .describe("step for the slice operation, supports negative values.");
+  }
+inline bool SliceAssignScalarOpShape(const nnvm::NodeAttrs& attrs,
+                                    std::vector<TShape> *in_attrs,
+                                    std::vector<TShape> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  const TShape& dshape = (*in_attrs)[0];
+  if (dshape.ndim() == 0U || dshape.Size() == 0U) return false;
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape);
+  return true;
+template<int ndim>
+struct slice_assign_scalar {
+  // i is the i-th row after flattening out into 2D tensor
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, DType* out, const DType val,
+                                  const OpReqType req,
+                                  const mshadow::Shape<ndim> oshape,
+                                  const mshadow::Shape<ndim> vshape,
+                                  const common::StaticArray<int, ndim> begin,
+                                  const common::StaticArray<int, ndim> step) {
+    const int data_last_dim_size = oshape[ndim-1];
+    const int out_last_dim_size = vshape[ndim-1];
+    const int step_last_dim = step[ndim-1];
+    const int begin_last_dim = begin[ndim-1];
+    for (int j = 0; j < out_last_dim_size; ++j) {
+      int irow = 0;  // row id of flattend 2D out
+      int stride = 1;
+      int idx = i;
+      #pragma unroll
+      for (int k = ndim - 2; k >= 0; --k) {
+        irow += stride * ((idx % vshape[k]) * step[k] + begin[k]);
+        idx /= vshape[k];
+        stride *= oshape[k];
+      }
+      KERNEL_ASSIGN(out[irow * data_last_dim_size + j * step_last_dim + begin_last_dim], req, val);
+    }
 template<typename xpu>
-void CropAssignScalar(const nnvm::NodeAttrs& attrs,
-                      const OpContext& ctx,
-                      const std::vector<TBlob>& inputs,
-                      const std::vector<OpReqType>& req,
-                      const std::vector<TBlob>& outputs) {
+void SliceAssignScalarOpForward(const nnvm::NodeAttrs& attrs,
+                                const OpContext& ctx,
+                                const std::vector<TBlob>& inputs,
+                                const std::vector<OpReqType>& req,
+                                const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
   using namespace mshadow;
-  using namespace mshadow::expr;
-  const SimpleCropAssignScalarParam& param = nnvm::get<SimpleCropAssignScalarParam>(attrs.parsed);
   Stream<xpu> *s = ctx.get_stream<xpu>();
+  const TBlob& data = inputs[0];
+  const TBlob& out = outputs[0];
   if (req[0] == kWriteTo) {
     MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
       Tensor<xpu, 1, DType> in = inputs[0].FlatTo1D<xpu, DType>(s);
@@ -1031,63 +1012,24 @@ void CropAssignScalar(const nnvm::NodeAttrs& attrs,
       Copy(out, in, s);
   } else if (req[0] != kWriteInplace) {
-    LOG(FATAL) << "CropAssignScalar only supports kWriteTo and kWriteInplace";
+    LOG(FATAL) << "_crop_assign_scalar only supports kWriteTo and kWriteInplace";
-  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-    switch (outputs[0].shape_.ndim()) {
-      case 0:
-        break;
-      case 1: {
-        Tensor<xpu, 1, DType> out = outputs[0].get<xpu, 1, DType>(s);
-        slice(out, param.begin.get<1>(), param.end.get<1>()) = \
-            static_cast<DType>(param.scalar);
-        break;
-      }
-      case 2: {
-        Tensor<xpu, 2, DType> out = outputs[0].get<xpu, 2, DType>(s);
-        slice(out, param.begin.get<2>(), param.end.get<2>()) = \
-            static_cast<DType>(param.scalar);
-        break;
-      }
-      case 3: {
-        Tensor<xpu, 3, DType> out = outputs[0].get<xpu, 3, DType>(s);
-        slice(out, param.begin.get<3>(), param.end.get<3>()) = \
-            static_cast<DType>(param.scalar);
-        break;
-      }
-      case 4: {
-        Tensor<xpu, 4, DType> out = outputs[0].get<xpu, 4, DType>(s);
-        slice(out, param.begin.get<4>(), param.end.get<4>()) = \
-            static_cast<DType>(param.scalar);
-        break;
-      }
-      case 5: {
-        Tensor<xpu, 5, DType> out = outputs[0].get<xpu, 5, DType>(s);
-        slice(out, param.begin.get<5>(), param.end.get<5>()) = \
-            static_cast<DType>(param.scalar);
-        break;
-      }
-      default:
-        LOG(FATAL) << "CropAssign supports at most 5 dimensions";
-        break;
+  TShape vshape = data.shape_;
+  const SliceAssignScalarParam& param = nnvm::get<SliceAssignScalarParam>(attrs.parsed);
+  MXNET_NDIM_SWITCH(data.ndim(), ndim, {
+    common::StaticArray<int, ndim> begin, end, step;
+    GetIndexRange(data.shape_, param.begin, param.end, param.step, &begin, &end, &step);
+    for (index_t i = 0; i < param.begin.ndim(); ++i) {
+      const int b = begin[i], e = end[i], s = step[i];
+      SetSliceOpOutputDimSize(i, b, e, s, &vshape);
-  });
-inline bool CropAssignScalarShape(const nnvm::NodeAttrs& attrs,
-                                  std::vector<TShape> *in_attrs,
-                                  std::vector<TShape> *out_attrs) {
-  const SimpleCropAssignScalarParam& param = nnvm::get<SimpleCropAssignScalarParam>(attrs.parsed);
-  TShape& lshape = (*in_attrs)[0];
-  CHECK_EQ(lshape.ndim(), param.begin.ndim());
-  CHECK_EQ(lshape.ndim(), param.end.ndim());
-  for (index_t i = 0; i < lshape.ndim(); ++i) {
-    CHECK_LT(param.begin[i], param.end[i]);
-    CHECK_LE(param.end[i], lshape[i]);
-  }
-  SHAPE_ASSIGN_CHECK(*out_attrs, 0, lshape);
-  return true;
+    MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
+      mxnet_op::Kernel<slice_assign_scalar<ndim>, xpu>::Launch(s, vshape.FlatTo2D()[0],
+          out.dptr<DType>(), static_cast<DType>(param.scalar), req[0],
+          out.shape_.get<ndim>(), vshape.get<ndim>(), begin, step);
+    })
+  })
 struct SliceAxisParam : public dmlc::Parameter<SliceAxisParam> {
diff --git a/src/operator/tensor/ b/src/operator/tensor/
index 7f109b69d8..cba9efd1a9 100644
--- a/src/operator/tensor/
+++ b/src/operator/tensor/
@@ -31,7 +31,7 @@ DMLC_REGISTER_PARAMETER(ReshapeParam);
@@ -323,18 +323,19 @@ NNVM_REGISTER_OP(_slice_assign)
     return std::vector<std::string>{"lhs", "rhs"};
-.set_attr<nnvm::FInferShape>("FInferShape", SliceAssignShape)
+.set_attr<nnvm::FInferShape>("FInferShape", SliceAssignOpShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
   [](const NodeAttrs& attrs){
     return std::vector<std::pair<int, int> >{{0, 0}};
-.set_attr<FCompute>("FCompute<cpu>", SliceAssign<cpu>)
+.set_attr<FCompute>("FCompute<cpu>", SliceAssignOpForward<cpu>)
 .add_argument("lhs", "NDArray-or-Symbol", "Source input")
 .add_argument("rhs", "NDArray-or-Symbol", "value to assign")
 .MXNET_DESCRIBE("(Assign the scalar to a cropped subset of the input.\n\n"
@@ -342,16 +343,16 @@ NNVM_REGISTER_OP(_crop_assign_scalar)
-.set_attr<nnvm::FInferShape>("FInferShape", CropAssignScalarShape)
+.set_attr<nnvm::FInferShape>("FInferShape", SliceAssignScalarOpShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
   [](const NodeAttrs& attrs){
     return std::vector<std::pair<int, int> >{{0, 0}};
-.set_attr<FCompute>("FCompute<cpu>", CropAssignScalar<cpu>)
+.set_attr<FCompute>("FCompute<cpu>", SliceAssignScalarOpForward<cpu>)
 .add_argument("data", "NDArray-or-Symbol", "Source input")
 .describe(R"code(Slices along a given axis.
diff --git a/src/operator/tensor/ b/src/operator/tensor/
index 3866fc419f..237b87296c 100644
--- a/src/operator/tensor/
+++ b/src/operator/tensor/
@@ -45,10 +45,10 @@ NNVM_REGISTER_OP(_backward_slice)
 .set_attr<FCompute>("FCompute<gpu>", SliceOpBackward<gpu>);
-.set_attr<FCompute>("FCompute<gpu>", SliceAssign<gpu>);
+.set_attr<FCompute>("FCompute<gpu>", SliceAssignOpForward<gpu>);
-.set_attr<FCompute>("FCompute<gpu>", CropAssignScalar<gpu>);
+.set_attr<FCompute>("FCompute<gpu>", SliceAssignScalarOpForward<gpu>);
 .set_attr<FCompute>("FCompute<gpu>", SliceAxis<gpu>);
diff --git a/tests/python/unittest/ b/tests/python/unittest/
index 28aab7a4cb..5bdadc4a29 100644
--- a/tests/python/unittest/
+++ b/tests/python/unittest/
@@ -24,6 +24,7 @@
 from mxnet.test_utils import *
 from numpy.testing import assert_allclose
 import unittest
+import mxnet.autograd
 def check_with_uniform(uf, arg_shapes, dim=None, npuf=None, rmin=-10, type_list=[np.float32]):
     """check function consistency with uniform random numbers"""
@@ -88,14 +89,6 @@ def test_ndarray_setitem():
     x_np[1] = 1
     assert same(x.asnumpy(), x_np)
-    # all-dim indexing
-    x = mx.nd.zeros(shape)
-    val = mx.nd.ones((3, 2, 1))
-    x[:, 1:3, 1] = val
-    x_np = np.zeros(shape, dtype=x.dtype)
-    x_np[:, 1:3, 1:2] = val.asnumpy()
-    assert same(x.asnumpy(), x_np)
     # short all-dim indexing
     x = mx.nd.zeros(shape)
     val = mx.nd.ones((3, 2))
@@ -804,6 +797,135 @@ def test_bool():
     assert not bool(mx.nd.zeros((1,)))
     assert bool(mx.nd.ones((1,)))
+def test_ndarray_indexing():
+    def test_getitem(np_array, index, is_scalar=False):
+        """`is_scalar` indicates whether we should expect a scalar for the result.
+        If so, the indexed array of NDArray should call asscalar to compare
+        with numpy's indexed array."""
+        np_index = index
+        if isinstance(index, mx.nd.NDArray):
+            np_index = index.asnumpy()
+        if isinstance(index, tuple):
+            np_index = []
+            for idx in index:
+                if isinstance(idx, mx.nd.NDArray):
+                    np_index.append(idx.asnumpy())
+                else:
+                    np_index.append(idx)
+            np_index = tuple(np_index)
+        np_indexed_array = np_array[np_index]
+        mx_array = mx.nd.array(np_array, dtype=np_array.dtype)
+        mx_indexed_array = mx_array[index]
+        if is_scalar:
+            mx_indexed_array = mx_indexed_array.asscalar()
+        else:
+            mx_indexed_array = mx_indexed_array.asnumpy()
+        assert same(np_indexed_array, mx_indexed_array), 'Failed with index=%s' % str(index)
+    def test_setitem(np_array, index, is_scalar):
+        def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None):
+            if np_value is not None:
+                np_array[np_index] = np_value
+            else:
+                np_array[np_index] = mx_value
+            mx_array[mx_index] = mx_value
+            assert same(np_array, mx_array.asnumpy())
+        np_index = index
+        if isinstance(index, mx.nd.NDArray):
+            np_index = index.asnumpy()
+        if isinstance(index, tuple):
+            np_index = []
+            for idx in index:
+                if isinstance(idx, mx.nd.NDArray):
+                    np_index.append(idx.asnumpy())
+                else:
+                    np_index.append(idx)
+            np_index = tuple(np_index)
+        mx_array = mx.nd.array(np_array, dtype=np_array.dtype)
+        np_array = mx_array.asnumpy()
+        if is_scalar:
+            # test value is a numeric type
+            assert_same(np_array, np_index, mx_array, index, np.random.randint(low=-10000, high=0))
+            value_nd = [np.random.randint(low=-10000, high=0)]
+            assert_same(np_array, np_index, mx_array, index, value_nd, value_nd[0])
+        else:
+            indexed_array_shape = np_array[np_index].shape
+            np_indexed_array = np.random.randint(low=-10000, high=0, size=indexed_array_shape)
+            # test value is a numpy array without broadcast
+            assert_same(np_array, np_index, mx_array, index, np_indexed_array)
+            # test value is an numeric_type
+            assert_same(np_array, np_index, mx_array, index, np.random.randint(low=-10000, high=0))
+            if len(indexed_array_shape) > 1:
+                # test numpy array with broadcast
+                assert_same(np_array, np_index, mx_array, index,
+                            np.random.randint(low=-10000, high=0, size=(indexed_array_shape[-1],)))
+                # test list with broadcast
+                assert_same(np_array, np_index, mx_array, index,
+                            [np.random.randint(low=-10000, high=0)] * indexed_array_shape[-1])
+    def test_getitem_autograd(np_array, index):
+        x = mx.nd.array(np_array, dtype=np_array.dtype)
+        x.attach_grad()
+        with mx.autograd.record():
+            y = x[index]
+        y.backward()
+        value = mx.nd.ones_like(y)
+        x_grad = mx.nd.zeros_like(x)
+        x_grad[index] = value
+        assert same(x_grad.asnumpy(), x.grad.asnumpy())
+    shape = (8, 16, 9, 9)
+    np_array = np.arange(, dtype='int32').reshape(shape)
+    # index_list is a list of tuples. The tuple's first element is the index, the second one is a boolean value
+    # indicating whether we should expect the result as a scalar compared to numpy.
+    index_list = [(0, False), (5, False), (-1, False),
+                  (slice(5), False), (slice(1, 5), False), (slice(1, 5, 2), False),
+                  (slice(7, 0, -1), False), (slice(None, 6), False), (slice(None, 6, 3), False),
+                  (slice(1, None), False), (slice(1, None, 3), False), (slice(None, None, 2), False),
+                  (slice(None, None, -1), False), (slice(None, None, -2), False),
+                  ((slice(None), slice(None), 1, 8), False),
+                  ((slice(None), 2, slice(1, 5), 1), False),
+                  ((1, 2, 3), False), ((1, 2, 3, 4), True),
+                  ((slice(None, None, -1), 2, slice(1, 5), 1), False),
+                  ((slice(None, None, -1), 2, slice(1, 7, 2), 1), False),
+                  ((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), False),
+                  ((slice(1, 8, 2), 1, slice(3, 8), 2), False),
+                  ([1], False), ([1, 2], False), ([2, 1, 3], False), ([7, 5, 0, 3, 6, 2, 1], False),
+                  (np.array([6, 3], dtype=np.int32), False),
+                  (np.array([[3, 4], [0, 6]], dtype=np.int32), False),
+                  (np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32), False),
+                  (np.array([[2], [0], [1]], dtype=np.int32), False),
+                  (mx.nd.array([4, 7], dtype=np.int32), False),
+                  (mx.nd.array([[3, 6], [2, 1]], dtype=np.int32), False),
+                  (mx.nd.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32), False),
+                  ((1, [2, 3]), False), ((1, [2, 3], np.array([[3], [0]], dtype=np.int32)), False),
+                  ((1, [2], np.array([[5], [3]], dtype=np.int32), slice(None)), False),
+                  ((1, [2, 3], np.array([[6], [0]], dtype=np.int32), slice(2, 5)), False),
+                  ((1, [2, 3], np.array([[4], [7]], dtype=np.int32), slice(2, 5, 2)), False),
+                  ((1, [2], np.array([[3]], dtype=np.int32), slice(None, None, -1)), False),
+                  ((1, [2], np.array([[3]], dtype=np.int32), np.array([[5, 7], [2, 4]], dtype=np.int64)), False),
+                  ((1, [2], mx.nd.array([[4]], dtype=np.int32), mx.nd.array([[1, 3], [5, 7]], dtype='int64')),
+                   False),
+                  ([0], False), ([0, 1], False), ([1, 2, 3], False), ([2, 0, 5, 6], False),
+                  (([1, 1], [2, 3]), False), (([1], [4], [5]), False), (([1], [4], [5], [6]), False),
+                  (([[1]], [[2]]), False), (([[1]], [[2]], [[3]], [[4]]), False),
+                  ((slice(0, 2), [[1], [6]], slice(0, 2), slice(0, 5, 2)), False),
+                  (([[[[1]]]], [[1]], slice(0, 3), [1, 5]), False),
+                  (([[[[1]]]], 3, slice(0, 3), [1, 3]), False),
+                  (([[[[1]]]], 3, slice(0, 3), 0), False),
+                  (([[[[1]]]], [[2], [12]], slice(0, 3), slice(None)), False),
+                  (([1, 2], slice(3, 5), [2, 3], [3, 4]), False),
+                  (([1, 2], slice(3, 5), (2, 3), [3, 4]), False)]
+    for index in index_list:
+        test_getitem(np_array, index[0], index[1])
+        test_setitem(np_array, index[0], index[1])
+        test_getitem_autograd(np_array, index[0])
 if __name__ == '__main__':
