You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/07/18 00:00:52 UTC

[incubator-mxnet] 16/42: [numpy] Fix d2l performance regression (#15173)

This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit bf176db453cae552b3c7609147a7fbf60b54b7eb
Author: reminisce <wu...@gmail.com>
AuthorDate: Fri Jun 7 08:48:13 2019 -0700

    [numpy] Fix d2l performance regression (#15173)
    
    * Add np array adapter decorator for layers
    
    * Fix performance regression caused by too many conversions between nd.NDArray and np.ndarray
    
    * Fix pylint
    
    * Fix test backward compatibility issue
    
    * Fix test_lambda
---
 python/mxnet/gluon/data/vision/transforms.py   |  8 ++---
 python/mxnet/gluon/loss.py                     | 50 +++++++++++++-------------
 python/mxnet/gluon/nn/activations.py           |  8 ++---
 python/mxnet/gluon/nn/basic_layers.py          | 23 +++++++-----
 python/mxnet/gluon/utils.py                    | 38 ++++++++++++++++----
 python/mxnet/ndarray/ndarray.py                |  4 +--
 python/mxnet/ndarray/register.py               | 32 ++++++++---------
 python/mxnet/numpy/multiarray.py               | 50 ++++++++++++--------------
 python/mxnet/numpy_extension/__init__.py       |  1 -
 python/mxnet/optimizer/optimizer.py            |  4 +--
 python/mxnet/symbol/numpy/_symbol.py           |  2 +-
 python/mxnet/symbol/register.py                | 18 +++++-----
 python/mxnet/symbol/symbol.py                  |  2 +-
 python/mxnet/test_utils.py                     |  2 +-
 python/mxnet/util.py                           |  8 +++++
 src/operator/numpy/np_matrix_op.cu             |  3 ++
 src/operator/tensor/elemwise_unary_op_basic.cc |  1 +
 src/operator/tensor/matrix_op.cc               |  1 +
 tests/python/unittest/test_numpy_ndarray.py    | 21 ++++++-----
 tests/python/unittest/test_numpy_op.py         | 25 ++++++++++---
 20 files changed, 174 insertions(+), 127 deletions(-)

diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py
index 0e90c17..2648997 100644
--- a/python/mxnet/gluon/data/vision/transforms.py
+++ b/python/mxnet/gluon/data/vision/transforms.py
@@ -23,7 +23,7 @@ from ...block import Block, HybridBlock
 from ...nn import Sequential, HybridSequential
 from .... import image
 from ....base import numeric_types
-from ....util import is_np_array
+from ...utils import _adapt_np_array
 
 
 class Compose(Sequential):
@@ -134,11 +134,9 @@ class ToTensor(HybridBlock):
     def __init__(self):
         super(ToTensor, self).__init__()
 
+    @_adapt_np_array
     def hybrid_forward(self, F, x):
-        if is_np_array():
-            x = x.as_classic_ndarray()
-        out = F.image.to_tensor(x)
-        return out.as_np_ndarray() if is_np_array() else out
+        return F.image.to_tensor(x)
 
 
 class Normalize(HybridBlock):
diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py
index 8cf41a2..79a5981 100644
--- a/python/mxnet/gluon/loss.py
+++ b/python/mxnet/gluon/loss.py
@@ -29,7 +29,8 @@ import numpy as np
 from .. import ndarray
 from ..base import numeric_types
 from .block import HybridBlock
-from .utils import _to_classic_arrays, _to_np_arrays
+from .utils import _adapt_np_array
+from ..util import is_np_array
 
 
 def _apply_weighting(F, loss, weight=None, sample_weight=None):
@@ -54,7 +55,10 @@ def _apply_weighting(F, loss, weight=None, sample_weight=None):
         Weighted loss
     """
     if sample_weight is not None:
-        loss = F.broadcast_mul(loss, sample_weight)
+        if is_np_array():
+            loss = loss * sample_weight
+        else:
+            loss = F.broadcast_mul(loss, sample_weight)
 
     if weight is not None:
         assert isinstance(weight, numeric_types), "weight must be a number"
@@ -65,7 +69,11 @@ def _apply_weighting(F, loss, weight=None, sample_weight=None):
 
 def _reshape_like(F, x, y):
     """Reshapes x to the same shape as y."""
-    return x.reshape(y.shape) if F is ndarray else F.reshape_like(x, y)
+    if F is ndarray:
+        return x.reshape(y.shape)
+    elif is_np_array():
+        F = F.npx
+    return F.reshape_like(x, y)
 
 
 class Loss(HybridBlock):
@@ -136,14 +144,16 @@ class L2Loss(Loss):
         super(L2Loss, self).__init__(weight, batch_axis, **kwargs)
 
     def hybrid_forward(self, F, pred, label, sample_weight=None):
-        # TODO(junwu): This is a temp solution to reuse legacy ops for np.ndarray.
-        # We should rewrite this with np/npx ops.
-        pred, label, sample_weight = _to_classic_arrays(pred, label, sample_weight)
         label = _reshape_like(F, label, pred)
-        loss = F.square(label - pred)
+        loss = F.np.square(label - pred) if is_np_array() else F.square(label - pred)
         loss = _apply_weighting(F, loss, self._weight / 2, sample_weight)
-        out = F.mean(loss, axis=self._batch_axis, exclude=True)
-        return _to_np_arrays(out)
+        if is_np_array():
+            if F is ndarray:
+                return F.np.mean(loss, axis=tuple(range(1, loss.ndim)))
+            else:
+                return F.npx.batch_flatten(loss).mean(axis=1)
+        else:
+            return F.mean(loss, axis=self._batch_axis, exclude=True)
 
 
 class L1Loss(Loss):
@@ -178,15 +188,12 @@ class L1Loss(Loss):
     def __init__(self, weight=None, batch_axis=0, **kwargs):
         super(L1Loss, self).__init__(weight, batch_axis, **kwargs)
 
+    @_adapt_np_array
     def hybrid_forward(self, F, pred, label, sample_weight=None):
-        # TODO(junwu): This is a temp solution to reuse legacy ops for np.ndarray.
-        # We should rewrite this with np/npx ops.
-        pred, label, sample_weight = _to_classic_arrays(pred, label, sample_weight)
         label = _reshape_like(F, label, pred)
         loss = F.abs(label - pred)
         loss = _apply_weighting(F, loss, self._weight, sample_weight)
-        out = F.mean(loss, axis=self._batch_axis, exclude=True)
-        return _to_np_arrays(out)
+        return F.mean(loss, axis=self._batch_axis, exclude=True)
 
 
 class SigmoidBinaryCrossEntropyLoss(Loss):
@@ -251,11 +258,8 @@ class SigmoidBinaryCrossEntropyLoss(Loss):
             weight, batch_axis, **kwargs)
         self._from_sigmoid = from_sigmoid
 
+    @_adapt_np_array
     def hybrid_forward(self, F, pred, label, sample_weight=None, pos_weight=None):
-        # TODO(junwu): This is a temp solution to reuse legacy ops for np.ndarray.
-        # We should rewrite this with np/npx ops.
-        pred, label, sample_weight, pos_weight =\
-            _to_classic_arrays(pred, label, sample_weight, pos_weight)
         label = _reshape_like(F, label, pred)
         if not self._from_sigmoid:
             if pos_weight is None:
@@ -277,8 +281,7 @@ class SigmoidBinaryCrossEntropyLoss(Loss):
                 loss = -(F.broadcast_mul(F.log(pred + eps) * label, pos_weight)
                          + F.log(1. - pred + eps) * (1. - label))
         loss = _apply_weighting(F, loss, self._weight, sample_weight)
-        out = F.mean(loss, axis=self._batch_axis, exclude=True)
-        return _to_np_arrays(out)
+        return F.mean(loss, axis=self._batch_axis, exclude=True)
 
 
 SigmoidBCELoss = SigmoidBinaryCrossEntropyLoss
@@ -354,10 +357,8 @@ class SoftmaxCrossEntropyLoss(Loss):
         self._sparse_label = sparse_label
         self._from_logits = from_logits
 
+    @_adapt_np_array
     def hybrid_forward(self, F, pred, label, sample_weight=None):
-        # TODO(junwu): This is a temp solution to reuse legacy ops for np.ndarray.
-        # We should rewrite this with np/npx ops.
-        pred, label = _to_classic_arrays(pred, label)
         if not self._from_logits:
             pred = F.log_softmax(pred, self._axis)
         if self._sparse_label:
@@ -366,8 +367,7 @@ class SoftmaxCrossEntropyLoss(Loss):
             label = _reshape_like(F, label, pred)
             loss = -F.sum(pred * label, axis=self._axis, keepdims=True)
         loss = _apply_weighting(F, loss, self._weight, sample_weight)
-        out = F.mean(loss, axis=self._batch_axis, exclude=True)
-        return _to_np_arrays(out)
+        return F.mean(loss, axis=self._batch_axis, exclude=True)
 
 
 SoftmaxCELoss = SoftmaxCrossEntropyLoss
diff --git a/python/mxnet/gluon/nn/activations.py b/python/mxnet/gluon/nn/activations.py
index 04a8227..6e0e7ca 100644
--- a/python/mxnet/gluon/nn/activations.py
+++ b/python/mxnet/gluon/nn/activations.py
@@ -22,7 +22,7 @@ __all__ = ['Activation', 'LeakyReLU', 'PReLU', 'ELU', 'SELU', 'Swish', 'GELU']
 
 from ... import initializer
 from ..block import HybridBlock
-from ..utils import _to_classic_arrays, _to_np_arrays
+from ...util import is_np_array
 
 
 class Activation(HybridBlock):
@@ -49,9 +49,9 @@ class Activation(HybridBlock):
         return self._act_type
 
     def hybrid_forward(self, F, x):
-        x = _to_classic_arrays(x)
-        out = F.Activation(x, act_type=self._act_type, name='fwd')
-        return _to_np_arrays(out)
+        if is_np_array():
+            F = F.npx
+        return F.Activation(x, act_type=self._act_type, name='fwd')
 
     def __repr__(self):
         s = '{name}({_act_type})'
diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py
index 654e3ef..512863a 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -25,8 +25,9 @@ import numpy as np
 
 from .activations import Activation
 from ..block import Block, HybridBlock
-from ..utils import _indent, _to_classic_arrays, _to_np_arrays
+from ..utils import _indent, _adapt_np_array
 from ... import nd, sym
+from ...util import is_np_array
 
 
 class Sequential(Block):
@@ -217,14 +218,13 @@ class Dense(HybridBlock):
                 self.act = None
 
     def hybrid_forward(self, F, x, weight, bias=None):
-        # TODO(junwu): This is a temp solution to reuse legacy ops for np.ndarray.
-        # We should rewrite this with np/npx ops.
-        x, weight, bias = _to_classic_arrays(x, weight, bias)
+        if is_np_array():
+            F = F.npx
         act = F.FullyConnected(x, weight, bias, no_bias=bias is None, num_hidden=self._units,
                                flatten=self._flatten, name='fwd')
         if self.act is not None:
             act = self.act(act)
-        return _to_np_arrays(act)
+        return act
 
     def __repr__(self):
         s = '{name}({layout}, {act})'
@@ -264,13 +264,12 @@ class Dropout(HybridBlock):
         self._rate = rate
         self._axes = axes
 
+    @_adapt_np_array
     def hybrid_forward(self, F, x):
-        x = _to_classic_arrays(x)
         if self._rate > 0:
-            out = F.Dropout(x, p=self._rate, axes=self._axes, name='fwd', cudnn_off=False)
+            return F.Dropout(x, p=self._rate, axes=self._axes, name='fwd', cudnn_off=False)
         else:
-            out = F.identity(x)
-        return _to_np_arrays(out)
+            return F.identity(x)
 
     def __repr__(self):
         s = '{name}(p = {_rate}, axes={_axes})'
@@ -360,6 +359,7 @@ class BatchNorm(HybridBlock):
             dtype = 'float32'
         super(BatchNorm, self).cast(dtype)
 
+    @_adapt_np_array
     def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):
         return F.BatchNorm(x, gamma, beta, running_mean, running_var,
                            name='fwd', **self._kwargs)
@@ -413,6 +413,7 @@ class Embedding(HybridBlock):
                                       init=weight_initializer, dtype=dtype,
                                       allow_deferred_init=True, grad_stype=grad_stype)
 
+    @_adapt_np_array
     def hybrid_forward(self, F, x, weight):
         return F.Embedding(x, weight, name='fwd', **self._kwargs)
 
@@ -434,6 +435,7 @@ class Flatten(HybridBlock):
     def __init__(self, **kwargs):
         super(Flatten, self).__init__(**kwargs)
 
+    @_adapt_np_array
     def hybrid_forward(self, F, x):
         return F.Flatten(x)
 
@@ -519,6 +521,7 @@ class InstanceNorm(HybridBlock):
                                     shape=(in_channels,), init=beta_initializer,
                                     allow_deferred_init=True)
 
+    @_adapt_np_array
     def hybrid_forward(self, F, x, gamma, beta):
         if self._axis == 1:
             return F.InstanceNorm(x, gamma, beta,
@@ -607,6 +610,7 @@ class LayerNorm(HybridBlock):
                                     shape=(in_channels,), init=beta_initializer,
                                     allow_deferred_init=True)
 
+    @_adapt_np_array
     def hybrid_forward(self, F, data, gamma, beta):
         norm_data = F.LayerNorm(data, gamma=gamma, beta=beta, axis=self._axis, eps=self._epsilon)
         return norm_data
@@ -703,6 +707,7 @@ class HybridLambda(HybridBlock):
                 "Unrecognized function in lambda: {} of type {}"
                 .format(function, type(function)))
 
+    @_adapt_np_array
     def hybrid_forward(self, F, x, *args):
         return self._func(F, x, *args)
 
diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py
index 19f5c1a..08b29ef 100644
--- a/python/mxnet/gluon/utils.py
+++ b/python/mxnet/gluon/utils.py
@@ -38,7 +38,7 @@ except ImportError:
 import numpy as np
 
 from .. import ndarray
-from ..util import is_np_shape, is_np_array
+from ..util import is_np_shape, is_np_array, wraps_safely
 
 
 def split_data(data, num_slice, batch_axis=0, even_split=True):
@@ -459,7 +459,7 @@ def _check_same_symbol_type(symbols):
                             'symbols in the list to numpy symbols by calling `as_np_ndarray()` '
                             'on each of them; if you want classic ndarray output(s) from the '
                             'computation graph, please convert all the numpy symbols in the list '
-                            'to classic symbols by calling `as_classic_ndarray()` on each of them.')
+                            'to classic symbols by calling `as_nd_ndarray()` on each of them.')
     return np_symbol if is_np_sym else classic_symbol
 
 
@@ -474,16 +474,24 @@ def _check_all_np_ndarrays(out):
                             '{}'.format(str(type(array))))
 
 
-def _to_classic_arrays(*args):
+def _to_classic_arrays(*args, **kwargs):
     """Convert arrays to classic arrays. This is used in a Gluon layer for converting
     inputs of np arrays to classic arrays so that the layer built with legacy ops can still
     be used in np_array semantics."""
+    from ..numpy import ndarray as np_ndarray
+    from ..symbol.numpy import _Symbol as np_symbol
     num_inputs = len(args)
     assert num_inputs != 0
     if not is_np_array():
-        return args[0] if num_inputs == 1 else args
-    in_arrs = [arr if arr is None else arr.as_classic_ndarray() for arr in args]
-    return in_arrs[0] if num_inputs == 1 else in_arrs
+        return args, kwargs
+    in_arrs = [arr if arr is None else arr.as_nd_ndarray() for arr in args]
+    new_kwargs = {}
+    for k, v in kwargs.items():
+        if isinstance(v, (np_ndarray, np_symbol)):
+            new_kwargs[k] = v.as_nd_ndarray()
+        else:
+            new_kwargs[k] = v
+    return in_arrs, new_kwargs
 
 
 def _to_np_arrays(*args):
@@ -496,3 +504,21 @@ def _to_np_arrays(*args):
         return args[0] if num_outputs == 1 else args
     out = [arr.as_np_ndarray() for arr in args]
     return out[0] if num_outputs == 1 else out
+
+
+# TODO(junwu): This is a temp solution for allowing basic layers
+# implemented using legacy ops to accept np.ndarrays as inputs and return
+# np.ndarrays as outputs. We should remove it after changing all the layers
+# to use np ops in np_array semantics in the future.
+def _adapt_np_array(func):
+    @wraps_safely(func)
+    def _with_np_array(*args, **kwargs):
+        assert len(args) > 2, "expect at least three arguments in args"
+        if is_np_array():
+            input_args, kwargs = _to_classic_arrays(*args[2:], **kwargs)
+            input_args = list(args[0:2]) + input_args
+            out = func(*input_args, **kwargs)
+            return _to_np_arrays(out)
+        else:
+            return func(*args, **kwargs)
+    return _with_np_array
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index fc60518..1ba7bce 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -196,7 +196,7 @@ fixed-size items.
         check_call(_LIB.MXShallowCopyNDArray(self.handle, ctypes.byref(hdl)))
         return ndarray(handle=hdl, writable=self.writable)
 
-    def as_classic_ndarray(self):
+    def as_nd_ndarray(self):
         """A convenience function for creating a classic ndarray from the current
         ndarray with zero copy. For this class, it just returns itself since it is
         already a classic ndarray."""
@@ -962,7 +962,7 @@ fixed-size items.
                                  % (idx-length, length))
         check_call(_LIB.MXNDArrayAt(
             self.handle, mx_uint(idx), ctypes.byref(handle)))
-        return NDArray(handle=handle, writable=self.writable)
+        return self.__class__(handle=handle, writable=self.writable)
 
     def reshape(self, *shape, **kwargs):
         """Returns a **view** of this array with a new shape without altering any data.
diff --git a/python/mxnet/ndarray/register.py b/python/mxnet/ndarray/register.py
index cde1145..20e6223 100644
--- a/python/mxnet/ndarray/register.py
+++ b/python/mxnet/ndarray/register.py
@@ -48,8 +48,8 @@ def _verify_all_np_ndarrays(op_name, func_name, args, out):
         if (arr is not None) and (not isinstance(arr, np_ndarray)):
             raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. '
                             'This is a numpy operator which can only accept '
-                            'MXNet numpy ndarrays, while received a classic ndarray. '
-                            'Please call `as_np_ndarray()` upon the classic ndarray to '
+                            'MXNet numpy ndarrays, while received a legacy ndarray. '
+                            'Please call `as_np_ndarray()` upon the legacy ndarray to '
                             'convert it to an MXNet numpy ndarray, and then feed the converted '
                             'array to this operator.'
                             .format(op_name, func_name))
@@ -61,15 +61,15 @@ def _verify_all_np_ndarrays(op_name, func_name, args, out):
         if (arr is not None) and (not isinstance(arr, np_ndarray)):
             raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. '
                             'This is a numpy operator which can only write to MXNet numpy '
-                            'ndarrays, while received a classic ndarray. '
-                            'Please call `as_np_ndarray()` upon the classic ndarray to '
+                            'ndarrays, while received a legacy ndarray. '
+                            'Please call `as_np_ndarray()` upon the legacy ndarray to '
                             'convert it to an MXNet numpy ndarray, and then feed the converted '
                             'array to this operator.'
                             .format(op_name, func_name))
 
 
