You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2019/06/02 17:37:38 UTC

[incubator-mxnet] branch numpy updated: [numpy] Fix np branch after rebase (#15086)

This is an automated email from the ASF dual-hosted git repository.

reminisce pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/numpy by this push:
     new 49ee3a7  [numpy] Fix np branch after rebase (#15086)
49ee3a7 is described below

commit 49ee3a7b109f1201b980b694f36bb2875575eb4d
Author: reminisce <wu...@gmail.com>
AuthorDate: Sun Jun 2 10:37:15 2019 -0700

    [numpy] Fix np branch after rebase (#15086)
    
    * Add np_array semantics for Gluon
    
    Fix notebook
    
    Fix sanity
    
    Fix gluon deferred infer shape
    
    Add np.random.uniform
    
    Add random normal
    
    Add boolean comparison ops
    
    Add np.ndarray indexing
    
    Reformat test ndarray indexing
    
    Fix unit tests
    
    Add one more test of indexing
    
    Fix sanity
    
    Enable amp test
    
    Add np.arange
    
    Revert cython unit test to ctypes
    
    Delete unnecessary use_np_shape decorator from test
    
    Rebase with numpy branch
    
    support range as index
    
    Fix python2 range type check
    
    Add argmax
    
    Disable clojure test
    
    * Fix ci
    
    * Add np.linalg.norm for ord='fro'
    
    * Fix pylint
---
 ci/jenkins/Jenkins_steps.groovy                    |  18 +-
 ci/jenkins/Jenkinsfile_unix_cpu                    |   4 +-
 example/numpy/demo.ipynb                           |   2 +-
 python/mxnet/__init__.py                           |   3 +-
 python/mxnet/_ctypes/ndarray.py                    |   2 +-
 python/mxnet/base.py                               |  10 +-
 python/mxnet/gluon/block.py                        |   3 +-
 python/mxnet/gluon/parameter.py                    |  20 +-
 python/mxnet/gluon/utils.py                        |   2 +-
 python/mxnet/ndarray/__init__.py                   |   2 +-
 python/mxnet/ndarray/numpy/_op.py                  |  78 ++++-
 python/mxnet/ndarray/numpy/linalg.py               |  50 +++-
 python/mxnet/ndarray/numpy/random.py               | 119 +++++++-
 python/mxnet/numpy/__init__.py                     |   1 -
 python/mxnet/numpy/linalg.py                       |  44 ++-
 python/mxnet/numpy/multiarray.py                   | 197 +++++++++++--
 python/mxnet/numpy/random.py                       |  82 +++++-
 python/mxnet/numpy_extension/__init__.py           |   3 +
 python/mxnet/symbol/__init__.py                    |   2 +-
 python/mxnet/symbol/numpy/_symbol.py               | 148 ++++++++--
 python/mxnet/symbol/numpy/linalg.py                |  49 +++-
 python/mxnet/symbol/numpy/random.py                | 120 +++++++-
 python/mxnet/test_utils.py                         |   2 +-
 python/mxnet/util.py                               | 230 ++++++++++++++-
 src/operator/numpy/np_broadcast_reduce_op_index.cc |  61 ++++
 ..._init_op.cu => np_broadcast_reduce_op_index.cu} |  20 +-
 src/operator/numpy/np_broadcast_reduce_op_value.cc |   2 +-
 src/operator/numpy/np_broadcast_reduce_op_value.cu |   2 +-
 src/operator/numpy/np_elemwise_unary_op_basic.cc   |   4 +-
 src/operator/numpy/np_elemwise_unary_op_basic.cu   |   4 +-
 src/operator/numpy/np_init_op.cc                   |  27 ++
 src/operator/numpy/np_init_op.cu                   |   3 +
 src/operator/random/sample_op.cc                   |   2 +
 src/operator/tensor/broadcast_reduce_op.h          |  50 +++-
 .../tensor/elemwise_binary_broadcast_op_logic.cc   |   6 +
 .../tensor/elemwise_binary_scalar_op_logic.cc      |   6 +
 tests/python/unittest/test_contrib_amp.py          |   3 -
 tests/python/unittest/test_numpy_gluon.py          |  12 +-
 tests/python/unittest/test_numpy_ndarray.py        | 319 +++++++++++++++++++--
 tests/python/unittest/test_numpy_op.py             | 229 +++++++++++++--
 tests/python/unittest/test_thread_local.py         |  36 +++
 41 files changed, 1813 insertions(+), 164 deletions(-)

diff --git a/ci/jenkins/Jenkins_steps.groovy b/ci/jenkins/Jenkins_steps.groovy
index 668d2f7..1d62d0a 100644
--- a/ci/jenkins/Jenkins_steps.groovy
+++ b/ci/jenkins/Jenkins_steps.groovy
@@ -112,7 +112,8 @@ def compile_unix_cpu_openblas() {
           timeout(time: max_time, unit: 'MINUTES') {
             utils.init_git()
             utils.docker_run('ubuntu_cpu', 'build_ubuntu_cpu_openblas', false)
-            utils.pack_lib('cpu', mx_lib_cython, true)
+            // utils.pack_lib('cpu', mx_lib_cython, true)
+            utils.pack_lib('cpu', mx_lib, true)
           }
         }
       }
@@ -266,7 +267,8 @@ def compile_unix_cmake_gpu() {
           timeout(time: max_time, unit: 'MINUTES') {
             utils.init_git()
             utils.docker_run('ubuntu_gpu_cu100', 'build_ubuntu_gpu_cmake', false)
-            utils.pack_lib('cmake_gpu', mx_cmake_lib_cython, true)
+            // utils.pack_lib('cmake_gpu', mx_cmake_lib_cython, true)
+            utils.pack_lib('cmake_gpu', mx_cmake_lib, true)
           }
         }
       }
