You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/10/13 04:04:33 UTC

[incubator-mxnet] 02/03: numpy eye op (#16132)

This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch numpy_pr_merge
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit ab60e34b6d2a2bce3f3eb39433a0bf18f8e0eb17
Author: Jake Lee <gs...@gmail.com>
AuthorDate: Fri Oct 11 21:58:42 2019 -0700

    numpy eye op (#16132)
---
 python/mxnet/ndarray/numpy/_op.py      | 35 ++++++++++++++++-
 python/mxnet/numpy/multiarray.py       | 29 +++++++++++++-
 python/mxnet/symbol/numpy/_symbol.py   | 35 ++++++++++++++++-
 src/operator/numpy/np_init_op.cc       | 30 ++++++--------
 src/operator/numpy/np_init_op.cu       |  3 ++
 src/operator/numpy/np_init_op.h        | 72 ++++++++++++++++++++++++++++++++++
 src/operator/tensor/init_op.h          | 40 +++++++++++--------
 tests/python/unittest/test_numpy_op.py | 68 ++++++++++++++++++++++++++++++++
 8 files changed, 274 insertions(+), 38 deletions(-)

diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index 2846d2b..0bf6232 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -22,7 +22,7 @@
 from __future__ import absolute_import
 import numpy as _np
 from ...base import numeric_types
-from ...util import set_module
+from ...util import _sanity_check_params, set_module
 from ...context import current_context
 from . import _internal as _npi
 from ..ndarray import NDArray
@@ -31,7 +31,7 @@ __all__ = ['zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', 'mo
            'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs',
            'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2',
            'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
-           'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram',
+           'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye',
            'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'dstack',
            'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
            'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
@@ -789,6 +789,37 @@ def histogram(a, bins=10, range=None, normed=None, weights=None, density=None):
 
 
 @set_module('mxnet.ndarray.numpy')
+def eye(N, M=None, k=0, dtype=_np.float32, **kwargs):
+    """
+    Return a 2-D array with ones on the diagonal and zeros elsewhere.
+
+    Parameters
+    ----------
+    N : int
+        Number of rows in the output.
+    M : int, optional
+        Number of columns in the output. If None, defaults to N.
+    k : int, optional
+        Index of the diagonal: 0 (the default) refers to the main diagonal,
+        a positive value refers to an upper diagonal,
+        and a negative value to a lower diagonal.
+    dtype : data-type, optional
+        Data-type of the returned array.
+
+    Returns
+    -------
+    I : ndarray of shape (N,M)
+        An array where all elements are equal to zero,
+        except for the k-th diagonal, whose values are equal to one.
+    """
+    _sanity_check_params('eye', ['order'], kwargs)
+    ctx = kwargs.pop('ctx', current_context())
+    if ctx is None:
+        ctx = current_context()
+    return _npi.eye(N, M, k, ctx, dtype)
+
+
+@set_module('mxnet.ndarray.numpy')
 def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None):  # pylint: disable=too-many-arguments
     r"""
     Return evenly spaced numbers over a specified interval.
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 00a7709..76df87c 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -50,7 +50,7 @@ __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'full', 'add', 'subtrac
            'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
            'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
            'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
-           'tensordot', 'histogram', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
+           'tensordot', 'histogram', 'eye', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
            'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var',
            'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot',
            'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
@@ -3637,6 +3637,33 @@ def histogram(a, bins=10, range=None, normed=None, weights=None, density=None):
 
 
 @set_module('mxnet.numpy')
+def eye(N, M=None, k=0, dtype=_np.float32, **kwargs):
+    """
+    Return a 2-D array with ones on the diagonal and zeros elsewhere.
+
+    Parameters
+    ----------
+    N : int
+        Number of rows in the output.
+    M : int, optional
+        Number of columns in the output. If None, defaults to N.
+    k : int, optional
+        Index of the diagonal: 0 (the default) refers to the main diagonal,
+        a positive value refers to an upper diagonal,
+        and a negative value to a lower diagonal.
+    dtype : data-type, optional
+        Data-type of the returned array.
+
+    Returns
+    -------
+    I : ndarray of shape (N,M)
+        An array where all elements are equal to zero,
+        except for the k-th diagonal, whose values are equal to one.
+    """
+    return _mx_nd_np.eye(N, M, k, dtype, **kwargs)
+
+
+@set_module('mxnet.numpy')
 def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None):  # pylint: disable=too-many-arguments
     r"""
     Return evenly spaced numbers over a specified interval.
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index de11cfb..cbd46f3 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -23,7 +23,7 @@ import ctypes
 import numpy as _np
 from . import _op as _mx_np_op
 from ...base import _LIB, SymbolHandle, numeric_types, mx_uint
-from ...util import check_call, set_module
+from ...util import check_call, _sanity_check_params, set_module
 from ...context import current_context
 from ..symbol import Symbol
 from .._internal import _set_np_symbol_class
@@ -33,7 +33,7 @@ __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'rem
            'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp',
            'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
            'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
-           'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram',
+           'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye',
            'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'dstack',
            'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
            'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
@@ -1278,6 +1278,37 @@ def histogram(a, bins=10, range=None, normed=None, weights=None, density=None):
 
 
 @set_module('mxnet.symbol.numpy')
+def eye(N, M=None, k=0, dtype=_np.float32, **kwargs):
+    """
+    Return a 2-D array with ones on the diagonal and zeros elsewhere.
+
+    Parameters
+    ----------
+    N : int
+        Number of rows in the output.
+    M : int, optional
+        Number of columns in the output. If None, defaults to N.
+    k : int, optional
+        Index of the diagonal: 0 (the default) refers to the main diagonal,
+        a positive value refers to an upper diagonal,
+        and a negative value to a lower diagonal.
+    dtype : data-type, optional
+        Data-type of the returned array.
+
+    Returns
+    -------
+    I : ndarray of shape (N,M)
+        An array where all elements are equal to zero,
+        except for the k-th diagonal, whose values are equal to one.
+    """
+    _sanity_check_params('eye', ['order'], kwargs)
+    ctx = kwargs.pop('ctx', current_context())
+    if ctx is None:
+        ctx = current_context()
+    return _npi.eye(N, M, k, ctx, dtype)
+
+
+@set_module('mxnet.symbol.numpy')
 def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments
     r"""
     Return evenly spaced numbers over a specified interval.
diff --git a/src/operator/numpy/np_init_op.cc b/src/operator/numpy/np_init_op.cc
index 2477573..7e65d6c 100644
--- a/src/operator/numpy/np_init_op.cc
+++ b/src/operator/numpy/np_init_op.cc
@@ -22,6 +22,7 @@
  * \file np_init_op.cc
  * \brief CPU Implementation of numpy init op
  */
+
 #include "../tensor/init_op.h"
 #include "../tensor/elemwise_unary_op.h"
 #include "./np_init_op.h"
@@ -29,6 +30,8 @@
 namespace mxnet {
 namespace op {
 
+
+DMLC_REGISTER_PARAMETER(NumpyEyeParam);
 DMLC_REGISTER_PARAMETER(IndicesOpParam);
 
 inline bool NumpyIndicesShape(const nnvm::NodeAttrs& attrs,
@@ -117,23 +120,6 @@ NNVM_REGISTER_OP(_np_ones_like)
 .add_argument("a", "NDArray-or-Symbol",
               "The shape and data-type of a define these same attributes of the returned array.");
 
-bool NumpyRangeShape(const nnvm::NodeAttrs& attrs,
-                     mxnet::ShapeVector* in_shapes,
-                     mxnet::ShapeVector* out_shapes) {
-  const RangeParam& param = nnvm::get<RangeParam>(attrs.parsed);
-  CHECK_EQ(in_shapes->size(), 0U);
-  CHECK_EQ(out_shapes->size(), 1U);
-  CHECK_NE(param.step, 0) << "_npi_arange does not support step=0";
-  CHECK_EQ(param.repeat, 1) << "_npi_arange only supports repeat=1, received " << param.repeat;
-  CHECK(param.stop.has_value()) << "_npi_arange requires stop to have a value";
-  double out_size = std::ceil((param.stop.value() - param.start) / param.step);
-  if (out_size < 0) {
-    out_size = 0;
-  }
-  SHAPE_ASSIGN_CHECK(*out_shapes, 0, mxnet::TShape({static_cast<nnvm::dim_t>(out_size)}));
-  return true;
-}
-
 NNVM_REGISTER_OP(_npi_arange)
 .set_num_inputs(0)
 .set_num_outputs(1)
@@ -143,6 +129,16 @@ NNVM_REGISTER_OP(_npi_arange)
 .set_attr<FCompute>("FCompute<cpu>", RangeCompute<cpu, RangeParam>)
 .add_arguments(RangeParam::__FIELDS__());
 
+NNVM_REGISTER_OP(_npi_eye)
+.describe("Return a 2-D array with ones on the diagonal and zeros elsewhere.")
+.set_num_inputs(0)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<NumpyEyeParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", NumpyEyeShape)
+.set_attr<nnvm::FInferType>("FInferType", InitType<NumpyEyeParam>)
+.set_attr<FCompute>("FCompute<cpu>", NumpyEyeFill<cpu>)
+.add_arguments(NumpyEyeParam::__FIELDS__());
+
 NNVM_REGISTER_OP(_npi_indices)
 .describe("Return an array representing the indices of a grid.")
 .set_num_inputs(0)
diff --git a/src/operator/numpy/np_init_op.cu b/src/operator/numpy/np_init_op.cu
index e68dd9a..9a2b236 100644
--- a/src/operator/numpy/np_init_op.cu
+++ b/src/operator/numpy/np_init_op.cu
@@ -47,6 +47,9 @@ NNVM_REGISTER_OP(_np_ones_like)
 NNVM_REGISTER_OP(_npi_arange)
 .set_attr<FCompute>("FCompute<gpu>", RangeCompute<gpu, RangeParam>);
 
+NNVM_REGISTER_OP(_npi_eye)
+.set_attr<FCompute>("FCompute<gpu>", NumpyEyeFill<gpu>);
+
 NNVM_REGISTER_OP(_npi_indices)
 .set_attr<FCompute>("FCompute<gpu>", IndicesCompute<gpu>);
 
diff --git a/src/operator/numpy/np_init_op.h b/src/operator/numpy/np_init_op.h
index 3e1c345..9eb83e8 100644
--- a/src/operator/numpy/np_init_op.h
+++ b/src/operator/numpy/np_init_op.h
@@ -35,6 +35,34 @@
 namespace mxnet {
 namespace op {
 
+struct NumpyEyeParam : public dmlc::Parameter<NumpyEyeParam> {
+  nnvm::dim_t N;
+  dmlc::optional<nnvm::dim_t> M;
+  nnvm::dim_t k;
+  std::string ctx;
+  int dtype;
+  DMLC_DECLARE_PARAMETER(NumpyEyeParam) {
+    DMLC_DECLARE_FIELD(N)
+    .describe("Number of rows in the output.");
+    DMLC_DECLARE_FIELD(M)
+    .set_default(dmlc::optional<nnvm::dim_t>())
+    .describe("Number of columns in the output. If None, defaults to N.");
+    DMLC_DECLARE_FIELD(k)
+    .set_default(0)
+    .describe("Index of the diagonal. 0 (the default) refers to the main diagonal,"
+              "a positive value refers to an upper diagonal."
+              "and a negative value to a lower diagonal.");
+    DMLC_DECLARE_FIELD(ctx)
+    .set_default("")
+    .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
+              "Only used for imperative calls.");
+    DMLC_DECLARE_FIELD(dtype)
+    .set_default(mshadow::kFloat32)
+    MXNET_ADD_ALL_TYPES
+    .describe("Data-type of the returned array.");
+  }
+};
+
 struct IndicesOpParam : public dmlc::Parameter<IndicesOpParam> {
   mxnet::TShape dimensions;
   int dtype;
@@ -52,6 +80,50 @@ struct IndicesOpParam : public dmlc::Parameter<IndicesOpParam> {
   }
 };
 
+inline bool NumpyRangeShape(const nnvm::NodeAttrs& attrs,
+                            mxnet::ShapeVector* in_shapes,
+                            mxnet::ShapeVector* out_shapes) {
+  const RangeParam& param = nnvm::get<RangeParam>(attrs.parsed);
+  CHECK_EQ(in_shapes->size(), 0U);
+  CHECK_EQ(out_shapes->size(), 1U);
+  CHECK_NE(param.step, 0) << "_npi_arange does not support step=0";
+  CHECK_EQ(param.repeat, 1) << "_npi_arange only supports repeat=1, received " << param.repeat;
+  CHECK(param.stop.has_value()) << "_npi_arange requires stop to have a value";
+  double out_size = std::ceil((param.stop.value() - param.start) / param.step);
+  if (out_size < 0) {
+    out_size = 0;
+  }
+  SHAPE_ASSIGN_CHECK(*out_shapes, 0, mxnet::TShape({static_cast<nnvm::dim_t>(out_size)}));
+  return true;
+}
+
+inline bool NumpyEyeShape(const nnvm::NodeAttrs& attrs,
+                          mxnet::ShapeVector *in_attrs,
+                          mxnet::ShapeVector *out_attrs) {
+  const NumpyEyeParam& param = nnvm::get<NumpyEyeParam>(attrs.parsed);
+  CHECK_EQ(in_attrs->size(), 0U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  nnvm::dim_t M = param.M.has_value() ? param.M.value() : param.N;
+  CHECK(param.N >= 0) << "negative dimensions are not allowed. N is " << param.N;
+  CHECK(M >= 0) << "negative dimensions are not allowed. M is " << M;
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape2(param.N, M));
+
+  return out_attrs->at(0).ndim() != 0U;
+}
+template<typename xpu>
+void NumpyEyeFill(const nnvm::NodeAttrs& attrs,
+                  const OpContext& ctx,
+                  const std::vector<TBlob>& inputs,
+                  const std::vector<OpReqType>& req,
+                  const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 0U);
+  CHECK_EQ(outputs.size(), 1U);
+  if (outputs[0].shape_.Size() == 0) return;  // zero-size tensor
+  const NumpyEyeParam& param = nnvm::get<NumpyEyeParam>(attrs.parsed);
+  const nnvm::dim_t num_cols = param.M.has_value() ? param.M.value() : param.N;
+  EyeFillImpl<xpu>(outputs[0], ctx, req, num_cols, param.N, param.k);
+}
+
 template<int req>
 struct indices_fwd {
   template<typename DType>
diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h
index 8e8896e..d2107a1 100644
--- a/src/operator/tensor/init_op.h
+++ b/src/operator/tensor/init_op.h
@@ -483,6 +483,29 @@ void FillComputeZerosEx(const nnvm::NodeAttrs& attrs,
   }
 }
 
+template<typename xpu>
+inline void EyeFillImpl(const TBlob& out_data,
+                        const OpContext& ctx,
+                        const std::vector<OpReqType>& req,
+                        const nnvm::dim_t num_cols,
+                        const nnvm::dim_t N,
+                        const nnvm::dim_t k) {
+  using namespace mxnet_op;
+  const nnvm::dim_t cnnz = std::max(num_cols - std::abs(k), (nnvm::dim_t)0);
+  const nnvm::dim_t rnnz = std::max(N - std::abs(k), (nnvm::dim_t)0);
+  const nnvm::dim_t nnz = k > 0 ? std::min(cnnz, N) :
+                          std::min(rnnz, num_cols);
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
+      MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+        Fill(s, out_data, req[0], static_cast<DType>(0));
+        if (nnz > 0) {
+          Kernel<eye_dns_fill<req_type>, xpu>::Launch(s, nnz, out_data.dptr<DType>(),
+            std::max(static_cast<nnvm::dim_t>(0), k), k, num_cols);
+        }
+      });
+  });
+}
 
 template<typename xpu>
 void EyeFill(const nnvm::NodeAttrs& attrs,
@@ -493,25 +516,10 @@ void EyeFill(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(inputs.size(), 0U);
   CHECK_EQ(outputs.size(), 1U);
   CHECK_EQ(req.size(), 1U);
-  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
   const EyeParam& param = nnvm::get<EyeParam>(attrs.parsed);
   const TBlob& out_data = outputs[0];
   const nnvm::dim_t num_cols = param.M > 0 ? param.M : param.N;
-
-  const nnvm::dim_t cnnz = std::max(num_cols - std::abs(param.k), (nnvm::dim_t)0);
-  const nnvm::dim_t rnnz = std::max(param.N - std::abs(param.k), (nnvm::dim_t)0);
-  const nnvm::dim_t nnz = param.k > 0 ? std::min(cnnz, param.N) :
-                                        std::min(rnnz, num_cols);
-  using namespace mxnet_op;
-  MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
-    MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
-      Fill(s, out_data, req[0], static_cast<DType>(0));
-      if (nnz > 0) {
-        Kernel<eye_dns_fill<req_type>, xpu>::Launch(s, nnz, out_data.dptr<DType>(),
-          std::max(static_cast<nnvm::dim_t>(0), param.k), param.k, num_cols);
-      }
-    });
-  });
+  EyeFillImpl<xpu>(out_data, ctx, req, num_cols, param.N, param.k);
 }
 
 
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 978d5d3..3942023 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -2216,6 +2216,74 @@ def test_np_choice():
 
 @with_seed()
 @use_np
+def test_np_eye():
+    configs = [
+        4,
+        1000,
+        (4, 3),
+        (5, None),
+        (4, None, 1),
+        (2, 2, 1),
+        (4, 6, 1),
+        (7, 3, -3),
+        (3, 2, -2),
+        (4, 0),
+        (0, 0),
+        (0, 3),
+        (0, 0, -2)
+    ]
+    exception_configs = [
+        -1,
+        -1000,
+        (-2, None),
+        (1, -1)
+    ]
+    dtypes = ['int32', 'float16', 'float32', 'float64', None]
+    for config in configs:
+        for dtype in dtypes:
+            if isinstance(config, tuple):
+                mx_ret = np.eye(*config, dtype=dtype)
+                np_ret = _np.eye(*config, dtype=dtype)
+            else:
+                mx_ret = np.eye(config, dtype=dtype)
+                np_ret = _np.eye(config, dtype=dtype)
+            assert same(mx_ret.asnumpy(), np_ret)
+    # check for exception input
+    for config in exception_configs:
+        if isinstance(config, tuple):
+            assertRaises(MXNetError, np.eye, *config)
+        else:
+            assertRaises(MXNetError, np.eye, config)
+
+    class TestEye(HybridBlock):
+        def __init__(self, N, M=None, k=0, dtype=None):
+            super(TestEye, self).__init__()
+            self._N = N
+            self._M = M
+            self._k = k
+            self._dtype = dtype
+
+        def hybrid_forward(self, F, x):
+            return x + F.np.eye(self._N, self._M, self._k, dtype=self._dtype)
+
+    for dtype in dtypes:
+        x = np.zeros(shape=(), dtype=dtype)
+        for config in configs:
+            for hybridize in [False, True]:
+                if isinstance(config, tuple):
+                    net = TestEye(*config, dtype=dtype)
+                    np_out = _np.eye(*config, dtype=dtype)
+                else:
+                    net = TestEye(config, dtype=dtype)
+                    np_out = _np.eye(config, dtype=dtype)
+                if hybridize:
+                    net.hybridize()
+                mx_out = net(x)
+                assert same(mx_out.asnumpy(), np_out)
+
+
+@with_seed()
+@use_np
 def test_np_indices():
     dtypes = ['int32', 'int64', 'float16', 'float32', 'float64']
     shapes = [