-def _verify_all_classic_ndarrays(op_name, func_name, args, out):
-    """Verify if all the arrays are classic ndarrays.
+def _verify_all_legacy_ndarrays(op_name, func_name, args, out):
+    """Verify if all the arrays are legacy ndarrays.
 
     Parameters
     ----------
@@ -87,10 +87,10 @@ def _verify_all_classic_ndarrays(op_name, func_name, args, out):
     for arr in args:
         if (arr is not None) and (isinstance(arr, np_ndarray)):
             raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. '
-                            'This is a classic operator which can only accept '
-                            'classic ndarrays, while received an MXNet numpy ndarray. '
-                            'Please call `as_classic_ndarray()` upon the numpy ndarray to '
-                            'convert it to a classic ndarray, and then feed the converted '
+                            'This is a legacy operator which can only accept '
+                            'legacy ndarrays, while received an MXNet numpy ndarray. '
+                            'Please call `as_nd_ndarray()` upon the numpy ndarray to '
+                            'convert it to a legacy ndarray, and then feed the converted '
                             'array to this operator.'
                             .format(op_name, func_name))
     if out is None:
@@ -100,10 +100,10 @@ def _verify_all_classic_ndarrays(op_name, func_name, args, out):
     for arr in out:
         if (arr is not None) and (isinstance(arr, np_ndarray)):
             raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. '