@@ -643,8 +645,10 @@ def test_unix_python2_cpu() {
       node(NODE_LINUX_CPU) {
         ws('workspace/ut-python2-cpu') {
           try {
-            utils.unpack_and_init('cpu', mx_lib_cython, true)
-            python2_ut_cython('ubuntu_cpu')
+            // utils.unpack_and_init('cpu', mx_lib_cython, true)
+            // python2_ut_cython('ubuntu_cpu')
+            utils.unpack_and_init('cpu', mx_lib, true)
+            python2_ut('ubuntu_cpu')
             utils.publish_test_coverage()
           } finally {
             utils.collect_test_results_unix('nosetests_unittest.xml', 'nosetests_python2_cpu_unittest.xml')
@@ -745,8 +749,10 @@ def test_unix_python3_gpu() {
       node(NODE_LINUX_GPU) {
         ws('workspace/ut-python3-gpu') {
           try {
-            utils.unpack_and_init('gpu', mx_lib_cython, true)
-            python3_gpu_ut_cython('ubuntu_gpu_cu100')
+            // utils.unpack_and_init('gpu', mx_lib_cython, true)
+            // python3_gpu_ut_cython('ubuntu_gpu_cu100')
+            utils.unpack_and_init('gpu', mx_lib, true)
+            python3_gpu_ut('ubuntu_gpu_cu100')
             utils.publish_test_coverage()
           } finally {
             utils.collect_test_results_unix('nosetests_gpu.xml', 'nosetests_python3_gpu.xml')
diff --git a/ci/jenkins/Jenkinsfile_unix_cpu b/ci/jenkins/Jenkinsfile_unix_cpu
index fa09429..c3a1481 100644
--- a/ci/jenkins/Jenkinsfile_unix_cpu
+++ b/ci/jenkins/Jenkinsfile_unix_cpu
@@ -52,8 +52,8 @@ core_logic: {
     custom_steps.test_unix_python3_mkldnn_mkl_cpu(),
     custom_steps.test_unix_scala_cpu(),
     custom_steps.test_unix_scala_mkldnn_cpu(),
-    custom_steps.test_unix_clojure_cpu(),
-    custom_steps.test_unix_clojure_integration_cpu(),
+    // custom_steps.test_unix_clojure_cpu(),
+    // custom_steps.test_unix_clojure_integration_cpu(),
     custom_steps.test_unix_perl_cpu(),
     custom_steps.test_unix_r_cpu(),
     custom_steps.test_unix_r_mkldnn_cpu(),
diff --git a/example/numpy/demo.ipynb b/example/numpy/demo.ipynb
index 1f06275..31c13e9 100644
--- a/example/numpy/demo.ipynb
+++ b/example/numpy/demo.ipynb
@@ -372,7 +372,7 @@
     "from mxnet import gluon, autograd, np\n",
     "\n",
     "\n",
-    "@np.use_np_compat\n",
+    "@np.use_np\n",
     "class LinearRegression(gluon.HybridBlock):\n",
     "    def __init__(self, num_input_dim=1000, num_hidden_dim=100, num_output_dim=10):\n",
     "        super(LinearRegression, self).__init__()\n",
diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py
index 883e846..f288b4c 100644
--- a/python/mxnet/__init__.py
+++ b/python/mxnet/__init__.py
@@ -25,6 +25,7 @@ from .context import Context, current_context, cpu, gpu, cpu_pinned
 from . import engine
 from .base import MXNetError
 from .util import is_np_shape, set_np_shape, np_shape, use_np_shape
+from .util import is_np_array, np_array, use_np_array, use_np
 from . import base
 from . import contrib
 from . import ndarray
@@ -32,7 +33,7 @@ from . import ndarray as nd
 from . import numpy
 from . import numpy_extension
 from . import numpy as np
-from . import numpy_extension as npe
+from . import numpy_extension as npx
 from . import name
 # use mx.sym as short for symbol
 from . import symbol as sym
diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py
index 6404d89..dd429e6 100644
--- a/python/mxnet/_ctypes/ndarray.py
+++ b/python/mxnet/_ctypes/ndarray.py
@@ -118,7 +118,7 @@ class CachedOp(object):
         self.handle = CachedOpHandle()
 
         from ..symbol.numpy._symbol import _Symbol
-        self.is_np_sym = True if isinstance(sym, _Symbol) else False
+        self.is_np_sym = bool(isinstance(sym, _Symbol))
 
         check_call(_LIB.MXCreateCachedOpEx(
             sym.handle,
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 7149d2f..a348326 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -743,7 +743,7 @@ ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
 _NP_OP_PREFIX = '_np_'
 _NP_OP_SUBMODULE_LIST = ['_random_', '_linalg_']
 
-_NP_EXT_OP_PREFIX = '_npe_'
+_NP_EXT_OP_PREFIX = '_npx_'
 
 _NP_INTERNAL_OP_PREFIX = '_npi_'
 
@@ -800,14 +800,14 @@ def _init_np_op_module(root_module_name, np_module_name, mx_module_name, make_op
             op_names.append(name)
 
     if mx_module_name is None:
-        # register np/npe ops for imperative programming
+        # register np/npx ops for imperative programming
         op_module_name = "%s.%s._op" % (root_module_name, np_module_name)  # e.g. mxnet.numpy._op
         op_submodule_name = "%s.%s" % (root_module_name, np_module_name)  # e.g. mxnet.numpy.random
-    elif mx_module_name == 'ndarray' or mx_module_name == 'symbol':
-        # register numpy internal ops and np/npe ops for use in Gluon
+    elif mx_module_name in ('ndarray', 'symbol'):
+        # register numpy internal ops and np/npx ops for use in Gluon
         # np internal ops are registered in mxnet.ndarray/symbol.numpy._internal
         # np ops are registered in mxnet.ndarray/symbol.numpy._op
-        # npe ops are registered in mxnet.ndarray/symbol.numpy_extension._op
+        # npx ops are registered in mxnet.ndarray/symbol.numpy_extension._op
         op_module_name = "%s.%s.%s" % (root_module_name, mx_module_name, np_module_name)
         if op_name_prefix != _NP_INTERNAL_OP_PREFIX:
             op_module_name += '._op'
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 9a1d16e..845bb31 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -34,6 +34,7 @@ from .. import name as _name
 from .parameter import Parameter, ParameterDict, DeferredInitializationError
 from .utils import _indent, _brief_print_list, HookHandle
 from .utils import _check_same_symbol_type, _check_all_np_ndarrays
+from .. import numpy_extension as _mx_npx
 from .. import numpy as _mx_np
 
 
@@ -543,7 +544,7 @@ class Block(object):
 
         for hook in self._forward_hooks.values():
             hook(self, args, out)
-        if _mx_np.is_np_shape():
+        if _mx_npx.is_np_array():
             _check_all_np_ndarrays(_flatten(out, "output")[0])
         return out
 
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index 0b87bf0..eaab6dd 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -31,7 +31,7 @@ from .. import symbol, ndarray, initializer, context
 from ..context import Context, cpu
 from .. import autograd
 from .utils import _indent, _brief_print_list, shape_is_known
-from ..util import is_np_shape
+from ..util import is_np_shape, is_np_array
 
 # pylint: disable= invalid-name
 tensor_types = (symbol.Symbol, ndarray.NDArray)
@@ -156,16 +156,21 @@ class Parameter(object):
 
     @property
     def shape(self):
-        return self._shape
+        if self._shape is None:
+            return None
+        elif is_np_shape():
+            return tuple(i if i != 0 else -1 for i in self._shape)
+        else:
+            return self._shape
 
     @shape.setter
     def shape(self, new_shape):
         if self._shape is None:
             self._shape = new_shape
             return
-        unknown_dim_size = -1 if is_np_shape() else 0
+
         assert len(self._shape) == len(new_shape) and \
-            all(j in (unknown_dim_size, i) for i, j in zip(new_shape, self._shape)), \
+            all(j in (0, i) for i, j in zip(new_shape, self._shape)), \
             "Expected shape %s is incompatible with given shape %s."%(
                 str(new_shape), str(self._shape))
 
@@ -269,6 +274,7 @@ class Parameter(object):
             return
         init, ctx, default_init, data = self._deferred_init
         self._deferred_init = ()
+
         assert shape_is_known(self.shape), \
             "Cannot initialize Parameter '%s' because it has " \
             "invalid shape: %s. Please specify in_units, " \
@@ -282,7 +288,7 @@ class Parameter(object):
                 initializer.create(default_init)(
                     initializer.InitDesc(self.name, {'__init__': init}), data)
                 # TODO(junwu): use np random operators when available
-                if is_np_shape():
+                if is_np_array():
                     data = data.as_np_ndarray()  # convert to np.ndarray
 
             self._init_impl(data, ctx)
@@ -309,7 +315,7 @@ class Parameter(object):
         self._grad = [ndarray.zeros(shape=i.shape, dtype=i.dtype, ctx=i.context,
                                     stype=self._grad_stype) for i in self._data]
         # TODO(junwu): use np.zeros
-        if is_np_shape():
+        if is_np_array():
             self._grad = [arr.as_np_ndarray() for arr in self._grad]
 
         autograd.mark_variables(self._check_and_get(self._data, list),
@@ -558,7 +564,7 @@ class Parameter(object):
             self._var = symbol.var(self.name, shape=self.shape, dtype=self.dtype,
                                    lr_mult=self.lr_mult, wd_mult=self.wd_mult,
                                    init=self.init, stype=self._stype)
-            if is_np_shape():
+            if is_np_array():
                 self._var = self._var.as_np_ndarray()
         return self._var
 
diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py
index e8ffdd6..7e6e0e3 100644
--- a/python/mxnet/gluon/utils.py
+++ b/python/mxnet/gluon/utils.py
@@ -422,7 +422,7 @@ def _check_same_symbol_type(symbols):
     the symbols."""
     from ..symbol.numpy import _Symbol as np_symbol
     from ..symbol import Symbol as classic_symbol
-    is_np_sym = True if isinstance(symbols[0], np_symbol) else False
+    is_np_sym = bool(isinstance(symbols[0], np_symbol))
     for s in symbols[1:]:
         if is_np_sym != isinstance(s, np_symbol):
             raise TypeError('Found both classic symbol (mx.sym.Symbol) and numpy symbol '
diff --git a/python/mxnet/ndarray/__init__.py b/python/mxnet/ndarray/__init__.py
index c326850..f6b8712 100644
--- a/python/mxnet/ndarray/__init__.py
+++ b/python/mxnet/ndarray/__init__.py
@@ -31,7 +31,7 @@ from .utils import load, load_frombuffer, save, zeros, empty, array
 from .sparse import _ndarray_cls
 from .ndarray import _GRAD_REQ_MAP, _DTYPE_MX_TO_NP, _DTYPE_NP_TO_MX, _new_empty_handle
 from . import numpy as np
-from . import numpy_extension as npe
+from . import numpy_extension as npx
 
 __all__ = op.__all__ + ndarray.__all__ + utils.__all__ + \
           ['contrib', 'linalg', 'random', 'sparse', 'image', 'numpy', 'numpy_extension']
diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index 76825f1..34218e3 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -24,7 +24,7 @@ from ...util import _sanity_check_params, set_module
 from ...context import current_context
 from . import _internal as _npi
 
-__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack']
+__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax']
 
 
 @set_module('mxnet.ndarray.numpy')
@@ -201,3 +201,79 @@ def stack(arrays, axis=0, out=None):
 
     arrays = get_list(arrays)
     return _npi.stack(*arrays, axis=axis, out=out)
+
+
+@set_module('mxnet.ndarray.numpy')
+def arange(start, stop=None, step=1, dtype=None, ctx=None):
+    """Return evenly spaced values within a given interval.
+
+    Values are generated within the half-open interval ``[start, stop)``
+    (in other words, the interval including `start` but excluding `stop`).
+    For integer arguments the function is equivalent to the Python built-in
+    `range` function, but returns an ndarray rather than a list.
+
+    Parameters
+    ----------
+    start : number, optional
+        Start of interval. The interval includes this value.  The default
+        start value is 0.
+    stop : number
+        End of interval. The interval does not include this value, except
+        in some cases where `step` is not an integer and floating point
+        round-off affects the length of `out`.
+    step : number, optional
+        Spacing between values. For any output `out`, this is the distance
+        between two adjacent values, ``out[i+1] - out[i]``.  The default
+        step size is 1.  If `step` is specified as a position argument,
+        `start` must also be given.
+    dtype : dtype
+        The type of the output array. The default is `float32`.
+
+    Returns
+    -------
+    arange : ndarray
+        Array of evenly spaced values.
+
+        For floating point arguments, the length of the result is
+        ``ceil((stop - start)/step)``.  Because of floating point overflow,
+        this rule may result in the last element of `out` being greater
+        than `stop`.
+    """
+    if dtype is None:
+        dtype = 'float32'
+    if ctx is None:
+        ctx = current_context()
+    if stop is None:
+        stop = start
+        start = 0
+    if step is None:
+        step = 1
+    if start is None and stop is None:
+        raise ValueError('start and stop cannot be both None')
+    if step == 0:
+        raise ZeroDivisionError('step cannot be 0')
+    return _npi.arange(start=start, stop=stop, step=step, dtype=dtype, ctx=ctx)
+
+
+@set_module('mxnet.ndarray.numpy')
+def argmax(a, axis=None, out=None):
+    """Returns the indices of the maximum values along an axis.
+
+    Parameters
+    ----------
+    a : ndarray
+        Input array. Only support ndarrays of dtype `float16`, `float32`, and `float64`.
+    axis : int, optional
+        By default, the index is into the flattened array, otherwise
+        along the specified axis.
+    out : array, optional
+        If provided, the result will be inserted into this array. It should
+        be of the appropriate shape and dtype.
+
+    Returns
+    -------
+    index_array : ndarray of indices whose dtype is same as the input ndarray.
+        Array of indices into the array. It has the same shape as `a.shape`
+        with the dimension along `axis` removed.
+    """
+    return _npi.argmax(a, axis=axis, keepdims=False, out=out)
diff --git a/python/mxnet/ndarray/numpy/linalg.py b/python/mxnet/ndarray/numpy/linalg.py
index 8f521fd..36f3f21 100644
--- a/python/mxnet/ndarray/numpy/linalg.py
+++ b/python/mxnet/ndarray/numpy/linalg.py
@@ -17,4 +17,52 @@
 
 """Namespace for operators used in Gluon dispatched by F=ndarray."""
 
-__all__ = []
+from __future__ import absolute_import
+from . import _op as _mx_nd_np
+
+__all__ = ['norm']
+
+
+def norm(x, ord=None, axis=None, keepdims=False):
+    r"""Matrix or vector norm.
+
+    This function can only support Frobenius norm for now.
+    The Frobenius norm is given by [1]_:
+
+        :math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}`
+
+    Parameters
+    ----------
+    x : ndarray
+        Input array.
+    ord : {'fro'}, optional
+        Order of the norm.
+    axis : {int, 2-tuple of ints, None}, optional
+        If `axis` is an integer, it specifies the axis of `x` along which to
+        compute the vector norms.  If `axis` is a 2-tuple, it specifies the
+        axes that hold 2-D matrices, and the matrix norms of these matrices
+        are computed.  If `axis` is None, the norm of the whole ndarray is
+        returned.
+
+    keepdims : bool, optional
+        If this is set to True, the axes which are normed over are left in the
+        result as dimensions with size one.  With this option the result will
+        broadcast correctly against the original `x`.
+
+    Returns
+    -------
+    n : float or ndarray
+        Norm of the matrix or vector(s).
+
+    References
+    ----------
+    .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*,
+           Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
+    """
+    if ord is not None and ord != 'fro':
+        raise ValueError('only support Frobenius norm for now, received ord={}'.format(str(ord)))
+    if isinstance(axis, tuple) and len(axis) > 2:
+        raise ValueError('Improper number of dimensions to norm')
+    if ord == 'fro' and x.ndim > 2 and axis is None:
+        raise ValueError('Improper number of dimensions to norm')
+    return _mx_nd_np.sqrt(_mx_nd_np.sum(x * x, axis=axis, keepdims=keepdims))
diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py
index 8f521fd..3d9fd6a 100644
--- a/python/mxnet/ndarray/numpy/random.py
+++ b/python/mxnet/ndarray/numpy/random.py
@@ -16,5 +16,122 @@
 # under the License.
 
 """Namespace for operators used in Gluon dispatched by F=ndarray."""
+from __future__ import absolute_import
+from ...base import numeric_types
+from ...context import current_context
+from . import _internal as _npi
 
-__all__ = []
+__all__ = ['uniform', 'normal']
+
+
+def _random_helper(random, sampler, params, shape, dtype, ctx, out, kwargs):
+    """Helper function for random generators."""
+    from ...numpy import ndarray as np_ndarray
+    if isinstance(params[0], np_ndarray):
+        for i in params[1:]:
+            assert isinstance(i, np_ndarray), \
+                "Distribution parameters must all have the same type, but got " \
+                "both %s and %s." % (type(params[0]), type(i))
+        return sampler(*params, shape=shape, dtype=dtype, out=out, **kwargs)
+    elif isinstance(params[0], numeric_types):
+        if ctx is None:
+            ctx = current_context()
+        if shape is None and out is None:
+            shape = ()
+        for i in params[1:]:
+            assert isinstance(i, numeric_types), \
+                "Distribution parameters must all have the same type, but got " \
+                "both %s and %s."%(type(params[0]), type(i))
+        return random(*params, shape=shape, dtype=dtype, ctx=ctx, out=out, **kwargs)
+
+    raise ValueError("Distribution parameters must be either mxnet.numpy.ndarray or numbers, "
+                     "but got %s." % type(params[0]))
+
+
+def uniform(low=0.0, high=1.0, size=None, **kwargs):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval
+    ``[low, high)`` (includes low, but excludes high).  In other words,
+    any value within the given interval is equally likely to be drawn
+    by `uniform`.
+
+    Parameters
+    ----------
+    low : float, optional
+        Lower boundary of the output interval.  All values generated will be
+        greater than or equal to low.  The default value is 0.
+    high : float
+        Upper boundary of the output interval.  All values generated will be
+        less than high.  The default value is 1.0.
+    size : int or tuple of ints, optional
+        Output shape.  If the given shape is, e.g., ``(m, n, k)``, then
+        ``m * n * k`` samples are drawn.  If size is ``None`` (default),
+        a scalar tensor containing a single value is returned if
+        ``low`` and ``high`` are both scalars.
+    dtype : {'float16', 'float32', 'float64'}, optional
+        Data type of output samples. Default is 'float32'
+    ctx : Context, optional
+        Device context of output. Default is current context.
+    out : ndarray, optional
+        Store output to an existing ndarray.
+
+    Returns
+    -------
+    out : ndarray
+        Drawn samples from the parameterized uniform distribution.
+
+
+    Notes
+    -----
+    This function currently does not support ``low`` and ``high`` as ndarrays.
+    """
+    dtype = kwargs.pop('dtype', None)
+    if dtype is None:
+        dtype = 'float32'
+    ctx = kwargs.pop('ctx', None)
+    out = kwargs.pop('out', None)
+    return _random_helper(_npi.random_uniform, None,
+                          [low, high], size, dtype, ctx, out, kwargs)
+
+
+def normal(loc=0.0, scale=1.0, size=None, **kwargs):
+    """Draw random samples from a normal (Gaussian) distribution.
+
+    Samples are distributed according to a normal distribution parametrized
+    by *loc* (mean) and *scale* (standard deviation).
+
+
+    Parameters
+    ----------
+    loc : float, optional
+        Mean (centre) of the distribution.
+    scale : float, optional
+        Standard deviation (spread or "width") of the distribution.
+    size : int or tuple of ints, optional
+        Output shape. If the given shape is, e.g., `(m, n, k)`, then `m * n * k`
+        samples are drawn. If size is `None` (default), a scalar tensor containing
+        a single value is returned if loc and scale are both scalars.
+    dtype : {'float16', 'float32', 'float64'}, optional
+        Data type of output samples. Default is 'float32'
+    ctx : Context, optional
+        Device context of output. Default is current context.
+    out : ``ndarray``, optional
+        Store output to an existing ``ndarray``.
+
+    Returns
+    -------
+    out : ndarray
+        Drawn samples from the parameterized normal distribution.
+
+    Notes
+    -----
+    This function currently does not support ``loc`` and ``scale`` as ndarrays.
+    """
+    dtype = kwargs.pop('dtype', None)
+    if dtype is None:
+        dtype = 'float32'
+    ctx = kwargs.pop('ctx', None)
+    out = kwargs.pop('out', None)
+    return _random_helper(_npi.random_normal, None,
+                          [loc, scale], size, dtype, ctx, out, kwargs)
diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/__init__.py
index 6f1c02d..344483d 100644
--- a/python/mxnet/numpy/__init__.py
+++ b/python/mxnet/numpy/__init__.py
@@ -26,6 +26,5 @@ from .multiarray import *  # pylint: disable=wildcard-import
 from . import _op
 from . import _register
 from ._op import *  # pylint: disable=wildcard-import
-from ..util import use_np_shape, set_np_shape, np_shape, is_np_shape
 
 __all__ = []
diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/linalg.py
index e49bfcf..9758af4 100644
--- a/python/mxnet/numpy/linalg.py
+++ b/python/mxnet/numpy/linalg.py
@@ -17,4 +17,46 @@
 
 """Namespace for ops used in imperative programming."""
 
-__all__ = []
+from __future__ import absolute_import
+from ..ndarray import numpy as _mx_nd_np
+
+__all__ = ['norm']
+
+
+def norm(x, ord=None, axis=None, keepdims=False):
+    r"""Matrix or vector norm.
+
+    This function can only support Frobenius norm for now.
+    The Frobenius norm is given by [1]_:
+
+        :math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}`
+
+    Parameters
+    ----------
+    x : ndarray
+        Input array.
+    ord : {'fro'}, optional
+        Order of the norm.
+    axis : {int, 2-tuple of ints, None}, optional
+        If `axis` is an integer, it specifies the axis of `x` along which to
+        compute the vector norms.  If `axis` is a 2-tuple, it specifies the
+        axes that hold 2-D matrices, and the matrix norms of these matrices
+        are computed.  If `axis` is None, the norm of the whole ndarray is
+        returned.
+
+    keepdims : bool, optional
+        If this is set to True, the axes which are normed over are left in the
+        result as dimensions with size one.  With this option the result will
+        broadcast correctly against the original `x`.
+
+    Returns
+    -------
+    n : float or ndarray
+        Norm of the matrix or vector(s).
+
+    References
+    ----------
+    .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*,
+           Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
+    """
+    return _mx_nd_np.linalg.norm(x, ord, axis, keepdims)
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index da7e61e..212dfe3 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -23,19 +23,22 @@
 from __future__ import absolute_import
 from __future__ import division
 from array import array as native_array
+import sys
 import ctypes
+import warnings
 import numpy as _np
 from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _GRAD_REQ_MAP
 from ..ndarray._internal import _set_np_ndarray_class
 from . import _op as _mx_np_op
 from ..base import check_call, _LIB, NDArrayHandle
-from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types
+from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types, integer_types
 from ..util import _sanity_check_params, set_module, use_np_shape
 from ..context import current_context
 from ..ndarray import numpy as _mx_nd_np
 from ..ndarray.numpy import _internal as _npi
 
-__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 'stack']
+__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange',
+           'argmax']
 
 
 # This function is copied from ndarray.py since pylint
@@ -74,6 +77,17 @@ def _np_ndarray_cls(handle, writable=True, stype=0):
 _set_np_ndarray_class(_np_ndarray_cls)
 
 
+def _get_index(idx):
+    if isinstance(idx, NDArray) and not isinstance(idx, ndarray):
+        raise TypeError('Cannot have mx.nd.NDArray as index')
+    if isinstance(idx, ndarray):
+        return idx._as_classic_ndarray()
+    elif sys.version_info[0] > 2 and isinstance(idx, range):
+        return arange(idx.start, idx.stop, idx.step, dtype='int32')._as_classic_ndarray()
+    else:
+        return idx
+
+
 @set_module('mxnet.numpy')  # pylint: disable=invalid-name
 @use_np_shape
 class ndarray(NDArray):
@@ -83,22 +97,57 @@ class ndarray(NDArray):
     floating point number, or something else, etc.). Arrays should be constructed using
     `array`, `zeros` or `empty`. Currently, only c-contiguous arrays are supported."""
 
-    def __getitem__(self, item):
-        # TODO(junwu): make output shape of integer indexing correct
-        raise NotImplementedError
+    def __getitem__(self, key):
+        # TODO(junwu): calling base class __setitem__ is a temp solution
+        if self.ndim == 0:
+            if key != ():
+                raise IndexError('scalar tensor can only accept `()` as index')
+        if isinstance(key, tuple) and len(key) == 0:
+            return self
+        if isinstance(key, integer_types):
+            key = (key,)
+        if isinstance(key, tuple) and len(key) == self.ndim\
+                and all(isinstance(idx, integer_types) for idx in key):
+            out = self._as_classic_ndarray()
+            for idx in key:
+                out = out[idx]
+            return out.reshape(()).as_np_ndarray()
+        if isinstance(key, ndarray):
+            key = key._as_classic_ndarray()
+        elif isinstance(key, tuple):
+            key = [_get_index(idx) for idx in key]
+            key = tuple(key)
+        elif isinstance(key, list):
+            key = [_get_index(idx) for idx in key]
+        elif sys.version_info[0] > 2 and isinstance(key, range):
+            key = _get_index(key)
+        return self._as_classic_ndarray().__getitem__(key).as_np_ndarray()
 
     def __setitem__(self, key, value):
-        if self.size == 0:
-            return
+        # TODO(junwu): calling base class __setitem__ is a temp solution
+        if isinstance(value, NDArray) and not isinstance(value, ndarray):
+            raise TypeError('Cannot assign mx.nd.NDArray to mxnet.numpy.ndarray')
         if self.ndim == 0:
-            if key != ():
+            if not isinstance(key, tuple) or len(key) != 0:
                 raise IndexError('scalar tensor can only accept `()` as index')
-            # TODO(junwu): Better handling of this situation
-            hdl = NDArrayHandle()
-            check_call(_LIB.MXShallowCopyNDArray(self.handle, ctypes.byref(hdl)))
-            classic_ndarray = NDArray(handle=hdl, writable=self.writable)
-            classic_ndarray.__setitem__(slice(None), value)
+        if isinstance(value, ndarray):
+            value = value._as_classic_ndarray()
+        # TODO(junwu): Better handling of this situation
+        if isinstance(key, tuple) and len(key) == 0:
+            self._as_classic_ndarray().__setitem__(slice(None), value)
             return
+
+        if isinstance(key, integer_types):
+            key = (key,)
+        if isinstance(key, ndarray):
+            key = key._as_classic_ndarray()
+        elif isinstance(key, tuple):
+            key = [_get_index(idx) for idx in key]
+            key = tuple(key)
+        elif isinstance(key, list):
+            key = [_get_index(idx) for idx in key]
+        elif sys.version_info[0] > 2 and isinstance(key, range):
+            key = _get_index(key)
         self._as_classic_ndarray().__setitem__(key, value)
 
     def __add__(self, other):
@@ -248,33 +297,78 @@ class ndarray(NDArray):
 
     def __eq__(self, other):
         """x.__eq__(y) <=> x == y"""
-        raise NotImplementedError
+        # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported
+        if isinstance(other, ndarray):
+            return _npi.equal(self, other)
+        elif isinstance(other, numeric_types):
+            return _npi.equal_scalar(self, float(other))
+        else:
+            raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
 
     def __hash__(self):
         raise NotImplementedError
 
     def __ne__(self, other):
         """x.__ne__(y) <=> x != y"""
-        raise NotImplementedError
+        # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported
+        if isinstance(other, ndarray):
+            return _npi.not_equal(self, other)
+        elif isinstance(other, numeric_types):
+            return _npi.not_equal_scalar(self, float(other))
+        else:
+            raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
 
     def __gt__(self, other):
         """x.__gt__(y) <=> x > y"""
-        raise NotImplementedError
+        # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported
+        if isinstance(other, ndarray):
+            return _npi.greater(self, other)
+        elif isinstance(other, numeric_types):
+            return _npi.greater_scalar(self, float(other))
+        else:
+            raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
 
     def __ge__(self, other):
         """x.__ge__(y) <=> x >= y"""
-        raise NotImplementedError
+        # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported
+        if isinstance(other, ndarray):
+            return _npi.greater_equal(self, other)
+        elif isinstance(other, numeric_types):
+            return _npi.greater_equal_scalar(self, float(other))
+        else:
+            raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
 
     def __lt__(self, other):
         """x.__lt__(y) <=> x < y"""
-        raise NotImplementedError
+        # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported
+        if isinstance(other, ndarray):
+            return _npi.less(self, other)
+        elif isinstance(other, numeric_types):
+            return _npi.less_scalar(self, float(other))
+        else:
+            raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
 
     def __le__(self, other):
         """x.__le__(y) <=> x <= y"""
-        raise NotImplementedError
+        # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported
+        if isinstance(other, ndarray):
+            return _npi.less_equal(self, other)
+        elif isinstance(other, numeric_types):
+            return _npi.less_equal_scalar(self, float(other))
+        else:
+            raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
 
     def __bool__(self):
-        raise NotImplementedError
+        num_elements = self.size
+        if num_elements == 0:
+            warnings.simplefilter('default')
+            warnings.warn('The truth value of an empty array is ambiguous. Returning False, but in'
+                          ' future this will result in an error.', DeprecationWarning)
+            return False
+        elif num_elements == 1:
+            return bool(self.item())
+        else:
+            raise ValueError("The truth value of an ndarray with multiple elements is ambiguous.")
 
     def __len__(self):
         """Number of elements along the first axis."""
@@ -1329,3 +1423,66 @@ def stack(arrays, axis=0, out=None):
     stacked : ndarray
         The stacked array has one more dimension than the input arrays."""
     return _mx_nd_np.stack(arrays, axis=axis, out=out)
+
+
+@set_module('mxnet.numpy')
+def arange(start, stop=None, step=1, dtype=None, ctx=None):
+    """Return evenly spaced values within a given interval.
+
+    Values are generated within the half-open interval ``[start, stop)``
+    (in other words, the interval including `start` but excluding `stop`).
+    For integer arguments the function is equivalent to the Python built-in
+    `range` function, but returns an ndarray rather than a list.
+
+    Parameters
+    ----------
+    start : number, optional
+        Start of interval. The interval includes this value.  The default
+        start value is 0.
+    stop : number
+        End of interval. The interval does not include this value, except
+        in some cases where `step` is not an integer and floating point
+        round-off affects the length of `out`.
+    step : number, optional
+        Spacing between values. For any output `out`, this is the distance
+        between two adjacent values, ``out[i+1] - out[i]``.  The default
+        step size is 1.  If `step` is specified as a position argument,
+        `start` must also be given.
+    dtype : dtype
+        The type of the output array. The default is `float32`.
+
+    Returns
+    -------
+    arange : ndarray
+        Array of evenly spaced values.
+
+        For floating point arguments, the length of the result is
+        ``ceil((stop - start)/step)``.  Because of floating point overflow,
+        this rule may result in the last element of `out` being greater
+        than `stop`.
+    """
+    return _mx_nd_np.arange(start, stop, step, dtype, ctx)
+
+
+@set_module('mxnet.numpy')
+def argmax(a, axis=None, out=None):
+    """Returns the indices of the maximum values along an axis.
+
+    Parameters
+    ----------
+    a : ndarray
+        Input array. Only support ndarrays of dtype `float16`, `float32`, and `float64`.
+    axis : int, optional
+        By default, the index is into the flattened array, otherwise
+        along the specified axis.
+    out : array, optional
+        If provided, the result will be inserted into this array. It should
+        be of the appropriate shape and dtype.
+
+    Returns
+    -------
+    index_array : ndarray of indices whose dtype is same as the input ndarray.
+        Array of indices into the array. It has the same shape as `a.shape`
+        with the dimension along `axis` removed.
+    """
+    return _mx_nd_np.argmax(a, axis, out)
diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py
index e49bfcf..baeab8b 100644
--- a/python/mxnet/numpy/random.py
+++ b/python/mxnet/numpy/random.py
@@ -17,4 +17,84 @@
 
 """Namespace for ops used in imperative programming."""
 
-__all__ = []
+from __future__ import absolute_import
+from ..ndarray import numpy as _mx_nd_np
+
+__all__ = ['uniform', 'normal']
+
+
+def uniform(low=0.0, high=1.0, size=None, **kwargs):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval
+    ``[low, high)`` (includes low, but excludes high).  In other words,
+    any value within the given interval is equally likely to be drawn
+    by `uniform`.
+
+    Parameters
+    ----------
+    low : float, optional
+        Lower boundary of the output interval.  All values generated will be
+        greater than or equal to low.  The default value is 0.
+    high : float
+        Upper boundary of the output interval.  All values generated will be
+        less than high.  The default value is 1.0.
+    size : int or tuple of ints, optional
+        Output shape.  If the given shape is, e.g., ``(m, n, k)``, then
+        ``m * n * k`` samples are drawn.  If size is ``None`` (default),
+        a scalar tensor containing a single value is returned if
+        ``low`` and ``high`` are both scalars.
+    dtype : {'float16', 'float32', 'float64'}, optional
+        Data type of output samples. Default is 'float32'
+    ctx : Context, optional
+        Device context of output. Default is current context.
+    out : ndarray, optional
+        Store output to an existing ndarray.
+
+    Returns
+    -------
+    out : ndarray
+        Drawn samples from the parameterized uniform distribution.
+
+
+    Notes
+    -----
+    This function currently does not support ``low`` and ``high`` as ndarrays.
+    """
+    return _mx_nd_np.random.uniform(low, high, size, **kwargs)
+
+
+def normal(loc=0.0, scale=1.0, size=None, **kwargs):
+    """Draw random samples from a normal (Gaussian) distribution.
+
+    Samples are distributed according to a normal distribution parametrized
+    by *loc* (mean) and *scale* (standard deviation).
+
+
+    Parameters
+    ----------
+    loc : float, optional
+        Mean (centre) of the distribution.
+    scale : float, optional
+        Standard deviation (spread or "width") of the distribution.
+    size : int or tuple of ints, optional
+        Output shape. If the given shape is, e.g., `(m, n, k)`, then `m * n * k`
+        samples are drawn. If size is `None` (default), a scalar tensor containing
+        a single value is returned if loc and scale are both scalars.
+    dtype : {'float16', 'float32', 'float64'}, optional
+        Data type of output samples. Default is 'float32'
+    ctx : Context, optional
+        Device context of output. Default is current context.
+    out : ``ndarray``, optional
+        Store output to an existing ``ndarray``.
+
+    Returns
+    -------
+    out : ndarray
+        Drawn samples from the parameterized normal distribution.
+
+    Notes
+    -----
+    This function currently does not support ``loc`` and ``scale`` as ndarrays.
+    """
+    return _mx_nd_np.random.normal(loc, scale, size, **kwargs)
diff --git a/python/mxnet/numpy_extension/__init__.py b/python/mxnet/numpy_extension/__init__.py
index bd51175..0c89a88 100644
--- a/python/mxnet/numpy_extension/__init__.py
+++ b/python/mxnet/numpy_extension/__init__.py
@@ -24,5 +24,8 @@ from . import _op
 from . import _register
 from ._op import *  # pylint: disable=wildcard-import
 from ..context import *  # pylint: disable=wildcard-import
+from ..util import use_np_shape, np_shape, is_np_shape
+from ..util import use_np_array, np_array, is_np_array, use_np
+from .. import autograd
 
 __all__ = []
diff --git a/python/mxnet/symbol/__init__.py b/python/mxnet/symbol/__init__.py
index 1cd8057..2ce395b 100644
--- a/python/mxnet/symbol/__init__.py
+++ b/python/mxnet/symbol/__init__.py
@@ -28,7 +28,7 @@ from .op import *
 from .symbol import *
 # pylint: enable=wildcard-import
 from . import numpy as np
-from . import numpy_extension as npe
+from . import numpy_extension as npx
 
 __all__ = op.__all__ + symbol.__all__\
           + ['contrib', 'linalg', 'random', 'sparse', 'image', 'numpy', 'numpy_extension']
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index d55a878..b2d8a5b 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -29,7 +29,7 @@ from ..symbol import Symbol
 from .._internal import _set_np_symbol_class
 from . import _internal as _npi
 
-__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack']
+__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax']
 
 
 @set_module('mxnet.symbol.numpy')
@@ -114,8 +114,7 @@ class _Symbol(Symbol):
         elif isinstance(other, numeric_types):
             return _npi.mod_scalar(self, float(other))
         else:
-            raise TypeError("_Symbol does not support type {} as operand"
-                            .format(str(type(other))))
+            raise TypeError("_Symbol does not support type {} as operand".format(str(type(other))))
 
     def __rmod__(self, other):
         """x.__rmod__(y) <=> y % x"""
@@ -124,8 +123,7 @@ class _Symbol(Symbol):
         elif isinstance(other, numeric_types):
             return _npi.rmod_scalar(self, float(other))
         else:
-            raise TypeError("_Symbol does not support type {} as operand"
-                            .format(str(type(other))))
+            raise TypeError("_Symbol does not support type {} as operand".format(str(type(other))))
 
     def __idiv__(self, other):
         raise NotImplementedError
@@ -137,8 +135,7 @@ class _Symbol(Symbol):
         elif isinstance(other, numeric_types):
             return _npi.true_divide_scalar(self, float(other))
         else:
-            raise TypeError("_Symbol does not support type {} as divisor"
-                            .format(str(type(other))))
+            raise TypeError("_Symbol does not support type {} as divisor".format(str(type(other))))
 
     def __rtruediv__(self, other):
         """x.__rtruediv__(y) <=> y / x"""
@@ -147,8 +144,7 @@ class _Symbol(Symbol):
         elif isinstance(other, numeric_types):
             return _npi.rtrue_divide_scalar(self, float(other)).as_np_ndarray()
         else:
-            raise TypeError("_Symbol does not support type {} as dividend"
-                            .format(str(type(other))))
+            raise TypeError("_Symbol does not support type {} as dividend".format(str(type(other))))
 
     def __itruediv__(self, other):
         raise NotImplementedError
@@ -160,8 +156,7 @@ class _Symbol(Symbol):
         elif isinstance(other, numeric_types):
             return _npi.power_scalar(self, float(other))
         else:
-            raise TypeError("_Symbol does not support type {} as operand"
-                            .format(str(type(other))))
+            raise TypeError("_Symbol does not support type {} as operand".format(str(type(other))))
 
     def __rpow__(self, other):
         """x.__rpow__(y) <=> y ** x"""
@@ -170,8 +165,7 @@ class _Symbol(Symbol):
         elif isinstance(other, numeric_types):
             return _npi.rpower_scalar(self, float(other))
         else:
-            raise TypeError("_Symbol does not support type {} as operand"
-                            .format(str(type(other))))
+            raise TypeError("_Symbol does not support type {} as operand".format(str(type(other))))
 
     def __neg__(self):
         """x.__neg__() <=> - x"""
@@ -182,27 +176,63 @@ class _Symbol(Symbol):
 
     def __eq__(self, other):
         """x.__eq__(y) <=> x == y"""
-        raise NotImplementedError
+        # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported
+        if isinstance(other, _Symbol):
+            return _npi.equal(self, other)
+        elif isinstance(other, numeric_types):
+            return _npi.equal_scalar(self, float(other))
+        else:
+            raise TypeError("_Symbol does not support type {} as operand".format(str(type(other))))
 
     def __ne__(self, other):
         """x.__ne__(y) <=> x != y"""
-        raise NotImplementedError
+        # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported
+        if isinstance(other, _Symbol):
+            return _npi.not_equal(self, other)
+        elif isinstance(other, numeric_types):
+            return _npi.not_equal_scalar(self, float(other))
+        else:
+            raise TypeError("_Symbol does not support type {} as operand".format(str(type(other))))
 
     def __gt__(self, other):
         """x.__gt__(y) <=> x > y"""
-        raise NotImplementedError
+        # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported
+        if isinstance(other, _Symbol):
+            return _npi.greater(self, other)
+        elif isinstance(other, numeric_types):
+            return _npi.greater_scalar(self, float(other))
+        else:
+            raise TypeError("_Symbol does not support type {} as operand".format(str(type(other))))
 
     def __ge__(self, other):
         """x.__ge__(y) <=> x >= y"""
-        raise NotImplementedError
+        # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported
+        if isinstance(other, _Symbol):
+            return _npi.greater_equal(self, other)
+        elif isinstance(other, numeric_types):
+            return _npi.greater_equal_scalar(self, float(other))
+        else:
+            raise TypeError("_Symbol does not support type {} as operand".format(str(type(other))))
 
     def __lt__(self, other):
         """x.__lt__(y) <=> x < y"""
-        raise NotImplementedError
+        # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported
+        if isinstance(other, _Symbol):
+            return _npi.less(self, other)
+        elif isinstance(other, numeric_types):
+            return _npi.less_scalar(self, float(other))
+        else:
+            raise TypeError("_Symbol does not support type {} as operand".format(str(type(other))))
 
     def __le__(self, other):
         """x.__le__(y) <=> x <= y"""
-        raise NotImplementedError
+        # TODO(junwu): Return boolean ndarray when dtype=bool_ is supported
+        if isinstance(other, _Symbol):
+            return _npi.less_equal(self, other)
+        elif isinstance(other, numeric_types):
+            return _npi.less_equal_scalar(self, float(other))
+        else:
+            raise TypeError("_Symbol does not support type {} as operand".format(str(type(other))))
 
     def __len__(self):
         raise NotImplementedError
@@ -228,8 +258,8 @@ class _Symbol(Symbol):
 
     def reshape(self, shape, order='C'):  # pylint: disable=arguments-differ
         if order != 'C':
-            raise NotImplementedError('ndarray.copy only supports order=\'C\', while '
-                                      'received {}'.format(str(order)))
+            raise NotImplementedError('only supports order=\'C\', while received {}'
+                                      .format(str(order)))
         return _mx_np_op.reshape(self, newshape=shape, order=order)
 
     def reshape_like(self, *args, **kwargs):
@@ -1030,4 +1060,80 @@ def stack(arrays, axis=0, out=None):
     return _npi.stack(*arrays, axis=axis, out=out)
 
 
+@set_module('mxnet.symbol.numpy')
+def arange(start, stop=None, step=1, dtype=None, ctx=None):
+    """Return evenly spaced values within a given interval.
+
+    Values are generated within the half-open interval ``[start, stop)``
+    (in other words, the interval including `start` but excluding `stop`).
+    For integer arguments the function is equivalent to the Python built-in
+    `range` function, but returns an ndarray rather than a list.
+
+    Parameters
+    ----------
+    start : number, optional
+        Start of interval. The interval includes this value.  The default
+        start value is 0.
+    stop : number
+        End of interval. The interval does not include this value, except
+        in some cases where `step` is not an integer and floating point
+        round-off affects the length of `out`.
+    step : number, optional
+        Spacing between values. For any output `out`, this is the distance
+        between two adjacent values, ``out[i+1] - out[i]``.  The default
+        step size is 1.  If `step` is specified as a position argument,
+        `start` must also be given.
+    dtype : dtype
+        The type of the output array. The default is `float32`.
+
+    Returns
+    -------
+    arange : ndarray
+        Array of evenly spaced values.
+
+        For floating point arguments, the length of the result is
+        ``ceil((stop - start)/step)``.  Because of floating point overflow,
+        this rule may result in the last element of `out` being greater
+        than `stop`.
+    """
+    if dtype is None:
+        dtype = 'float32'
+    if ctx is None:
+        ctx = current_context()
+    if stop is None:
+        stop = start
+        start = 0
+    if step is None:
+        step = 1
+    if start is None and stop is None:
+        raise ValueError('start and stop cannot be both None')
+    if step == 0:
+        raise ZeroDivisionError('step cannot be 0')
+    return _npi.arange(start=start, stop=stop, step=step, dtype=dtype, ctx=ctx)
+
+
+@set_module('mxnet.symbol.numpy')
+def argmax(a, axis=None, out=None):
+    """Returns the indices of the maximum values along an axis.
+
+    Parameters
+    ----------
+    a : ndarray
+        Input array. Only support ndarrays of dtype `float16`, `float32`, and `float64`.
+    axis : int, optional
+        By default, the index is into the flattened array, otherwise
+        along the specified axis.
+    out : array, optional
+        If provided, the result will be inserted into this array. It should
+        be of the appropriate shape and dtype.
+
+    Returns
+    -------
+    index_array : ndarray of indices whose dtype is same as the input ndarray.
+        Array of indices into the array. It has the same shape as `a.shape`
+        with the dimension along `axis` removed.
+    """
+    return _npi.argmax(a, axis=axis, keepdims=False, out=out)
+
+
 _set_np_symbol_class(_Symbol)
diff --git a/python/mxnet/symbol/numpy/linalg.py b/python/mxnet/symbol/numpy/linalg.py
index 869fdeb..2cb0d22 100644
--- a/python/mxnet/symbol/numpy/linalg.py
+++ b/python/mxnet/symbol/numpy/linalg.py
@@ -17,4 +17,51 @@
 
 """Namespace for operators used in Gluon dispatched by F=symbol."""
 
-__all__ = []
+from __future__ import absolute_import
+from . import _op as _mx_nd_np
+
+__all__ = ['norm']
+
+
+def norm(x, ord=None, axis=None, keepdims=False):
+    r"""Matrix or vector norm.
+
+    This function can only support Frobenius norm for now.
+    The Frobenius norm is given by [1]_:
+
+        :math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}`
+
+    Parameters
+    ----------
+    x : ndarray
+        Input array.
+    ord : {'fro'}, optional
+        Order of the norm.
+    axis : {int, 2-tuple of ints, None}, optional
+        If `axis` is an integer, it specifies the axis of `x` along which to
+        compute the vector norms.  If `axis` is a 2-tuple, it specifies the
+        axes that hold 2-D matrices, and the matrix norms of these matrices
+        are computed.  If `axis` is None, the norm of the whole ndarray is
+        returned.
+
+    keepdims : bool, optional
+        If this is set to True, the axes which are normed over are left in the
+        result as dimensions with size one.  With this option the result will
+        broadcast correctly against the original `x`.
+
+    Returns
+    -------
+    n : float or ndarray
+        Norm of the matrix or vector(s).
+
+    References
+    ----------
+    .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*,
+           Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
+    """
+    if ord is not None and ord != 'fro':
+        raise ValueError('only support Frobenius norm for now, received ord={}'.format(str(ord)))
+    if isinstance(axis, tuple) and len(axis) > 2:
+        raise ValueError('Improper number of dimensions to norm')
+    # TODO(junwu): When ord = 'fro', axis = None, and x.ndim > 2, raise exception
+    return _mx_nd_np.sqrt(_mx_nd_np.sum(x * x, axis=axis, keepdims=keepdims))
diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py
index 869fdeb..fd73478 100644
--- a/python/mxnet/symbol/numpy/random.py
+++ b/python/mxnet/symbol/numpy/random.py
@@ -17,4 +17,122 @@
 
 """Namespace for operators used in Gluon dispatched by F=symbol."""
 
-__all__ = []
+from __future__ import absolute_import
+from ...base import numeric_types
+from ...context import current_context
+from . import _internal as _npi
+
+__all__ = ['uniform', 'normal']
+
+
+def _random_helper(random, sampler, params, shape, dtype, ctx, out, kwargs):
+    """Helper function for random generators."""
+    from ._symbol import _Symbol as np_symbol
+    if isinstance(params[0], np_symbol):
+        for i in params[1:]:
+            assert isinstance(i, np_symbol), \
+                "Distribution parameters must all have the same type, but got " \
+                "both %s and %s." % (type(params[0]), type(i))
+        return sampler(*params, shape=shape, dtype=dtype, out=out, **kwargs)
+    elif isinstance(params[0], numeric_types):
+        if ctx is None:
+            ctx = current_context()
+        if shape is None and out is None:
+            shape = ()
+        for i in params[1:]:
+            assert isinstance(i, numeric_types), \
+                "Distribution parameters must all have the same type, but got " \
+                "both %s and %s."%(type(params[0]), type(i))
+        return random(*params, shape=shape, dtype=dtype, ctx=ctx, out=out, **kwargs)
+
+    raise ValueError("Distribution parameters must be either mxnet.numpy.ndarray or numbers, "
+                     "but got %s." % type(params[0]))
+
+
+def uniform(low=0.0, high=1.0, size=None, **kwargs):
+    """Draw samples from a uniform distribution.
+
+    Samples are uniformly distributed over the half-open interval
+    ``[low, high)`` (includes low, but excludes high).  In other words,
+    any value within the given interval is equally likely to be drawn
+    by `uniform`.
+
+    Parameters
+    ----------
+    low : float, optional
+        Lower boundary of the output interval.  All values generated will be
+        greater than or equal to low.  The default value is 0.
+    high : float
+        Upper boundary of the output interval.  All values generated will be
+        less than high.  The default value is 1.0.
+    size : int or tuple of ints, optional
+        Output shape.  If the given shape is, e.g., ``(m, n, k)``, then
+        ``m * n * k`` samples are drawn.  If size is ``None`` (default),
+        a scalar tensor containing a single value is returned if
+        ``low`` and ``high`` are both scalars.
+    dtype : {'float16', 'float32', 'float64'}, optional
+        Data type of output samples. Default is 'float32'
+    ctx : Context, optional
+        Device context of output. Default is current context.
+    out : ndarray, optional
+        Store output to an existing ndarray.
+
+    Returns
+    -------
+    out : _Symbol (symbol representing `mxnet.numpy.ndarray` in computational graphs)
+        Drawn samples from the parameterized uniform distribution.
+
+
+    Notes
+    -----
+    This function currently does not support ``low`` and ``high`` as symbols.
+    """
+    dtype = kwargs.pop('dtype', None)
+    if dtype is None:
+        dtype = 'float32'
+    ctx = kwargs.pop('ctx', None)
+    out = kwargs.pop('out', None)
+    return _random_helper(_npi.random_uniform, None,
+                          [low, high], size, dtype, ctx, out, kwargs)
+
+
+def normal(loc=0.0, scale=1.0, size=None, **kwargs):
+    """Draw random samples from a normal (Gaussian) distribution.
+
+    Samples are distributed according to a normal distribution parametrized
+    by *loc* (mean) and *scale* (standard deviation).
+
+
+    Parameters
+    ----------
+    loc : float, optional
+        Mean (centre) of the distribution.
+    scale : float, optional
+        Standard deviation (spread or "width") of the distribution.
+    size : int or tuple of ints, optional
+        Output shape. If the given shape is, e.g., `(m, n, k)`, then `m * n * k`
+        samples are drawn. If size is `None` (default), a scalar tensor containing
+        a single value is returned if loc and scale are both scalars.
+    dtype : {'float16', 'float32', 'float64'}, optional
+        Data type of output samples. Default is 'float32'
+    ctx : Context, optional
+        Device context of output. Default is current context.
+    out : ``ndarray``, optional
+        Store output to an existing ``ndarray``.
+
+    Returns
+    -------
+    out : _Symbol (symbol representing `mxnet.numpy.ndarray` in computational graphs)
+        Drawn samples from the parameterized normal distribution.
+
+    Notes
+    -----
+    This function currently does not support ``loc`` and ``scale`` as `_Symbol`s.
+    """
+    dtype = kwargs.pop('dtype', None)
+    if dtype is None:
+        dtype = 'float32'
+    ctx = kwargs.pop('ctx', None)
+    out = kwargs.pop('out', None)
+    return _random_helper(_npi.random_normal, None,
+                          [loc, scale], size, dtype, ctx, out, kwargs)
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index fd7bbb4..6f423f9 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -940,7 +940,7 @@ def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rto
     input_shape = {k: v.shape for k, v in location.items()}
     _, out_shape, _ = sym.infer_shape(**input_shape)
     proj = mx.sym.Variable("__random_proj")
-    is_np_sym = True if isinstance(sym, np_symbol) else False
+    is_np_sym = bool(isinstance(sym, np_symbol))
     if is_np_sym:  # convert to np symbol for using element-wise multiplication
         proj = proj.as_np_ndarray()
     out = sym * proj
diff --git a/python/mxnet/util.py b/python/mxnet/util.py
index 091adb7..ce4e7dd 100644
--- a/python/mxnet/util.py
+++ b/python/mxnet/util.py
@@ -22,6 +22,7 @@ import sys
 import functools
 import itertools
 import inspect
+import threading
 
 from .base import _LIB, check_call
 
@@ -84,8 +85,7 @@ def set_np_shape(active):
 
 
 def is_np_shape():
-    """
-    Checks whether the NumPy shape semantics is currently turned on.
+    """Checks whether the NumPy shape semantics is currently turned on.
     In NumPy shape semantics, `()` represents the shape of scalar tensors,
     and tuples with `0` elements, for example, `(0,)`, `(1, 0, 2)`, represent
     the shapes of zero-size tensors. This is turned off by default for keeping
@@ -264,12 +264,12 @@ def use_np_shape(func):
 
     Parameters
     ----------
-    func : a user-provided callable function or class to be scoped by the NumPy compatibility state.
+    func : a user-provided callable function or class to be scoped by the NumPy-shape semantics.
 
     Returns
     -------
     Function or class
-        A function or class wrapped in the NumPy compatibility scope.
+        A function or class wrapped in the NumPy-shape scope.
     """
 
     if inspect.isclass(func):
@@ -319,3 +319,225 @@ def set_module(module):
             func.__module__ = module
         return func
     return decorator
+
+
+class _NumpyArrayScope(object):
+    """Scope for managing NumPy array creation. This is often used
+    with `is_np_array=True` in initializer to enforce array creation
+    as type `mxnet.numpy.ndarray`, instead of `mx.nd.NDArray` in Gluon.
+
+    Do not use this class directly. Use `np_array(active)` instead.
+    """
+    _current = threading.local()
+
+    def __init__(self, is_np_array):  #pylint: disable=redefined-outer-name
+        self._old_scope = None
+        self._is_np_array = is_np_array
+
+    def __enter__(self):
+        if not hasattr(_NumpyArrayScope._current, "value"):
+            _NumpyArrayScope._current.value = _NumpyArrayScope(False)
+        self._old_scope = _NumpyArrayScope._current.value
+        _NumpyArrayScope._current.value = self
+        return self
+
+    def __exit__(self, ptype, value, trace):
+        assert self._old_scope
+        _NumpyArrayScope._current.value = self._old_scope
+
+
+def np_array(active=True):
+    """Returns an activated/deactivated NumPy-array scope to be used in 'with' statement
+    and captures code that needs the NumPy-array semantics.
+
+    Currently, this is used in Gluon to enforce array creation in `Block`s as type
+    `mxnet.numpy.ndarray`, instead of `mx.nd.NDArray`.
+
+    It is recommended to use the decorator `use_np_array` to decorate the classes
+    that need this semantics, instead of using this function in a `with` statement
+    unless you know exactly what has been scoped by this semantics.
+
+    Please note that this is designed as an infrastructure for the incoming
+    MXNet-NumPy operators. Legacy operators registered in the modules
+    `mx.nd` and `mx.sym` are not guaranteed to behave like their counterparts
+    in NumPy even within this scope.
+
+    Parameters
+    ----------
+    active : bool
+        Indicates whether to activate NumPy-array semantics.
+
+    Returns
+    -------
+    _NumpyShapeScope
+        A scope object for wrapping the code w/ or w/o NumPy-shape semantics.
+    """
+    return _NumpyArrayScope(active)
+
+
+def is_np_array():
+    """Checks whether the NumPy-array semantics is currently turned on.
+    This is currently used in Gluon for checking whether an array of type `mxnet.numpy.ndarray`
+    or `mx.nd.NDArray` should be created. For example, at the time when a parameter
+    is created in a `Block`, an `mxnet.numpy.ndarray` is created if this returns true; else
+    an `mx.nd.NDArray` is created.
+
+    Normally, users are not recommended to use this API directly unless you known exactly
+    what is going on under the hood.
+
+    Please note that this is designed as an infrastructure for the incoming
+    MXNet-NumPy operators. Legacy operators registered in the modules
+    `mx.nd` and `mx.sym` are not guaranteed to behave like their counterparts
+    in NumPy within this semantics.
+
+    Returns
+    -------
+        A bool value indicating whether the NumPy-array semantics is currently on.
+    """
+    return _NumpyArrayScope._current.value._is_np_array if hasattr(
+        _NumpyArrayScope._current, "value") else False
+
+
+def use_np_array(func):
+    """A decorator wrapping Gluon `Block`s and all its methods, properties, and static functions
+    with the semantics of NumPy-array, which means that where ndarrays are created,
+    `mxnet.numpy.ndarray`s should be created, instead of legacy ndarrays of type `mx.nd.NDArray`.
+    For example, at the time when a parameter is created in a `Block`, an `mxnet.numpy.ndarray`
+    is created if it's decorated with this decorator.
+
+    Example::
+        import mxnet as mx
+        from mxnet import gluon, np
+
+
+        class TestHybridBlock1(gluon.HybridBlock):
+            def __init__(self):
+                super(TestHybridBlock1, self).__init__()
+                self.w = self.params.get('w', shape=(2, 2))
+
+            def hybrid_forward(self, F, x, w):
+                return F.dot(x, w)
+
+
+        x = mx.nd.ones((2, 2))
+        net1 = TestHybridBlock1()
+        net1.initialize()
+        out = net1.forward(x)
+        for _, v in net1.collect_params().items():
+            assert type(v.data()) is mx.nd.NDArray
+        assert type(out) is mx.nd.NDArray
+
+
+        @np.use_np_array
+        class TestHybridBlock2(gluon.HybridBlock):
+            def __init__(self):
+                super(TestHybridBlock2, self).__init__()
+                self.w = self.params.get('w', shape=(2, 2))
+
+            def hybrid_forward(self, F, x, w):
+                return F.np.dot(x, w)
+
+
+        x = np.ones((2, 2))
+        net2 = TestHybridBlock2()
+        net2.initialize()
+        out = net2.forward(x)
+        for _, v in net2.collect_params().items():
+            print(type(v.data()))
+            assert type(v.data()) is np.ndarray
+        assert type(out) is np.ndarray
+
+    Parameters
+    ----------
+    func : a user-provided callable function or class to be scoped by the NumPy-array semantics.
+
+    Returns
+    -------
+    Function or class
+        A function or class wrapped in the NumPy-array scope.
+    """
+    if inspect.isclass(func):
+        for name, method in inspect.getmembers(
+                func,
+                predicate=
+                lambda f: inspect.isfunction(f) or inspect.ismethod(f) or isinstance(f, property)):
+            if isinstance(method, property):
+                setattr(func, name, property(use_np_array(method.__get__),
+                                             method.__set__,
+                                             method.__delattr__,
+                                             method.__doc__))
+            else:
+                setattr(func, name, use_np_array(method))
+        return func
+    elif callable(func):
+        @wraps_safely(func)
+        def _with_np_array(*args, **kwargs):
+            with np_array(active=True):
+                return func(*args, **kwargs)
+        return _with_np_array
+    else:
+        raise TypeError('use_np_array can only decorate classes and callable objects, '
+                        'while received a {}'.format(str(type(func))))
+
+
+def use_np(func):
+    """A convenience decorator for wrapping user provided functions and classes in the scope of
+    both NumPy-shape and NumPy-array semantics, which means that (1) empty tuples `()` and tuples
+    with zeros, such as `(0, 1)`, `(1, 0, 2)`, will be treated as scalar tensors' shapes and
+    zero-size tensors' shapes in shape inference functions of operators, instead of as unknown
+    in legacy mode; (2) ndarrays of type `mxnet.numpy.ndarray` should be created instead of
+    `mx.nd.NDArray`.
+
+    Example::
+        import mxnet as mx
+        from mxnet import gluon, np
+
+
+        class TestHybridBlock1(gluon.HybridBlock):
+            def __init__(self):
+                super(TestHybridBlock1, self).__init__()
+                self.w = self.params.get('w', shape=(2, 2))
+
+            def hybrid_forward(self, F, x, w):
+                return F.dot(x, w) + F.ones((1,))
+
+
+        x = mx.nd.ones((2, 2))
+        net1 = TestHybridBlock1()
+        net1.initialize()
+        out = net1.forward(x)
+        for _, v in net1.collect_params().items():
+            assert type(v.data()) is mx.nd.NDArray
+        assert type(out) is mx.nd.NDArray
+
+
+        @np.use_np
+        class TestHybridBlock2(gluon.HybridBlock):
+            def __init__(self):
+                super(TestHybridBlock2, self).__init__()
+                self.w = self.params.get('w', shape=(2, 2))
+
+            def hybrid_forward(self, F, x, w):
+                return F.np.dot(x, w) + F.np.ones(())
+
+
+        x = np.ones((2, 2))
+        net2 = TestHybridBlock2()
+        net2.initialize()
+        out = net2.forward(x)
+        for _, v in net2.collect_params().items():
+            print(type(v.data()))
+            assert type(v.data()) is np.ndarray
+        assert type(out) is np.ndarray
+
+    Parameters
+    ----------
+    func : a user-provided callable function or class to be scoped by the
+    NumPy-shape and NumPy-array semantics.
+
+    Returns
+    -------
+    Function or class
+        A function or class wrapped in the Numpy-shape and NumPy-array scope.
+    """
+    return use_np_array(use_np_shape(func))
diff --git a/src/operator/numpy/np_broadcast_reduce_op_index.cc b/src/operator/numpy/np_broadcast_reduce_op_index.cc
new file mode 100644
index 0000000..bd6915c
--- /dev/null
+++ b/src/operator/numpy/np_broadcast_reduce_op_index.cc
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file np_broadcast_reduce_op_index.cc
+ * \brief CPU Implementation of broadcast and reduce functions based on index.
+ */
+#include "./np_broadcast_reduce_op.h"
+
+namespace mxnet {
+namespace op {
+
+bool NumpyReduceAxisShape(const nnvm::NodeAttrs& attrs,
+                          std::vector<TShape> *in_attrs,
+                          std::vector<TShape> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  if (!shape_is_known(in_attrs->at(0))) {
+    return false;
+  }
+  const ReduceAxisParam& param = nnvm::get<ReduceAxisParam>(attrs.parsed);
+  dmlc::optional<mxnet::Tuple<int>> axes;
+  if (param.axis.has_value()) {
+    mxnet::Tuple<int> t({param.axis.value()});
+    axes = dmlc::optional<mxnet::Tuple<int>>(t);
+  }
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0,
+                     NumpyReduceAxesShapeImpl((*in_attrs)[0], axes, param.keepdims));
+  return shape_is_known(out_attrs->at(0));
+}
+
+NNVM_REGISTER_OP(_npi_argmax)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<ReduceAxisParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", NumpyReduceAxisShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.add_argument("data", "NDArray-or-Symbol", "The input")
+.set_attr<FCompute>("FCompute<cpu>", SearchAxisCompute<cpu, mshadow::red::maximum>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.add_arguments(ReduceAxisParam::__FIELDS__());
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/numpy/np_init_op.cu b/src/operator/numpy/np_broadcast_reduce_op_index.cu
similarity index 66%
copy from src/operator/numpy/np_init_op.cu
copy to src/operator/numpy/np_broadcast_reduce_op_index.cu
index 2eb8ed6..aae66a6 100644
--- a/src/operator/numpy/np_init_op.cu
+++ b/src/operator/numpy/np_broadcast_reduce_op_index.cu
@@ -19,26 +19,16 @@
 
 /*!
  *  Copyright (c) 2019 by Contributors
- * \file np_init_op.cu
- * \brief GPU Implementation of numpy init op
+ * \file np_broadcast_reduce_op_index.cu
+ * \brief GPU Implementation of reduce functions.
  */
-
-#include "../tensor/init_op.h"
+#include "np_broadcast_reduce_op.h"
 
 namespace mxnet {
 namespace op {
 
-NNVM_REGISTER_OP(_npi_zeros)
-.set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 0>);
-
-NNVM_REGISTER_OP(_npi_ones)
-.set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 1>);
-
-NNVM_REGISTER_OP(_np_zeros_like)
-.set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 0>);
-
-NNVM_REGISTER_OP(_np_ones_like)
-.set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 1>);
+NNVM_REGISTER_OP(_npi_argmax)
+.set_attr<FCompute>("FCompute<gpu>", SearchAxisCompute<gpu, mshadow::red::maximum>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc
index a72efd9..078cd46 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_value.cc
+++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc
@@ -19,7 +19,7 @@
 
 /*!
  *  Copyright (c) 2019 by Contributors
- * \file np_reduce_op_value.cc
+ * \file np_broadcast_reduce_op_value.cc
  * \brief CPU Implementation of broadcast and reduce functions based on value.
  */
 
diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cu b/src/operator/numpy/np_broadcast_reduce_op_value.cu
index 2f50738..7740c03 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_value.cu
+++ b/src/operator/numpy/np_broadcast_reduce_op_value.cu
@@ -19,7 +19,7 @@
 
 /*!
  *  Copyright (c) 2019 by Contributors
- * \file np_reduce_op_value.cu
+ * \file np_broadcast_reduce_op_value.cu
  * \brief GPU Implementation of reduce functions based on value.
  */
 #include "np_broadcast_reduce_op.h"
diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cc b/src/operator/numpy/np_elemwise_unary_op_basic.cc
index 87a765e..1acec6f 100644
--- a/src/operator/numpy/np_elemwise_unary_op_basic.cc
+++ b/src/operator/numpy/np_elemwise_unary_op_basic.cc
@@ -27,7 +27,7 @@
 namespace mxnet {
 namespace op {
 
-MXNET_OPERATOR_REGISTER_UNARY(_npe_relu)
+MXNET_OPERATOR_REGISTER_UNARY(_npx_relu)
 .describe(R"code(Computes rectified linear activation.
 
 .. math::
@@ -37,7 +37,7 @@ MXNET_OPERATOR_REGISTER_UNARY(_npe_relu)
 .set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::relu>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_relu"});
 
-MXNET_OPERATOR_REGISTER_UNARY(_npe_sigmoid)
+MXNET_OPERATOR_REGISTER_UNARY(_npx_sigmoid)
 .describe(R"code(Computes sigmoid of x element-wise.
 
 .. math::
diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cu b/src/operator/numpy/np_elemwise_unary_op_basic.cu
index a3cdff9..1323768 100644
--- a/src/operator/numpy/np_elemwise_unary_op_basic.cu
+++ b/src/operator/numpy/np_elemwise_unary_op_basic.cu
@@ -26,10 +26,10 @@
 namespace mxnet {
 namespace op {
 
-NNVM_REGISTER_OP(_npe_relu)
+NNVM_REGISTER_OP(_npx_relu)
 .set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::relu>);
 
-NNVM_REGISTER_OP(_npe_sigmoid)
+NNVM_REGISTER_OP(_npx_sigmoid)
 .set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::sigmoid>);
 
 NNVM_REGISTER_OP(_np_copy)
diff --git a/src/operator/numpy/np_init_op.cc b/src/operator/numpy/np_init_op.cc
index 83a44c8..9edfa20 100644
--- a/src/operator/numpy/np_init_op.cc
+++ b/src/operator/numpy/np_init_op.cc
@@ -28,6 +28,23 @@
 namespace mxnet {
 namespace op {
 
+inline bool NumpyRangeShape(const nnvm::NodeAttrs& attrs,
+                            mxnet::ShapeVector* in_shapes,
+                            mxnet::ShapeVector* out_shapes) {
+  const RangeParam& param = nnvm::get<RangeParam>(attrs.parsed);
+  CHECK_EQ(in_shapes->size(), 0U);
+  CHECK_EQ(out_shapes->size(), 1U);
+  CHECK_NE(param.step, 0) << "_npi_arange does not support step=0";
+  CHECK_EQ(param.repeat, 1) << "_npi_arange only supports repeat=1, received " << param.repeat;
+  CHECK(param.stop.has_value()) << "_npi_arange requires stop to have a value";
+  double out_size = std::ceil((param.stop.value() - param.start) / param.step);
+  if (out_size < 0) {
+    out_size = 0;
+  }
+  SHAPE_ASSIGN_CHECK(*out_shapes, 0, mxnet::TShape({static_cast<nnvm::dim_t>(out_size)}));
+  return true;
+}
+
 NNVM_REGISTER_OP(_npi_zeros)
 .describe("Return a new array of given shape, type, and context, filled with zeros.")
 .set_num_inputs(0)
@@ -107,5 +124,15 @@ Examples::
 .add_argument("a", "NDArray-or-Symbol",
               "The shape and data-type of a define these same attributes of the returned array.");
 
+NNVM_REGISTER_OP(_npi_arange)
+.describe("Return evenly spaced values within a given interval.")
+.set_num_inputs(0)
+.set_num_outputs(1)
+.set_attr_parser(RangeParamParser)
+.set_attr<mxnet::FInferShape>("FInferShape", NumpyRangeShape)
+.set_attr<nnvm::FInferType>("FInferType", InitType<RangeParam>)
+.set_attr<FCompute>("FCompute<cpu>", RangeCompute<cpu>)
+.add_arguments(RangeParam::__FIELDS__());
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/numpy/np_init_op.cu b/src/operator/numpy/np_init_op.cu
index 2eb8ed6..2c41e56 100644
--- a/src/operator/numpy/np_init_op.cu
+++ b/src/operator/numpy/np_init_op.cu
@@ -40,5 +40,8 @@ NNVM_REGISTER_OP(_np_zeros_like)
 NNVM_REGISTER_OP(_np_ones_like)
 .set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 1>);
 
+NNVM_REGISTER_OP(_npi_arange)
+.set_attr<FCompute>("FCompute<gpu>", RangeCompute<gpu>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/random/sample_op.cc b/src/operator/random/sample_op.cc
index 56a162b..5431462 100644
--- a/src/operator/random/sample_op.cc
+++ b/src/operator/random/sample_op.cc
@@ -81,6 +81,7 @@ DMLC_REGISTER_PARAMETER(SampleGenNegBinomialLikeParam);
 MXNET_OPERATOR_REGISTER_SAMPLE(_random_uniform, SampleUniformParam)
 .add_alias("uniform")
 .add_alias("random_uniform")
+.add_alias("_npi_random_uniform")
 .describe(R"code(Draw random samples from a uniform distribution.
 
 .. note:: The existing alias ``uniform`` is deprecated.
@@ -99,6 +100,7 @@ Example::
 MXNET_OPERATOR_REGISTER_SAMPLE(_random_normal, SampleNormalParam)
 .add_alias("normal")
 .add_alias("random_normal")
+.add_alias("_npi_random_normal")
 .describe(R"code(Draw random samples from a normal (Gaussian) distribution.
 
 .. note:: The existing alias ``normal`` is deprecated.
diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h
index a6ee242..cba9821 100644
--- a/src/operator/tensor/broadcast_reduce_op.h
+++ b/src/operator/tensor/broadcast_reduce_op.h
@@ -168,15 +168,24 @@ struct BroadcastLikeParam : public dmlc::Parameter<BroadcastLikeParam> {
   }
 };
 
-inline int CheckAxis(int axis, int ndim) {
-  CHECK(axis < ndim && axis >= -ndim)
-    << "axis " << axis << " exceeds the input dimension of " << ndim;
-  return (axis + ndim)%ndim;
+inline int CheckAxis(const int axis, const int ndim) {
+  if (ndim == 0) {
+    CHECK(axis == 0 || axis == -1) << "axis " << axis << " is out of bounds for array of"
+                                                         " dimension 1";
+    return 0;
+  } else {
+    CHECK(axis < ndim && axis >= -ndim)
+        << "axis " << axis << " exceeds the input dimension of " << ndim;
+    return (axis + ndim) % ndim;
+  }
 }
 
 inline mxnet::TShape AxisShapeCompact(mxnet::TShape shape, int *axis, bool allow_2d) {
   int ndim = shape.ndim();
-  index_t leading = 1, trailing = 1, M = shape[*axis];
+  index_t leading = 1, trailing = 1, M = 1;
+  if (shape.ndim() > *axis) {
+    M = shape[*axis];
+  }
   for (int i = 0; i < *axis; ++i) leading *= shape[i];
   for (int i = *axis + 1; i < ndim; ++i) trailing *= shape[i];
   if (allow_2d && trailing == 1) {
@@ -553,14 +562,37 @@ void SearchAxisCompute(const nnvm::NodeAttrs& attrs,
   using namespace mshadow::expr;
   const ReduceAxisParam& param = nnvm::get<ReduceAxisParam>(attrs.parsed);
   Stream<xpu> *s = ctx.get_stream<xpu>();
-  if (!param.axis) LOG(FATAL) << "Global reduction not supported yet";
+  int axis = inputs[0].ndim();
+  TBlob input = inputs[0];
+  if (param.axis.has_value()) {
+    axis = param.axis.value();
+  } else {
+    // If global reduction, reshape the input tensor into 2D shape (1, inputs[0].shape_.Size())
+    // and search on axis = 1.
+    mxnet::TShape shape_2d(2, 1);
+    shape_2d[1] = input.shape_.Size();
+    input = TBlob(input.dptr_, shape_2d, input.dev_mask(), input.type_flag_, input.dev_id());
+    axis = 1;
+  }
 
-  int axis = CheckAxis(param.axis.value(), inputs[0].shape_.ndim());
-  mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, false);
+  axis = CheckAxis(axis, input.shape_.ndim());
+  if (inputs[0].shape_.ndim() != 0) {
+    if (param.axis.has_value()) {
+      // cannot do argmax in an empty dimension
+      CHECK_NE(inputs[0].shape_[axis], 0)
+          << "searching input tensor of shape " << inputs[0].shape_
+          << " along axis = " << axis << " of zero dim-size is not allowed";
+    } else {
+      // cannot do argmax on an empty array
+      CHECK_NE(inputs[0].shape_.Size(), 0U) << "attempt to search an empty sequence";
+    }
+  }
+  if (input.shape_.Size() == 0U) return;  // zero-size tensor
+  mxnet::TShape shape = AxisShapeCompact(input.shape_, &axis, false);
   MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
     Tensor<xpu, 2, DType> out = outputs[0].get_with_shape<xpu, 2, DType>(
       Shape2(shape[0], shape[2]), s);
-    Tensor<xpu, 3, DType> in = inputs[0].get_with_shape<xpu, 3, DType>(
+    Tensor<xpu, 3, DType> in = input.get_with_shape<xpu, 3, DType>(
       shape.get<3>(), s);
     CHECK(req[0] != kAddTo) << "AddTo is not supported";
     ASSIGN_DISPATCH(out, req[0], (reduce_with_axis<reducer, true>(in, 1)));
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc b/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc
index cd433e0..e3c2e0e 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc
+++ b/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc
@@ -30,6 +30,7 @@ namespace mxnet {
 namespace op {
 
 MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_equal)
+.add_alias("_npi_equal")
 .describe(R"code(Returns the result of element-wise **equal to** (==) comparison operation with broadcasting.
 
 Example::
@@ -48,6 +49,7 @@ Example::
 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
 
 MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_not_equal)
+.add_alias("_npi_not_equal")
 .describe(R"code(Returns the result of element-wise **not equal to** (!=) comparison operation with broadcasting.
 
 Example::
@@ -66,6 +68,7 @@ Example::
 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
 
 MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_greater)
+.add_alias("_npi_greater")
 .describe(R"code(Returns the result of element-wise **greater than** (>) comparison operation with broadcasting.
 
 Example::
@@ -84,6 +87,7 @@ Example::
 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
 
 MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_greater_equal)
+.add_alias("_npi_greater_equal")
 .describe(R"code(Returns the result of element-wise **greater than or equal to** (>=) comparison operation with broadcasting.
 
 Example::
@@ -102,6 +106,7 @@ Example::
 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
 
 MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_lesser)
+.add_alias("_npi_less")
 .describe(R"code(Returns the result of element-wise **lesser than** (<) comparison operation with broadcasting.
 
 Example::
@@ -120,6 +125,7 @@ Example::
 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
 
 MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_lesser_equal)
+.add_alias("_npi_less_equal")
 .describe(R"code(Returns the result of element-wise **lesser than or equal to** (<=) comparison operation with broadcasting.
 
 Example::
diff --git a/src/operator/tensor/elemwise_binary_scalar_op_logic.cc b/src/operator/tensor/elemwise_binary_scalar_op_logic.cc
index 17e7615..87ba394 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op_logic.cc
+++ b/src/operator/tensor/elemwise_binary_scalar_op_logic.cc
@@ -71,26 +71,32 @@ static bool BinaryScalarLogicStorageType(const nnvm::NodeAttrs& attrs,
 
 
 MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_equal_scalar, mshadow_op::eq)
+.add_alias("_npi_equal_scalar")
 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
 .add_alias("_EqualScalar");
 
 MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_not_equal_scalar, mshadow_op::ne)
+.add_alias("_npi_not_equal_scalar")
 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
 .add_alias("_NotEqualScalar");
 
 MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_greater_scalar, mshadow_op::gt)
+.add_alias("_npi_greater_scalar")
 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
 .add_alias("_GreaterScalar");
 
 MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_greater_equal_scalar, mshadow_op::ge)
+.add_alias("_npi_greater_equal_scalar")
 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
 .add_alias("_GreaterEqualScalar");
 
 MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_lesser_scalar, mshadow_op::lt)
+.add_alias("_npi_less_scalar")
 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
 .add_alias("_LesserScalar");
 
 MXNET_OPERATOR_REGISTER_BINARY_SCALAR_LOGIC(_lesser_equal_scalar, mshadow_op::le)
+.add_alias("_npi_less_equal_scalar")
 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
 .add_alias("_LesserEqualScalar");
 
diff --git a/tests/python/unittest/test_contrib_amp.py b/tests/python/unittest/test_contrib_amp.py
index c11d3f7..ef3a6d8 100644
--- a/tests/python/unittest/test_contrib_amp.py
+++ b/tests/python/unittest/test_contrib_amp.py
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import unittest
 import mxnet as mx
 import warnings
 import collections
@@ -23,8 +22,6 @@ import ctypes
 import mxnet.contrib.amp as amp
 
 
-# TODO(junwu): Enable test
-@unittest.skip("Temporarily disabled for adding new np ops")
 def test_amp_coverage():
     conditional = [item[0] for item in amp.lists.symbol.CONDITIONAL_FP32_FUNCS]
 
diff --git a/tests/python/unittest/test_numpy_gluon.py b/tests/python/unittest/test_numpy_gluon.py
index b7656b7..0fcb874 100644
--- a/tests/python/unittest/test_numpy_gluon.py
+++ b/tests/python/unittest/test_numpy_gluon.py
@@ -19,7 +19,7 @@
 from __future__ import absolute_import
 from __future__ import division
 import mxnet as mx
-from mxnet import gluon, autograd, np
+from mxnet import gluon, autograd, np, npx
 
 
 def test_create_np_param():
@@ -44,7 +44,7 @@ def test_create_np_param():
         def hybrid_forward(self, F, x, w):
             return F.dot(x, w)
 
-    @np.use_np_shape
+    @npx.use_np
     class TestBlock2(gluon.HybridBlock):
         def __init__(self):
             super(TestBlock2, self).__init__()
@@ -62,9 +62,9 @@ def test_create_np_param():
 
 
 def test_optimizer_with_np_ndarrays():
-    @np.use_np_shape
+    @npx.use_np
     class LinearRegression(gluon.HybridBlock):
-        def __init__(self, num_input_dim=-1, num_hidden_dim=100, num_output_dim=10):
+        def __init__(self, num_input_dim=0, num_hidden_dim=100, num_output_dim=10):
             super(LinearRegression, self).__init__()
             with self.name_scope():
                 self.w1 = self.params.get('w1', shape=(num_input_dim, num_hidden_dim),
@@ -74,11 +74,11 @@ def test_optimizer_with_np_ndarrays():
 
         def hybrid_forward(self, F, x, w1, w2):
             h = x.dot(w1)  # equivalent to F.np.dot(x, w1)
-            h_relu = F.npe.relu(h)  # equivalent to F.relu(h) but generating np.ndarray
+            h_relu = F.npx.relu(h)  # equivalent to F.relu(h) but generating np.ndarray
             y_pred = h_relu.dot(w2)  # equivalent to F.np.dot(h_relu, w2)
             return y_pred
 
-    @np.use_np_shape
+    @npx.use_np
     class TotalLoss(gluon.HybridBlock):
         def hybrid_forward(self, F, pred, label):
             return ((pred - label) ** 2).sum()  # equivalent to F.np.sum(F.np.square(pred - label))
diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py
index 188cb6f..1c71471 100644
--- a/tests/python/unittest/test_numpy_ndarray.py
+++ b/tests/python/unittest/test_numpy_ndarray.py
@@ -20,7 +20,7 @@ from __future__ import absolute_import
 from __future__ import division
 import numpy as _np
 import mxnet as mx
-from mxnet import np
+from mxnet import np, npx
 from mxnet.gluon import HybridBlock
 from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, assert_exception
 from common import with_seed
@@ -29,9 +29,15 @@ from common import with_seed
 @with_seed()
 def test_array_creation():
     dtypes = [_np.int8, _np.int32, _np.float16, _np.float32, _np.float64, None]
-    objects = [[], (), [[1, 2], [3, 4]],
-               _np.random.uniform(size=rand_shape_nd(3, allow_zero_size=True)),
-               mx.nd.array(_np.random.uniform(size=rand_shape_nd(3, allow_zero_size=True)))]
+    objects = [
+        [],
+        (),
+        [[1, 2], [3, 4]],
+        _np.random.uniform(size=rand_shape_nd(3)),
+        _np.random.uniform(size=(3, 0, 4)),
+        np.random.uniform(size=rand_shape_nd(3)),
+        np.random.uniform(size=(3, 0, 4))
+    ]
     for dtype in dtypes:
         for src in objects:
             mx_arr = np.array(src, dtype=dtype)
@@ -47,7 +53,7 @@ def test_array_creation():
 @with_seed()
 def test_zeros():
     # test np.zeros in Gluon
-    @np.use_np_shape
+    @npx.use_np_shape
     class TestZeros(HybridBlock):
         def __init__(self, shape, dtype=None):
             super(TestZeros, self).__init__()
@@ -57,13 +63,13 @@ def test_zeros():
         def hybrid_forward(self, F, x, *args, **kwargs):
             return x + F.np.zeros(shape, dtype)
 
-    @np.use_np_shape
+    @npx.use_np_shape
     class TestZerosOutputType(HybridBlock):
         def hybrid_forward(self, F, x, *args, **kwargs):
             return x, F.np.zeros(shape=())
 
     # test np.zeros in imperative
-    @np.use_np_shape
+    @npx.use_np_shape
     def check_zero_array_creation(shape, dtype):
         np_out = _np.zeros(shape=shape, dtype=dtype)
         mx_out = np.zeros(shape=shape, dtype=dtype)
@@ -97,7 +103,7 @@ def test_zeros():
 @with_seed()
 def test_ones():
     # test np.ones in Gluon
-    @np.use_np_shape
+    @npx.use_np_shape
     class TestOnes(HybridBlock):
         def __init__(self, shape, dtype=None):
             super(TestOnes, self).__init__()
@@ -107,13 +113,13 @@ def test_ones():
         def hybrid_forward(self, F, x, *args, **kwargs):
             return x * F.np.ones(shape, dtype)
 
-    @np.use_np_shape
+    @npx.use_np_shape
     class TestOnesOutputType(HybridBlock):
         def hybrid_forward(self, F, x, *args, **kwargs):
             return x, F.np.ones(shape=())
 
     # test np.ones in imperative
-    @np.use_np_shape
+    @npx.use_np_shape
     def check_ones_array_creation(shape, dtype):
         np_out = _np.ones(shape=shape, dtype=dtype)
         mx_out = np.ones(shape=shape, dtype=dtype)
@@ -146,17 +152,24 @@ def test_ones():
 
 @with_seed()
 def test_ndarray_binary_element_wise_ops():
-    # Cannot test operators like >, because boolean arrays are not supported yet.
-    np_op_map = {'+': _np.add, '*': _np.multiply, '-': _np.subtract, '/': _np.divide,
-                 'mod': _np.mod, 'pow': _np.power,
-                 # '>': _np.greater, '>=': _np.greater_equal,
-                 # '<': _np.less, '<=': _np.less_equal
-                 }
+    np_op_map = {
+        '+': _np.add,
+        '*': _np.multiply,
+        '-': _np.subtract,
+        '/': _np.divide,
+        'mod': _np.mod,
+        'pow': _np.power,
+        '==': _np.equal,
+        '>': _np.greater,
+        '>=': _np.greater_equal,
+        '<': _np.less,
+        '<=': _np.less_equal
+    }
 
     def get_np_ret(x1, x2, op):
         return np_op_map[op](x1, x2)
 
-    @np.use_np_shape
+    @npx.use_np_shape
     class TestBinaryElementWiseOp(HybridBlock):
         def __init__(self, op, scalar=None, reverse=False):
             super(TestBinaryElementWiseOp, self).__init__()
@@ -197,29 +210,34 @@ def test_ndarray_binary_element_wise_ops():
                     return x ** args[0] if not self._reverse else args[0] ** x
             elif self._op == '>':
                 if self._scalar is not None:
-                    return x > self._scalar
+                    return x > self._scalar if not self._reverse else self._scalar > x
                 else:
                     return x > args[0]
             elif self._op == '>=':
                 if self._scalar is not None:
-                    return x >= self._scalar
+                    return x >= self._scalar if not self._reverse else self._scalar >= x
                 else:
                     return x >= args[0]
             elif self._op == '<':
                 if self._scalar is not None:
-                    return x < self._scalar
+                    return x < self._scalar if not self._reverse else self._scalar < x
                 else:
                     return x < args[0]
             elif self._op == '<=':
                 if self._scalar is not None:
-                    return x <= self._scalar
+                    return x <= self._scalar if not self._reverse else self._scalar <= x
                 else:
                     return x <= args[0]
+            elif self._op == '==':
+                if self._scalar is not None:
+                    return x == self._scalar if not self._reverse else self._scalar == x
+                else:
+                    return x == args[0]
             else:
                 print(self._op)
                 assert False
 
-    @np.use_np_shape
+    @npx.use_np_shape
     def check_binary_op_result(shape1, shape2, op, dtype=None):
         if shape1 is None:
             mx_input1 = abs(_np.random.uniform()) + 1
@@ -289,10 +307,10 @@ def test_ndarray_binary_element_wise_ops():
 
 @with_seed()
 def test_hybrid_block_multiple_outputs():
-    @np.use_np_shape
+    @npx.use_np_shape
     class TestAllNumpyOutputs(HybridBlock):
         def hybrid_forward(self, F, x, *args, **kwargs):
-            return F.npe.relu(x), F.np.sum(x)
+            return F.npx.relu(x), F.np.sum(x)
 
     class TestAllClassicOutputs(HybridBlock):
         def hybrid_forward(self, F, x, *args, **kwargs):
@@ -309,7 +327,7 @@ def test_hybrid_block_multiple_outputs():
             assert type(out1) is expected_out_type
             assert type(out2) is expected_out_type
 
-    @np.use_np_shape
+    @npx.use_np_array
     class TestMixedTypeOutputsFailure(HybridBlock):
         def hybrid_forward(self, F, x, *args, **kwargs):
             return F.relu(x.as_classic_ndarray()), F.np.sum(x)
@@ -357,6 +375,257 @@ def test_np_ndarray_copy():
     assert same(mx_ret.asnumpy(), np_ret)
 
 
+@with_seed()
+def test_np_ndarray_indexing():
+    def test_getitem(np_array, index):
+        """`is_scalar` indicates whether we should expect a scalar for the result.
+        If so, the indexed array of NDArray should call asscalar to compare
+        with numpy's indexed array."""
+        np_index = index
+        if isinstance(index, np.ndarray):
+            np_index = index.asnumpy()
+        if isinstance(index, tuple):
+            np_index = []
+            for idx in index:
+                if isinstance(idx, np.ndarray):
+                    np_index.append(idx.asnumpy())
+                else:
+                    np_index.append(idx)
+            np_index = tuple(np_index)
+
+        np_indexed_array = np_array[np_index]
+        mx_array = np.array(np_array, dtype=np_array.dtype)
+        mx_indexed_array = mx_array[index].asnumpy()
+        assert same(np_indexed_array, mx_indexed_array), 'Failed with index=%s' % str(index)
+
+    def test_setitem(np_array, index):
+        def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None):
+            if np_value is not None:
+                np_array[np_index] = np_value
+            elif isinstance(mx_value, np.ndarray):
+                np_array[np_index] = mx_value.asnumpy()
+            else:
+                np_array[np_index] = mx_value
+            mx_array[mx_index] = mx_value
+            assert same(np_array, mx_array.asnumpy())
+
+        np_index = index
+        if isinstance(index, np.ndarray):
+            np_index = index.asnumpy()
+        if isinstance(index, tuple):
+            np_index = []
+            for idx in index:
+                if isinstance(idx, np.ndarray):
+                    np_index.append(idx.asnumpy())
+                else:
+                    np_index.append(idx)
+            np_index = tuple(np_index)
+
+        mx_array = np.array(np_array, dtype=np_array.dtype)
+        np_array = mx_array.asnumpy()
+        indexed_array_shape = np_array[np_index].shape
+        np_indexed_array = _np.random.randint(low=-10000, high=0, size=indexed_array_shape)
+        # test value is a numpy array without broadcast
+        assert_same(np_array, np_index, mx_array, index, np_indexed_array)
+        # test value is an numeric_type
+        assert_same(np_array, np_index, mx_array, index, _np.random.randint(low=-10000, high=0))
+        if len(indexed_array_shape) > 1:
+            # test ndarray with broadcast
+            assert_same(np_array, np_index, mx_array, index,
+                        np.random.uniform(low=-10000, high=0, size=(indexed_array_shape[-1],)))
+            # test numpy array with broadcast
+            assert_same(np_array, np_index, mx_array, index,
+                        _np.random.randint(low=-10000, high=0, size=(indexed_array_shape[-1],)))
+            # test list with broadcast
+            assert_same(np_array, np_index, mx_array, index,
+                        [_np.random.randint(low=-10000, high=0)] * indexed_array_shape[-1])
+
+    def test_getitem_autograd(np_array, index):
+        x = np.array(np_array, dtype=np_array.dtype)
+        x.attach_grad()
+        with npx.autograd.record():
+            y = x[index]
+        y.backward()
+        value = np.ones_like(y)
+        x_grad = np.zeros_like(x)
+        x_grad[index] = value
+        assert same(x_grad.asnumpy(), x.grad.asnumpy())
+
+    def test_setitem_autograd(np_array, index):
+        x = np.array(np_array, dtype=np_array.dtype)
+        out_shape = x[index].shape
+        y = np.random.uniform(size=out_shape)
+        y.attach_grad()
+        try:
+            with npx.autograd.record():
+                x[index] = y
+                assert False  # should not reach here
+        except mx.base.MXNetError as err:
+            assert str(err).find('Inplace operations (+=, -=, x[:]=, etc) are not supported when recording with') != -1
+
+    def np_int(index, int_type=_np.int32):
+        def convert(num):
+            if num is None:
+                return num
+            else:
+                return int_type(num)
+
+        if isinstance(index, slice):
+            return slice(convert(index.start), convert(index.stop), convert(index.step))
+        elif isinstance(index, tuple):  # tuple of slices and integers
+            ret = []
+            for elem in index:
+                if isinstance(elem, slice):
+                    ret.append(slice(convert(elem.start), convert(elem.stop), convert(elem.step)))
+                else:
+                    ret.append(convert(elem))
+            return tuple(ret)
+        else:
+            assert False
+
+    shape = (8, 16, 9, 9)
+    np_array = _np.arange(_np.prod(shape), dtype='int32').reshape(shape)
+    index_list = [
+        (),
+        0,
+        _np.int32(0),
+        _np.int64(0),
+        5,
+        _np.int32(5),
+        _np.int64(5),
+        -1,
+        _np.int32(-1),
+        _np.int64(-1),
+        slice(5),
+        np_int(slice(5), _np.int32),
+        np_int(slice(5), _np.int64),
+        slice(1, 5),
+        np_int(slice(1, 5), _np.int32),
+        np_int(slice(1, 5), _np.int64),
+        slice(1, 5, 2),
+        np_int(slice(1, 5, 2), _np.int32),
+        np_int(slice(1, 5, 2), _np.int64),
+        slice(7, 0, -1),
+        np_int(slice(7, 0, -1)),
+        np_int(slice(7, 0, -1), _np.int64),
+        slice(None, 6),
+        np_int(slice(None, 6)),
+        np_int(slice(None, 6), _np.int64),
+        slice(None, 6, 3),
+        np_int(slice(None, 6, 3)),
+        np_int(slice(None, 6, 3), _np.int64),
+        slice(1, None),
+        np_int(slice(1, None)),
+        np_int(slice(1, None), _np.int64),
+        slice(1, None, 3),
+        np_int(slice(1, None, 3)),
+        np_int(slice(1, None, 3), _np.int64),
+        slice(None, None, 2),
+        np_int(slice(None, None, 2)),
+        np_int(slice(None, None, 2), _np.int64),
+        slice(None, None, -1),
+        np_int(slice(None, None, -1)),
+        np_int(slice(None, None, -1), _np.int64),
+        slice(None, None, -2),
+        np_int(slice(None, None, -2), _np.int32),
+        np_int(slice(None, None, -2), _np.int64),
+        (slice(None), slice(None), 1, 8),
+        (slice(None), slice(None), -1, 8),
+        (slice(None), slice(None), 1, -8),
+        (slice(None), slice(None), -1, -8),
+        np_int((slice(None), slice(None), 1, 8)),
+        np_int((slice(None), slice(None), 1, 8), _np.int64),
+        (slice(None), slice(None), 1, 8),
+        np_int((slice(None), slice(None), -1, -8)),
+        np_int((slice(None), slice(None), -1, -8), _np.int64),
+        (slice(None), 2, slice(1, 5), 1),
+        np_int((slice(None), 2, slice(1, 5), 1)),
+        np_int((slice(None), 2, slice(1, 5), 1), _np.int64),
+        (1, 2, 3),
+        np_int((1, 2, 3)),
+        np_int((1, 2, 3), _np.int64),
+        (-1, -2, -3),
+        np_int((-1, -2, -3)),
+        np_int((-1, -2, -3), _np.int64),
+        (1, 2, 3, 4),
+        np_int((1, 2, 3, 4)),
+        np_int((1, 2, 3, 4), _np.int64),
+        (-4, -3, -2, -1),
+        np_int((-4, -3, -2, -1)),
+        np_int((-4, -3, -2, -1), _np.int64),
+        (slice(None, None, -1), 2, slice(1, 5), 1),
+        np_int((slice(None, None, -1), 2, slice(1, 5), 1)),
+        np_int((slice(None, None, -1), 2, slice(1, 5), 1), _np.int64),
+        (slice(None, None, -1), 2, slice(1, 7, 2), 1),
+        np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1)),
+        np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1), _np.int64),
+        (slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)),
+        np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3))),
+        np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), _np.int64),
+        (slice(1, 8, 2), 1, slice(3, 8), 2),
+        np_int((slice(1, 8, 2), 1, slice(3, 8), 2)),
+        np_int((slice(1, 8, 2), 1, slice(3, 8), 2), _np.int64),
+        [1],
+        [1, 2],
+        [2, 1, 3],
+        [7, 5, 0, 3, 6, 2, 1],
+        _np.array([6, 3], dtype=_np.int32),
+        _np.array([[3, 4], [0, 6]], dtype=_np.int32),
+        _np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=_np.int32),
+        _np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=_np.int64),
+        _np.array([[2], [0], [1]], dtype=_np.int32),
+        _np.array([[2], [0], [1]], dtype=_np.int64),
+        np.array([4, 7], dtype=_np.int32),
+        np.array([4, 7], dtype=_np.int64),
+        np.array([[3, 6], [2, 1]], dtype=_np.int32),
+        np.array([[3, 6], [2, 1]], dtype=_np.int64),
+        np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=_np.int32),
+        np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=_np.int64),
+        (1, [2, 3]),
+        (1, [2, 3], _np.array([[3], [0]], dtype=_np.int32)),
+        (1, [2, 3]),
+        (1, [2, 3], _np.array([[3], [0]], dtype=_np.int64)),
+        (1, [2], _np.array([[5], [3]], dtype=_np.int32), slice(None)),
+        (1, [2], _np.array([[5], [3]], dtype=_np.int64), slice(None)),
+        (1, [2, 3], _np.array([[6], [0]], dtype=_np.int32), slice(2, 5)),
+        (1, [2, 3], _np.array([[6], [0]], dtype=_np.int64), slice(2, 5)),
+        (1, [2, 3], _np.array([[4], [7]], dtype=_np.int32), slice(2, 5, 2)),
+        (1, [2, 3], _np.array([[4], [7]], dtype=_np.int64), slice(2, 5, 2)),
+        (1, [2], _np.array([[3]], dtype=_np.int32), slice(None, None, -1)),
+        (1, [2], _np.array([[3]], dtype=_np.int64), slice(None, None, -1)),
+        (1, [2], _np.array([[3]], dtype=_np.int32), np.array([[5, 7], [2, 4]], dtype=_np.int64)),
+        (1, [2], np.array([[4]], dtype=_np.int32), np.array([[1, 3], [5, 7]], dtype='int64')),
+        [0],
+        [0, 1],
+        [1, 2, 3],
+        [2, 0, 5, 6],
+        ([1, 1], [2, 3]),
+        ([1], [4], [5]),
+        ([1], [4], [5], [6]),
+        ([[1]], [[2]]),
+        ([[1]], [[2]], [[3]], [[4]]),
+        (slice(0, 2), [[1], [6]], slice(0, 2), slice(0, 5, 2)),
+        ([[[[1]]]], [[1]], slice(0, 3), [1, 5]),
+        ([[[[1]]]], 3, slice(0, 3), [1, 3]),
+        ([[[[1]]]], 3, slice(0, 3), 0),
+        ([[[[1]]]], [[2], [12]], slice(0, 3), slice(None)),
+        ([1, 2], slice(3, 5), [2, 3], [3, 4]),
+        ([1, 2], slice(3, 5), (2, 3), [3, 4]),
+        range(4),
+        range(3, 0, -1),
+        (range(4,), [1]),
+        # slice(0, 0) does not support output zero-size tensor yet
+    ]
+    for index in index_list:
+        test_getitem(np_array, index)
+        test_setitem(np_array, index)
+        test_getitem_autograd(np_array, index)
+        if not isinstance(index, tuple) or len(index) != 0:
+            # When index = (), this is same a[()] = b is equivalent to b.copyto(a)
+            # which should have no problem to do autograd
+            test_setitem_autograd(np_array, index)
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 3608690..9804aea 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -19,7 +19,7 @@
 from __future__ import absolute_import
 import numpy as _np
 import mxnet as mx