-                            'This is a classic operator which can only write to '
-                            'classic ndarrays, while received an MXNet numpy ndarray. '
-                            'Please call `as_classic_ndarray()` upon the numpy ndarray to '
-                            'convert it to a classic ndarray, and then feed the converted '
+                            'This is a legacy operator which can only write to '
+                            'legacy ndarrays, while received an MXNet numpy ndarray. '
+                            'Please call `as_nd_ndarray()` upon the numpy ndarray to '
+                            'convert it to a legacy ndarray, and then feed the converted '
                             'array to this operator.'
                             .format(op_name, func_name))
 
@@ -175,8 +175,6 @@ def _generate_ndarray_function_code(handle, op_name, func_name, signature_only=F
     doc_str_idx = 1
     if is_np_op:
         doc_str_idx = 2
-        code.append("""
-@use_np_shape""")
     if arr_name:
         code.append("""
 def %s(*%s, **kwargs):"""%(func_name, arr_name))
@@ -233,7 +231,7 @@ def %s(%s):"""%(func_name, ', '.join(signature)))
         vals.append(_np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name))
 
     verify_ndarrays_fn =\
-        _verify_all_np_ndarrays.__name__ if is_np_op else _verify_all_classic_ndarrays.__name__
+        _verify_all_np_ndarrays.__name__ if is_np_op else _verify_all_legacy_ndarrays.__name__
     if not signature_only:
         code.append("""
     {verify_fn}("{op_name}", "{func_name}", ndargs, out)
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 2f0cdbc..454b562 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -32,7 +32,7 @@ from ..ndarray._internal import _set_np_ndarray_class
 from . import _op as _mx_np_op
 from ..base import check_call, _LIB, NDArrayHandle
 from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types, integer_types
-from ..util import _sanity_check_params, set_module, use_np_shape
+from ..util import _sanity_check_params, set_module
 from ..context import current_context
 from ..ndarray import numpy as _mx_nd_np
 from ..ndarray.numpy import _internal as _npi
@@ -82,15 +82,14 @@ def _get_index(idx):
     if isinstance(idx, NDArray) and not isinstance(idx, ndarray):
         raise TypeError('Cannot have mx.nd.NDArray as index')
     if isinstance(idx, ndarray):
-        return idx._as_classic_ndarray()
+        return idx._as_nd_ndarray()
     elif sys.version_info[0] > 2 and isinstance(idx, range):
-        return arange(idx.start, idx.stop, idx.step, dtype='int32')._as_classic_ndarray()
+        return arange(idx.start, idx.stop, idx.step, dtype='int32')._as_nd_ndarray()
     else:
         return idx
 
 
 @set_module('mxnet.numpy')  # pylint: disable=invalid-name
-@use_np_shape
 class ndarray(NDArray):
     """An array object represents a multidimensional, homogeneous array of fixed-size items.
     An associated data-type object describes the format of each element in the array
@@ -105,16 +104,16 @@ class ndarray(NDArray):
                 raise IndexError('scalar tensor can only accept `()` as index')
         if isinstance(key, tuple) and len(key) == 0:
             return self
-        if isinstance(key, integer_types):
-            key = (key,)
         if isinstance(key, tuple) and len(key) == self.ndim\
                 and all(isinstance(idx, integer_types) for idx in key):
-            out = self._as_classic_ndarray()
+            out = self._as_nd_ndarray()
             for idx in key:
                 out = out[idx]
             return out.reshape(()).as_np_ndarray()
+        if isinstance(key, integer_types):
+            return self._at(key)
         if isinstance(key, ndarray):
-            key = key._as_classic_ndarray()
+            key = key._as_nd_ndarray()
         elif isinstance(key, tuple):
             key = [_get_index(idx) for idx in key]
             key = tuple(key)
@@ -122,7 +121,7 @@ class ndarray(NDArray):
             key = [_get_index(idx) for idx in key]
         elif sys.version_info[0] > 2 and isinstance(key, range):
             key = _get_index(key)
-        return self._as_classic_ndarray().__getitem__(key).as_np_ndarray()
+        return self._as_nd_ndarray().__getitem__(key).as_np_ndarray()
 
     def __setitem__(self, key, value):
         # TODO(junwu): calling base class __setitem__ is a temp solution
@@ -132,16 +131,14 @@ class ndarray(NDArray):
             if not isinstance(key, tuple) or len(key) != 0:
                 raise IndexError('scalar tensor can only accept `()` as index')
         if isinstance(value, ndarray):
-            value = value._as_classic_ndarray()
+            value = value._as_nd_ndarray()
         # TODO(junwu): Better handling of this situation
         if isinstance(key, tuple) and len(key) == 0:
-            self._as_classic_ndarray().__setitem__(slice(None), value)
+            self._as_nd_ndarray().__setitem__(slice(None), value)
             return
 
-        if isinstance(key, integer_types):
-            key = (key,)
         if isinstance(key, ndarray):
-            key = key._as_classic_ndarray()
+            key = key._as_nd_ndarray()
         elif isinstance(key, tuple):
             key = [_get_index(idx) for idx in key]
             key = tuple(key)
@@ -149,7 +146,7 @@ class ndarray(NDArray):
             key = [_get_index(idx) for idx in key]
         elif sys.version_info[0] > 2 and isinstance(key, range):
             key = _get_index(key)
-        self._as_classic_ndarray().__setitem__(key, value)
+        self._as_nd_ndarray().__setitem__(key, value)
 
     def __add__(self, other):
         """x.__add__(y) <=> x + y"""
@@ -371,28 +368,26 @@ class ndarray(NDArray):
     def _slice(self, start, stop):
         raise NotImplementedError
 
-    def _at(self, idx):
-        raise NotImplementedError
-
     def all(self, axis=None, out=None, keepdims=False):
         raise NotImplementedError
 
     def any(self, axis=None, out=None, keepdims=False):
         raise NotImplementedError
 
-    def _as_classic_ndarray(self):
+    def _as_nd_ndarray(self):
         """This is not a user-facing API."""
         hdl = NDArrayHandle()
         check_call(_LIB.MXShallowCopyNDArray(self.handle, ctypes.byref(hdl)))
         return NDArray(handle=hdl, writable=self.writable)
 
-    def as_classic_ndarray(self):
+    def as_nd_ndarray(self):
         """Convert mxnet.numpy.ndarray to mxnet.ndarray.NDArray to use its fluent methods."""
-        if self.ndim == 0:  # TODO(junwu): this costs ~10ns, can be moved to backend
-            raise ValueError('cannot convert a scalar np.ndarray to mx.nd.NDArray')
-        if self.size == 0:  # TODO(junwu): this costs ~10ns, can be moved to backend
-            raise ValueError('cannot convert a zero-size np.ndarray to mx.nd.NDArray')
-        return self._as_classic_ndarray()
+        # TODO(junwu): Uncomment the following lines
+        # if self.ndim == 0:  # TODO(junwu): this costs ~10ns, can be moved to backend
+        #     raise ValueError('cannot convert a scalar np.ndarray to mx.nd.NDArray')
+        # if self.size == 0:  # TODO(junwu): this costs ~10ns, can be moved to backend
+        #     raise ValueError('cannot convert a zero-size np.ndarray to mx.nd.NDArray')
+        return self._as_nd_ndarray()
 
     def as_np_ndarray(self):
         """A convenience function for creating a numpy ndarray from the current ndarray
@@ -514,8 +509,8 @@ class ndarray(NDArray):
                [ 1.,  1.,  1.]], dtype=float32)
         """
         if isinstance(other, ndarray):
-            other = other._as_classic_ndarray()
-        return self._as_classic_ndarray().copyto(other).as_np_ndarray()
+            other = other._as_nd_ndarray()
+        return self._as_nd_ndarray().copyto(other).as_np_ndarray()
 
     def asscalar(self):
         raise AttributeError('mxnet.numpy.ndarray object has no attribute asscalar')
@@ -1229,7 +1224,6 @@ def empty(shape, dtype=None, **kwargs):
 
 
 @set_module('mxnet.numpy')
-@use_np_shape
 def array(object, dtype=None, **kwargs):
     """
     Create an array.
diff --git a/python/mxnet/numpy_extension/__init__.py b/python/mxnet/numpy_extension/__init__.py
index 6419c57..a15a1d4 100644
--- a/python/mxnet/numpy_extension/__init__.py
+++ b/python/mxnet/numpy_extension/__init__.py
@@ -27,6 +27,5 @@ from ..context import *  # pylint: disable=wildcard-import
 from ..util import use_np_shape, np_shape, is_np_shape, set_np_shape
 from ..util import use_np_array, np_array, is_np_array, set_np_array
 from ..util import set_np, use_np
-from .. import autograd
 
 __all__ = []
diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py
index 5ab256c..d953e92 100644
--- a/python/mxnet/optimizer/optimizer.py
+++ b/python/mxnet/optimizer/optimizer.py
@@ -1656,13 +1656,13 @@ def _as_classic(a, allow_np):
     if isinstance(a, (tuple, list)):
         if any(isinstance(x, np_ndarray) for x in a):
             if allow_np:
-                return [x.as_classic_ndarray() for x in a]
+                return [x.as_nd_ndarray() for x in a]
             else:
                 raise ValueError('Converting np.ndarray to mx.nd.NDArray is not allowed')
     else:
         if isinstance(a, np_ndarray):
             if allow_np:
-                return a.as_classic_ndarray()
+                return a.as_nd_ndarray()
             else:
                 raise ValueError('Converting np.ndarray to mx.nd.NDArray is not allowed')
     return a
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index 72f9eca..e333a62 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -177,7 +177,7 @@ class _Symbol(Symbol):
     def __len__(self):
         raise NotImplementedError
 