-from mxnet import np, npe
+from mxnet import np, npx
 from mxnet.gluon import HybridBlock
 from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray
 from mxnet.test_utils import check_numeric_gradient
@@ -79,7 +79,8 @@ def test_np_sum():
                         if itype == 'float32' and dtype == 'float32':
                             x_sym = mx.sym.Variable("x").as_np_ndarray()
                             mx_sym = mx.sym.np.sum(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_classic_ndarray()
-                            check_numeric_gradient(mx_sym, [x], numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)
+                            check_numeric_gradient(mx_sym, [x.as_classic_ndarray()],
+                                                   numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)
 
                         # test imperative
                         mx_out = np.sum(x, axis=axis, dtype=dtype, keepdims=keepdims)
@@ -88,7 +89,7 @@ def test_np_sum():
 
 
 @with_seed()
-@np.use_np_shape
+@npx.use_np_shape
 def test_np_dot():
     shapes = [
         ((3, 0), (0, 4)),
@@ -132,7 +133,7 @@ def test_np_dot():
 
 @with_seed()
 def test_np_mean():
-    @np.use_np_shape
+    @npx.use_np_shape
     class TestMean(HybridBlock):
         def __init__(self, axis=None, dtype=None, keepdims=False):
             super(TestMean, self).__init__()
@@ -185,7 +186,8 @@ def test_np_mean():
                         if itype == 'float32' and dtype == 'float32':
                             x_sym = mx.sym.Variable("x").as_np_ndarray()
                             mx_sym = mx.sym.np.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_classic_ndarray()
-                            check_numeric_gradient(mx_sym, [x], numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)
+                            check_numeric_gradient(mx_sym, [x.as_classic_ndarray()],
+                                                   numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)
 
                         # test imperative
                         mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims)
@@ -194,7 +196,6 @@ def test_np_mean():
 
 
 @with_seed()
-@np.use_np_shape
 def test_np_transpose():
     # TODO(junwu): Add more test cases
     data = mx.sym.var('a').as_np_ndarray()
@@ -224,39 +225,36 @@ def test_np_transpose():
 
 
 @with_seed()
-@np.use_np_shape
-def test_relu():
+def test_npx_relu():
     # TODO(junwu): Add more test cases
     data = mx.sym.var('data').as_np_ndarray()
-    ret = mx.sym.npe.relu(data)
+    ret = mx.sym.npx.relu(data)
     assert type(ret) == mx.sym.np._Symbol
 
     shapes = [(), (0, 2, 0)]
     shapes.extend([rand_shape_nd(ndim, allow_zero_size=True) for ndim in range(5)])
     for shape in shapes:
         data = np.array(_np.random.uniform(size=shape).astype('float32'))
-        ret = npe.relu(data)
+        ret = npx.relu(data)
         assert type(ret) == np.ndarray
 
 
 @with_seed()
-@np.use_np_shape
-def test_sigmoid():
+def test_npx_sigmoid():
     # TODO(junwu): Add more test cases
     data = mx.sym.var('data').as_np_ndarray()
-    ret = mx.sym.npe.sigmoid(data)
+    ret = mx.sym.npx.sigmoid(data)
     assert type(ret) == mx.sym.np._Symbol
 
     shapes = [(), (0, 2, 0)]
     shapes.extend([rand_shape_nd(ndim, allow_zero_size=True) for ndim in range(5)])
     for shape in shapes:
         data = np.array(_np.random.uniform(size=shape).astype('float32'))
-        ret = npe.sigmoid(data)
+        ret = npx.sigmoid(data)
         assert type(ret) == np.ndarray
 
 
 @with_seed()
-@np.use_np_shape
 def test_np_reshape():
     # TODO(junwu): Add more test cases
     data = mx.sym.var('a').as_np_ndarray()
@@ -272,7 +270,6 @@ def test_np_reshape():
 
 
 @with_seed()
-@np.use_np_shape
 def test_np_maximum():
     # TODO(junwu): Add more test cases
     x1, x2 = mx.sym.var('x1').as_np_ndarray(), mx.sym.var('x2').as_np_ndarray()
@@ -293,7 +290,6 @@ def test_np_maximum():
 
 
 @with_seed()
-@np.use_np_shape
 def test_np_minimum():
     # TODO(junwu): Add more test cases
     x1, x2 = mx.sym.var('x1').as_np_ndarray(), mx.sym.var('x2').as_np_ndarray()
@@ -314,9 +310,9 @@ def test_np_minimum():
 
 
 @with_seed()
-@mx.use_np_shape
 def test_np_unary_funcs():
     def check_unary_func(func, ref_grad, shape, low, high):
+        @npx.use_np_shape
         class TestUnary(HybridBlock):
             def __init__(self, func):
                 super(TestUnary, self).__init__()
@@ -391,8 +387,8 @@ def test_np_unary_funcs():
 
 
 @with_seed()
-@mx.use_np_shape
 def test_np_stack():
+    @npx.use_np_shape
     class TestStack(HybridBlock):
         def __init__(self, axis=None):
             super(TestStack, self).__init__()
@@ -442,6 +438,201 @@ def test_np_stack():
                 assert same(mx_out.asnumpy(), np_out)
 
 
+def test_np_random():
+    shapes = [(), (1,), (2, 3), (4, 0, 5), 6, (7, 8), None]
+    dtypes = ['float16', 'float32', 'float64']
+    op_names = ['uniform', 'normal']
+    for shape in shapes:
+        for dtype in dtypes:
+            for op_name in op_names:
+                op = getattr(np.random, op_name, None)
+                assert op is not None
+                out = op(size=shape, dtype=dtype)
+                expected_shape = shape
+                if not isinstance(shape, tuple):
+                    expected_shape = () if shape is None else (shape,)
+                assert out.shape == expected_shape
+
+    @npx.use_np
+    class TestRandom(HybridBlock):
+        def __init__(self, shape, op_name):
+            super(TestRandom, self).__init__()
+            self._shape = shape
+            self._op_name = op_name
+
+        def hybrid_forward(self, F, x):
+            op = getattr(F.np.random, self._op_name, None)
+            assert op is not None
+            return x + op(size=shape)
+
+    x = np.ones(())
+    for op_name in op_names:
+        for shape in shapes:
+            for hybridize in [False, True]:
+                net = TestRandom(shape, op_name)
+                if hybridize:
+                    net.hybridize()
+                out = net(x)
+                expected_shape = shape
+                if not isinstance(shape, tuple):
+                    expected_shape = () if shape is None else (shape,)
+                assert out.shape == expected_shape
+
+
+@with_seed()
+def test_np_arange():
+    configs = [
+        (1, 10, 2),
+        (1, 10, 4),
+        (1, -10, 4),
+        (1, -10, -2),
+        (1, -10, -4),
+        (2, 3),
+        (2, -3),
+        (-2, -3),
+        (-2, 3),
+        (4, 0, 5),
+        (-4, 0, 5),
+        (-4, 0, -5),
+        (0, 0),
+        (11, 11),
+        (0, 0, 2),
+        (0, 0, -2),
+        (0, 5, None),
+        (0, -5, None),
+        0,
+        6,
+    ]
+    dtypes = ['int32', 'float16', 'float32', 'float64', None]
+    for config in configs:
+        for dtype in dtypes:
+            if isinstance(config, tuple):
+                mx_ret = np.arange(*config, dtype=dtype)
+                np_ret = _np.arange(*config, dtype=dtype)
+            else:
+                mx_ret = np.arange(config, dtype=dtype)
+                np_ret = _np.arange(config, dtype=dtype)
+            assert same(mx_ret.asnumpy(), np_ret)
+
+    @npx.use_np
+    class TestRange(HybridBlock):
+        def __init__(self, start, stop=None, step=None, dtype=None):
+            super(TestRange, self).__init__()
+            self._start = start
+            self._stop = stop
+            self._step = step
+            self._dtype = dtype
+
+        def hybrid_forward(self, F, x):
+            return x + F.np.arange(self._start, self._stop, self._step, dtype=self._dtype)
+
+    for dtype in dtypes:
+        x = np.zeros(shape=(), dtype=dtype)
+        for config in configs:
+            for hybridize in [False, True]:
+                if isinstance(config, tuple):
+                    net = TestRange(*config, dtype=dtype)
+                    np_out = _np.arange(*config, dtype=dtype)
+                else:
+                    net = TestRange(config, dtype=dtype)
+                    np_out = _np.arange(config, dtype=dtype)
+                if hybridize:
+                    net.hybridize()
+                mx_out = net(x)
+                assert same(mx_out.asnumpy(), np_out)
+
+
+@with_seed()
+def test_np_argmax():
+    workloads = [
+        ((), 0, False),
+        ((), -1, False),
+        ((), 1, True),
+        ((5, 3), None, False),
+        ((5, 3), -1, False),
+        ((5, 3), 1, False),
+        ((5, 3), 3, True),
+        ((5, 0, 3), 0, False),
+        ((5, 0, 3), -1, False),
+        ((5, 0, 3), None, True),
+        ((5, 0, 3), 1, True),
+    ]
+    dtypes = ['float16', 'float32', 'float64']
+
+    @npx.use_np
+    class TestArgMax(HybridBlock):
+        def __init__(self, axis=None):
+            super(TestArgMax, self).__init__()
+            self._axis = axis
+
+        def hybrid_forward(self, F, x):
+            return F.np.argmax(x, self._axis)
+
+    for shape, axis, throw_exception in workloads:
+        for dtype in dtypes:
+            a = np.random.uniform(size=shape, dtype=dtype)
+            if throw_exception:
+                # Cannot use assert_exception because sometimes the main thread
+                # proceeds to `assert False` before the exception is thrown
+                # in the worker thread. Have to use mx.nd.waitall() here
+                # to block the main thread.
+                try:
+                    np.argmax(a, axis)
+                    mx.nd.waitall()
+                    assert False
+                except mx.MXNetError:
+                    pass
+            else:
+                mx_ret = np.argmax(a, axis=axis)
+                np_ret = _np.argmax(a.asnumpy(), axis=axis)
+                assert same(mx_ret.asnumpy(), np_ret)
+
+            for hybridize in [False, True]:
+                net = TestArgMax(axis)
+                if hybridize:
+                    net.hybridize()
+                if throw_exception:
+                    try:
+                        net(a)
+                        mx.nd.waitall()
+                        assert False
+                    except mx.MXNetError:
+                        pass
+                else:
+                    mx_ret = net(a)
+                    assert same(mx_ret.asnumpy(), np_ret)
+
+
+@with_seed()
+def test_np_linalg_norm():
+    @npx.use_np
+    class TestLinalgNorm(HybridBlock):
+        def __init__(self, ord=None, axis=None, keepdims=False):
+            super(TestLinalgNorm, self).__init__()
+            self._ord = ord
+            self._axis = axis
+            self._keepdims = keepdims
+
+        def hybrid_forward(self, F, x):
+            return F.np.linalg.norm(x, ord=self._ord, axis=self._axis, keepdims=self._keepdims)
+
+    a = np.arange(5 * 6 * 7 * 8).reshape((5, 6, 7, 8))
+    ords = [None, 'fro']
+    axes = [None, (0, 2), (1, 0), (1, 2)]
+    for ord in ords:
+        for axis in axes:
+            if ord == 'fro' and axis is None and a.ndim > 2:
+                continue
+            for keepdims in [False, True]:
+                for hybridize in [False, True]:
+                    net = TestLinalgNorm(ord, axis, keepdims)
+                    if hybridize:
+                        net.hybridize()
+                    mx_ret = net(a)
+                    np_ret = _np.linalg.norm(a.asnumpy(), ord=ord, axis=axis, keepdims=keepdims)
+                    assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-5, rtol=1e-4)
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()
diff --git a/tests/python/unittest/test_thread_local.py b/tests/python/unittest/test_thread_local.py
index b553299..ee56ba7 100644
--- a/tests/python/unittest/test_thread_local.py
+++ b/tests/python/unittest/test_thread_local.py
@@ -23,6 +23,7 @@ from mxnet.context import Context
 from mxnet.attribute import AttrScope
 from mxnet.name import NameManager
 from mxnet.test_utils import set_default_context