-    def as_classic_ndarray(self):
+    def as_nd_ndarray(self):
         """Convert _Symbol to mxnet.symbol.Symbol to use its convenience fluent methods."""
         hdl = SymbolHandle()
         check_call(_LIB.MXShallowCopySymbol(self.handle, ctypes.byref(hdl)))
diff --git a/python/mxnet/symbol/register.py b/python/mxnet/symbol/register.py
index 2bf3fbd..365a088 100644
--- a/python/mxnet/symbol/register.py
+++ b/python/mxnet/symbol/register.py
@@ -48,15 +48,15 @@ def _verify_np_symbol(op_name, func_name, sym):
     if not isinstance(sym, np_symbol):
         raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. '
                         'This is a numpy operator which can only accept '
-                        'MXNet numpy ndarrays, while received a classic ndarray. '
-                        'Please call `as_np_ndarray()` upon the classic ndarray to '
+                        'MXNet numpy ndarrays, while received a legacy ndarray. '
+                        'Please call `as_np_ndarray()` upon the legacy ndarray to '
                         'convert it to an MXNet numpy ndarray, and then feed the converted '
                         'array to this operator.'
                         .format(op_name, func_name))
 
 
-def _verify_classic_symbol(op_name, func_name, sym):
-    """Verify if the sym is a classic symbol.
+def _verify_legacy_symbol(op_name, func_name, sym):
+    """Verify if the sym is a legacy symbol.
 
     Parameters
     ----------
@@ -70,10 +70,10 @@ def _verify_classic_symbol(op_name, func_name, sym):
     from .numpy._symbol import _Symbol as np_symbol
     if isinstance(sym, np_symbol):
         raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. '
-                        'This is a classic operator which can only accept '
-                        'classic ndarrays, while received an MXNet numpy ndarray. '
-                        'Please call `as_classic_ndarray()` upon the numpy ndarray to '
-                        'convert it to a classic ndarray, and then feed the converted '
+                        'This is a legacy operator which can only accept '
+                        'legacy ndarrays, while received an MXNet numpy ndarray. '
+                        'Please call `as_nd_ndarray()` upon the numpy ndarray to '
+                        'convert it to a legacy ndarray, and then feed the converted '
                         'array to this operator.'
                         .format(op_name, func_name))
 
@@ -142,7 +142,7 @@ def _generate_symbol_function_code(handle, op_name, func_name, signature_only=Fa
     signature = ndsignature + signature
 
     is_np_op = _is_np_op(op_name)
-    verify_symbol_fn = _verify_np_symbol.__name__ if is_np_op else _verify_classic_symbol.__name__
+    verify_symbol_fn = _verify_np_symbol.__name__ if is_np_op else _verify_legacy_symbol.__name__
     code = []
     if arr_name:
         code.append("""
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 87893c4..eb9e759 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -68,7 +68,7 @@ class Symbol(SymbolBase):
         check_call(_LIB.MXShallowCopySymbol(self.handle, ctypes.byref(hdl)))
         return _Symbol(hdl)
 
-    def as_classic_ndarray(self):
+    def as_nd_ndarray(self):
         """Returns self. For the convenience of conversion between legacy and np symbols."""
         return self
 
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index 3264e51..0dcb54b 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -952,7 +952,7 @@ def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rto
         proj = proj.as_np_ndarray()
     out = sym * proj
     if is_np_sym:  # convert to classic symbol so that make_loss can be used
-        out = out.as_classic_ndarray()
+        out = out.as_nd_ndarray()
     out = mx.sym.make_loss(out)
 
     location = dict(list(location.items()) +
diff --git a/python/mxnet/util.py b/python/mxnet/util.py
index 013a717..11ec16e 100644
--- a/python/mxnet/util.py
+++ b/python/mxnet/util.py
@@ -79,6 +79,14 @@ def set_np_shape(active):
     >>> print(mx.is_np_shape())
     True
     """
+    # TODO(junwu): Consider uncommenting the following lines.
+    # import logging
+    # logging.info('NumPy-shape semantics has been activated in your code global scope. '
+    #              'This is required for using `mxnet.numpy` and `mxnet.numpy_extension` '
+    #              'modules as it enables creating and manipulating scalar and zero-size '
+    #              'tensors, which were not supported in MXNet before, as in the official '
+    #              'NumPy library. Please DO NOT manually deactivate this semantics while '
+    #              'using `mxnet.numpy` and `mxnet.numpy_extension` modules.')
     prev = ctypes.c_int()
     check_call(_LIB.MXSetIsNumpyShape(ctypes.c_int(active), ctypes.byref(prev)))
     return bool(prev.value)
diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu
index 5980e81..4cccf59 100644
--- a/src/operator/numpy/np_matrix_op.cu
+++ b/src/operator/numpy/np_matrix_op.cu
@@ -40,5 +40,8 @@ NNVM_REGISTER_OP(_npi_stack)
 NNVM_REGISTER_OP(_npi_concatenate)
 .set_attr<FCompute>("FCompute<gpu>", ConcatCompute<gpu>);
 
+NNVM_REGISTER_OP(_backward_np_concat)
+.set_attr<FCompute>("FCompute<gpu>", ConcatGradCompute<gpu>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc
index 4594b48..a955508 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cc
+++ b/src/operator/tensor/elemwise_unary_op_basic.cc
@@ -511,6 +511,7 @@ Negative indices are supported, and `None` can be used for either `lhs_end` or `
   - lhs shape = (30, 12), rhs shape = (4, 2, 2, 3), lhs_begin=-1, lhs_end=None, rhs_begin=1, rhs_end=None, output shape = (30, 2, 2, 3)
 
 )code" ADD_FILELINE)
+.add_alias("_npx_reshape_like")
 .set_num_inputs(2)
 .set_attr_parser(ParamParser<ReshapeLikeParam>)
 .set_attr<nnvm::FListInputNames>("FListInputNames",
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index 0f059e2..b1165c5 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -263,6 +263,7 @@ static inline bool FlattenStorageType(const nnvm::NodeAttrs& attrs,
 
 NNVM_REGISTER_OP(Flatten)
 .add_alias("flatten")
+.add_alias("_npx_batch_flatten")
 .describe(R"code(Flattens the input array into a 2-D array by collapsing the higher dimensions.
 
 .. note:: `Flatten` is deprecated. Use `flatten` instead.
diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py
index 1c71471..74b3d4d 100644
--- a/tests/python/unittest/test_numpy_ndarray.py
+++ b/tests/python/unittest/test_numpy_ndarray.py
@@ -20,13 +20,14 @@ from __future__ import absolute_import
 from __future__ import division
 import numpy as _np
 import mxnet as mx
-from mxnet import np, npx
+from mxnet import np, npx, autograd
 from mxnet.gluon import HybridBlock
 from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, assert_exception
 from common import with_seed
 
 
 @with_seed()
+@npx.use_np_shape
 def test_array_creation():
     dtypes = [_np.int8, _np.int32, _np.float16, _np.float32, _np.float64, None]
     objects = [
@@ -51,9 +52,9 @@ def test_array_creation():
 
 
 @with_seed()
+@npx.use_np_shape
 def test_zeros():
     # test np.zeros in Gluon
-    @npx.use_np_shape
     class TestZeros(HybridBlock):
         def __init__(self, shape, dtype=None):
             super(TestZeros, self).__init__()
@@ -63,13 +64,11 @@ def test_zeros():
         def hybrid_forward(self, F, x, *args, **kwargs):
             return x + F.np.zeros(shape, dtype)
 
-    @npx.use_np_shape
     class TestZerosOutputType(HybridBlock):
         def hybrid_forward(self, F, x, *args, **kwargs):
             return x, F.np.zeros(shape=())
 
     # test np.zeros in imperative
-    @npx.use_np_shape
     def check_zero_array_creation(shape, dtype):
         np_out = _np.zeros(shape=shape, dtype=dtype)
         mx_out = np.zeros(shape=shape, dtype=dtype)
@@ -101,9 +100,9 @@ def test_zeros():
 
 
 @with_seed()
+@npx.use_np_shape
 def test_ones():
     # test np.ones in Gluon
-    @npx.use_np_shape
     class TestOnes(HybridBlock):
         def __init__(self, shape, dtype=None):
             super(TestOnes, self).__init__()
@@ -113,13 +112,11 @@ def test_ones():
         def hybrid_forward(self, F, x, *args, **kwargs):
             return x * F.np.ones(shape, dtype)
 
-    @npx.use_np_shape
     class TestOnesOutputType(HybridBlock):
         def hybrid_forward(self, F, x, *args, **kwargs):
             return x, F.np.ones(shape=())
 
     # test np.ones in imperative
-    @npx.use_np_shape
     def check_ones_array_creation(shape, dtype):
         np_out = _np.ones(shape=shape, dtype=dtype)
         mx_out = np.ones(shape=shape, dtype=dtype)
@@ -314,7 +311,7 @@ def test_hybrid_block_multiple_outputs():
 
     class TestAllClassicOutputs(HybridBlock):
         def hybrid_forward(self, F, x, *args, **kwargs):
-            return F.relu(x.as_classic_ndarray()), F.sum(x.as_classic_ndarray())
+            return F.relu(x.as_nd_ndarray()), F.sum(x.as_nd_ndarray())
 
     data_np = np.ones((2, 3))
     for block, expected_out_type in [(TestAllClassicOutputs, mx.nd.NDArray),
@@ -330,7 +327,7 @@ def test_hybrid_block_multiple_outputs():
     @npx.use_np_array
     class TestMixedTypeOutputsFailure(HybridBlock):
         def hybrid_forward(self, F, x, *args, **kwargs):
-            return F.relu(x.as_classic_ndarray()), F.np.sum(x)
+            return F.relu(x.as_nd_ndarray()), F.np.sum(x)
 
     net = TestMixedTypeOutputsFailure()
     assert_exception(net, TypeError, data_np)
@@ -339,6 +336,7 @@ def test_hybrid_block_multiple_outputs():
 
 
 @with_seed()
+@npx.use_np_shape
 def test_grad_ndarray_type():
     data = np.array(2, dtype=_np.float32)
     data.attach_grad()
@@ -376,6 +374,7 @@ def test_np_ndarray_copy():
 
 
 @with_seed()
+@npx.use_np_shape
 def test_np_ndarray_indexing():
     def test_getitem(np_array, index):
         """`is_scalar` indicates whether we should expect a scalar for the result.
@@ -443,7 +442,7 @@ def test_np_ndarray_indexing():
     def test_getitem_autograd(np_array, index):
         x = np.array(np_array, dtype=np_array.dtype)
         x.attach_grad()
-        with npx.autograd.record():
+        with autograd.record():
             y = x[index]
         y.backward()
         value = np.ones_like(y)
@@ -457,7 +456,7 @@ def test_np_ndarray_indexing():
         y = np.random.uniform(size=out_shape)
         y.attach_grad()
         try:
-            with npx.autograd.record():
+            with autograd.record():
                 x[index] = y
                 assert False  # should not reach here
         except mx.base.MXNetError as err:
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index d00573e..4e80166 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -28,6 +28,7 @@ import random
 
 
 @with_seed()
+@npx.use_np_shape
 def test_np_sum():
     class TestSum(HybridBlock):
         def __init__(self, axis=None, dtype=None, keepdims=False):
@@ -78,8 +79,8 @@ def test_np_sum():
                         # test numeric
                         if itype == 'float32' and dtype == 'float32':
                             x_sym = mx.sym.Variable("x").as_np_ndarray()
-                            mx_sym = mx.sym.np.sum(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_classic_ndarray()
-                            check_numeric_gradient(mx_sym, [x.as_classic_ndarray()],
+                            mx_sym = mx.sym.np.sum(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_nd_ndarray()
+                            check_numeric_gradient(mx_sym, [x.as_nd_ndarray()],
                                                    numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)
 
                         # test imperative
@@ -116,7 +117,7 @@ def test_np_dot():
         assert_almost_equal(np_res, mx_res.asnumpy(), rtol=1e-5, atol=1e-5)
         mx_a = mx.sym.Variable("a")
         mx_b = mx.sym.Variable("b")
-        mx_sym = mx.sym.np.dot(mx_a.as_np_ndarray(), mx_b.as_np_ndarray()).as_classic_ndarray()
+        mx_sym = mx.sym.np.dot(mx_a.as_np_ndarray(), mx_b.as_np_ndarray()).as_nd_ndarray()
         check_numeric_gradient(mx_sym, {"a": a, "b": b}, numeric_eps=eps, rtol=1e-2, atol=1e-3)
 
     bad_shapes = [((4, 5), (2, 3)), ((3, 4, 5), (6, ))]
@@ -132,6 +133,7 @@ def test_np_dot():
 
 
 @with_seed()
+@npx.use_np_shape
 def test_np_mean():
     @npx.use_np_shape
     class TestMean(HybridBlock):
@@ -185,8 +187,8 @@ def test_np_mean():
                         # test numeric
                         if itype == 'float32' and dtype == 'float32':
                             x_sym = mx.sym.Variable("x").as_np_ndarray()
-                            mx_sym = mx.sym.np.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_classic_ndarray()
-                            check_numeric_gradient(mx_sym, [x.as_classic_ndarray()],
+                            mx_sym = mx.sym.np.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_nd_ndarray()
+                            check_numeric_gradient(mx_sym, [x.as_nd_ndarray()],
                                                    numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)
 
                         # test imperative
@@ -196,6 +198,7 @@ def test_np_mean():
 
 
 @with_seed()
+@npx.use_np_shape
 def test_np_transpose():
     # TODO(junwu): Add more test cases
     data = mx.sym.var('a').as_np_ndarray()
@@ -225,6 +228,7 @@ def test_np_transpose():
 
 
 @with_seed()
+@npx.use_np_shape
 def test_npx_relu():
     # TODO(junwu): Add more test cases
     data = mx.sym.var('data').as_np_ndarray()
@@ -240,6 +244,7 @@ def test_npx_relu():
 
 
 @with_seed()
+@npx.use_np_shape
 def test_npx_sigmoid():
     # TODO(junwu): Add more test cases
     data = mx.sym.var('data').as_np_ndarray()
@@ -255,6 +260,7 @@ def test_npx_sigmoid():
 
 
 @with_seed()
+@npx.use_np_shape
 def test_np_reshape():
     # TODO(junwu): Add more test cases
     data = mx.sym.var('a').as_np_ndarray()
@@ -270,6 +276,7 @@ def test_np_reshape():
 
 
 @with_seed()
+@npx.use_np_shape
 def test_np_maximum():
     # TODO(junwu): Add more test cases
     x1, x2 = mx.sym.var('x1').as_np_ndarray(), mx.sym.var('x2').as_np_ndarray()
@@ -290,6 +297,7 @@ def test_np_maximum():
 
 
 @with_seed()
+@npx.use_np_shape
 def test_np_minimum():
     # TODO(junwu): Add more test cases
     x1, x2 = mx.sym.var('x1').as_np_ndarray(), mx.sym.var('x2').as_np_ndarray()
@@ -310,6 +318,7 @@ def test_np_minimum():
 
 
 @with_seed()
+@npx.use_np_shape
 def test_np_unary_funcs():
     def check_unary_func(func, ref_grad, shape, low, high):
         @npx.use_np_shape
@@ -387,6 +396,7 @@ def test_np_unary_funcs():
 
 
 @with_seed()
+@npx.use_np_shape
 def test_np_stack():
     @npx.use_np_shape
     class TestStack(HybridBlock):
@@ -438,6 +448,8 @@ def test_np_stack():
                 assert same(mx_out.asnumpy(), np_out)
 
 
+@with_seed()
+@npx.use_np_shape
 def test_np_random():
     shapes = [(), (1,), (2, 3), (4, 0, 5), 6, (7, 8), None]
     dtypes = ['float16', 'float32', 'float64']
@@ -480,6 +492,7 @@ def test_np_random():
 
 
 @with_seed()
+@npx.use_np_shape
 def test_np_arange():
     configs = [
         (1, 10, 2),
@@ -543,6 +556,7 @@ def test_np_arange():
 
 
 @with_seed()
+@npx.use_np_shape
 def test_np_argmax():
     workloads = [
         ((), 0, False),
@@ -604,6 +618,7 @@ def test_np_argmax():
 
 
 @with_seed()
+@npx.use_np_shape
 def test_np_linalg_norm():
     @npx.use_np
     class TestLinalgNorm(HybridBlock):