+from mxnet.util import _NumpyArrayScope
 
 def test_context():
     ctx_list = []
@@ -163,6 +164,41 @@ def test_symbol():
     thread.join()
     assert status[0], "Failed to execute a symbolic graph within a thread"
 
+
+def test_np_array_scope():
+    np_array_scope_list = []
+    _NumpyArrayScope._current = _NumpyArrayScope(False)
+    np_array_scope_list.append(_NumpyArrayScope._current)
+
+    def f():
+        _NumpyArrayScope._current = _NumpyArrayScope(True)
+        np_array_scope_list.append(_NumpyArrayScope._current)
+
+    thread = threading.Thread(target=f)
+    thread.start()
+    thread.join()
+    assert len(np_array_scope_list) == 2
+    assert not np_array_scope_list[0]._is_np_array
+    assert np_array_scope_list[1]._is_np_array
+
+    event = threading.Event()
+    status = [False]
+
+    def g():
+        with mx.np_array(False):
+            event.wait()
+            if not mx.is_np_array():
+                status[0] = True
+
+    thread = threading.Thread(target=g)
+    thread.start()
+    _NumpyArrayScope._current = _NumpyArrayScope(True)
+    event.set()
+    thread.join()
+    event.clear()
+    assert status[0], "Spawned thread didn't set status correctly"
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()