You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2019/06/04 06:58:24 UTC

[incubator-mxnet] 07/13: [numpy] Refactor np modules (#14989)

This is an automated email from the ASF dual-hosted git repository.

reminisce pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 043f01e45462496f70bd75179847f3d24efa4a9b
Author: reminisce <wu...@gmail.com>
AuthorDate: Sat May 18 13:30:29 2019 -0700

    [numpy] Refactor np modules (#14989)
    
    * Refactor
    
    * Initial refactoring
    
    * Fix notebook
    
    * Move numpy op check from backend to frontend
    
    * Add homogeneous ndarray check
    
    * Fix grouping inhomogeneous types of symbols
    
    * Improve error handling of different types of symbols as outputs
    
    * Fix test
    
    * Fix numpy test
    
    * Fix ci
    
    * Try to fix gpu ci failure
---
 example/numpy/demo.ipynb                           |  73 ++----
 include/mxnet/c_api.h                              |  17 --
 include/mxnet/op_attr_types.h                      |   9 -
 python/mxnet/__init__.py                           |   3 +
 python/mxnet/_ctypes/ndarray.py                    |  19 +-
 python/mxnet/_ctypes/symbol.py                     |  10 +-
 python/mxnet/base.py                               | 119 +++++++---
 python/mxnet/gluon/block.py                        |   6 +-
 python/mxnet/gluon/utils.py                        |  22 ++
 python/mxnet/ndarray/__init__.py                   |   3 +-
 python/mxnet/ndarray/ndarray.py                    |  48 +---
 python/mxnet/ndarray/numpy/__init__.py             |   5 +-
 .../{numpy/ext.py => ndarray/numpy/_internal.py}   |   2 +-
 python/mxnet/ndarray/numpy/_op.py                  |  20 +-
 python/mxnet/ndarray/numpy/_register.py            |   8 +-
 python/mxnet/ndarray/numpy/linalg.py               |   2 +-
 python/mxnet/ndarray/numpy/random.py               |   2 +-
 .../ndarray/{numpy => numpy_extension}/__init__.py |   5 +-
 .../{numpy/ext.py => numpy_extension/_op.py}       |   3 +-
 .../{numpy => numpy_extension}/_register.py        |   5 +-
 python/mxnet/ndarray/register.py                   |  66 +++++-
 python/mxnet/numpy/__init__.py                     |   4 +-
 python/mxnet/numpy/_op.py                          |   2 +-
 python/mxnet/numpy/_register.py                    |   5 +-
 python/mxnet/numpy/linalg.py                       |   2 +-
 python/mxnet/numpy/multiarray.py                   | 185 ++++++++++-----
 python/mxnet/numpy/random.py                       |   2 +-
 .../mxnet/{numpy => numpy_extension}/__init__.py   |   7 +-
 python/mxnet/{numpy => numpy_extension}/_op.py     |   2 +-
 .../mxnet/{numpy => numpy_extension}/_register.py  |   5 +-
 python/mxnet/symbol/__init__.py                    |   4 +-
 python/mxnet/symbol/numpy/__init__.py              |   7 +-
 .../{numpy/_op.py => symbol/numpy/_internal.py}    |   2 +-
 python/mxnet/symbol/numpy/_op.py                   |   2 +-
 python/mxnet/symbol/numpy/_register.py             |   9 +-
 python/mxnet/symbol/numpy/_symbol.py               | 258 ++++++++++-----------
 python/mxnet/symbol/numpy/ext.py                   |  20 --
 python/mxnet/symbol/numpy/linalg.py                |   2 +-
 python/mxnet/symbol/numpy/random.py                |   2 +-
 .../numpy => symbol/numpy_extension}/__init__.py   |   5 +-
 .../mxnet/symbol/{numpy => numpy_extension}/_op.py |   3 +-
 .../symbol/{numpy => numpy_extension}/_register.py |   5 +-
 python/mxnet/symbol/register.py                    |  74 +++++-
 python/mxnet/symbol/symbol.py                      |  57 ++---
 python/mxnet/test_utils.py                         |   6 +
 src/c_api/c_api_common.h                           |  17 --
 src/c_api/c_api_ndarray.cc                         |  16 --
 src/operator/numpy/np_broadcast_reduce_op.h        |   1 +
 src/operator/numpy/np_broadcast_reduce_op_value.cc |  14 +-
 src/operator/numpy/np_broadcast_reduce_op_value.cu |   8 +-
 src/operator/numpy/np_dot-inl.h                    |  11 +-
 src/operator/numpy/np_dot.cc                       |   2 +-
 src/operator/numpy/np_dot.cu                       |   2 +-
 src/operator/numpy/np_elemwise_broadcast_op.cc     |  56 ++---
 src/operator/numpy/np_elemwise_broadcast_op.cu     |  34 +--
 src/operator/numpy/np_elemwise_unary_op_basic.cc   |  28 ++-
 src/operator/numpy/np_elemwise_unary_op_basic.cu   |   4 +-
 src/operator/numpy/np_init_op.cc                   |  64 ++++-
 src/operator/numpy/np_init_op.cu                   |  10 +-
 src/operator/numpy/np_matrix_op.cc                 |   6 +-
 src/operator/numpy/np_matrix_op.cu                 |   4 +-
 src/operator/numpy/np_true_divide.cc               |   9 +-
 src/operator/numpy/np_true_divide.cu               |   6 +-
 tests/python/unittest/test_numpy_ndarray.py        |  95 ++++----
 tests/python/unittest/test_numpy_op.py             |  78 ++++---
 65 files changed, 875 insertions(+), 707 deletions(-)

diff --git a/example/numpy/demo.ipynb b/example/numpy/demo.ipynb
index d8e6e06..7ba184d 100644
--- a/example/numpy/demo.ipynb
+++ b/example/numpy/demo.ipynb
@@ -6,21 +6,21 @@
    "source": [
     "# Fundamentals of MXNet Numpy Module\n",
     "\n",
-    "## Operator Namespaces for Imperative Programming\n",
+    "## Namespaces for Imperative Programming\n",
     "- `mxnet.numpy`: Regular NumPy operators\n",
     "- `mxnet.numpy.random`: NumPy random operators\n",
     "- `mxnet.numpy.linalg`: NumPy linear algebra operators\n",
-    "- `mxnet.numpy.ext`: Operators implemented in MXNet that do not exist in official NumPy\n",
+    "- `mxnet.numpy_extension`: Operators implemented in MXNet that do not exist in the official NumPy\n",
     "\n",
     "## Operator Namespaces for Gluon\n",
-    "`F` can be either `mxnet.ndarray` or `mxnet.symbol`.\n",
+    "`F` can be either `mxnet.ndarray` or `mxnet.symbol`. Note that `np` and `npe` are aliases of `numpy` and `numpy_extension`, respectively.\n",
     "- `F.np`: Regular NumPy operators\n",
     "- `F.np.random`: NumPy random operators\n",
     "- `F.np.linalg`: NumPy linear algebra operators\n",
-    "- `F.np.ext`: Operators implemented in MXNet that do not exist in official NumPy\n",
+    "- `F.npe`: Operators implemented in MXNet that do not exist in official NumPy\n",
     "\n",
     "## New `ndarray` and `symbol`\n",
-    "`mxnet.numpy.ndarray` and `mxnet.symbol.numpy._NumpySymbol` (not visible to users)\n",
+    "`mxnet.numpy.ndarray` (visible to users) and `mxnet.symbol.numpy._Symbol` (not visible to users)\n",
     "- Same name as in the official NumPy package\n",
     "- Dispatch convience fluent method calls to MXNet Numpy operators\n",
     "- Override many convenience fluent methods that do not exist in the official NumPy ndarray\n",
@@ -46,7 +46,7 @@
     "\n",
     "# create a scalar tensor\n",
     "x = np.array(3.14)\n",
-    "print(x)"
+    "print(x)  # x is actually an ndarray, but a scalar value will be printed"
    ]
   },
   {
@@ -170,13 +170,15 @@
     "from mxnet import gluon\n",
     "class TestBinaryBroadcast(gluon.HybridBlock):\n",
     "    def hybrid_forward(self, F, x1, x2):\n",
-    "        print(\"x1 type:\", str(type(x1)))\n",
-    "        print(\"x2 type:\", str(type(x2)))\n",
+    "        print(\"x1 type in hybrid_forward:\", str(type(x1)))\n",
+    "        print(\"x2 type in hybrid_forward:\", str(type(x2)))\n",
     "        return x1 + x2\n",
     "\n",
     "net = TestBinaryBroadcast()\n",
     "x1 = mx.nd.ones((2, 1))\n",
     "x2 = mx.nd.ones((1, 3))\n",
+    "print('x1 input tensor type: ', str(type(x1)))\n",
+    "print('x2 input tensor type: ', str(type(x2)))\n",
     "out = net(x1, x2)  # ok: imperative execution supports broadcasting\n",
     "print(out)"
    ]
@@ -203,13 +205,15 @@
    "source": [
     "class TestBinaryBroadcast2(gluon.HybridBlock):\n",
     "    def hybrid_forward(self, F, x1, x2):\n",
-    "        print(\"x1 type:\", str(type(x1)))\n",
-    "        print(\"x2 type:\", str(type(x2)))\n",
+    "        print(\"x1 type in hybrid_forward:\", str(type(x1)))\n",
+    "        print(\"x2 type in hybrid_forward:\", str(type(x2)))\n",
     "        return x1.as_np_ndarray() + x2  # convert x1 to new numpy ndarray/symbol\n",
     "\n",
     "net2 = TestBinaryBroadcast2()\n",
     "net2.hybridize()\n",
     "\n",
+    "print('x1 input tensor type: ', str(type(x1)))\n",
+    "print('x2 input tensor type: ', str(type(x2)))\n",
     "out =net2(x1, x2)\n",
     "print(out)"
    ]
@@ -224,7 +228,9 @@
     "net.hybridize()  # mark the block for execution using a computational graph\n",
     "\n",
     "x1 = x1.as_np_ndarray()  # convert x1 to np.ndarray so that _NumpySymbol will be used in graph construction\n",
+    "print('x1 input tensor type: ', str(type(x1)))\n",
     "x2 = x2.as_np_ndarray()  # convert x2 to np.ndarray so that _NumpySymbol will be used in graph construction\n",
+    "print('x2 input tensor type: ', str(type(x2)))\n",
     "out = net(x1, x2)  # ok: `+` operation supports broadcasting for _NumpySymbol\n",
     "print(out)  # mxnet.numpy.ndarray type, because it's from a np operator"
    ]
@@ -245,7 +251,7 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## MXNet Numpy Operators in Imperative Programming"
+    "### MXNet Numpy Operators in Imperative Programming"
    ]
   },
   {
@@ -255,15 +261,9 @@
    "outputs": [],
    "source": [
     "import mxnet as mx\n",
-    "from mxnet import numpy as np\n",
+    "from mxnet import numpy as np, numpy_extension as npe\n",
     "from mxnet import autograd\n",
-    "try:\n",
-    "    from mxboard import SummaryWriter\n",
-    "except ImportError:\n",
-    "    SummaryWriter = None\n",
     "\n",
-    "# create a summary writer for visualization\n",
-    "sw = SummaryWriter(logdir='./logs', flush_secs=2) if SummaryWriter is not None else None\n",
     "\n",
     "# Use numpy-compatible semantics to support scalar tensors\n",
     "mx.set_np_compat(True)\n",
@@ -285,11 +285,11 @@
     "learning_rate = 1e-6\n",
     "\n",
     "\n",
-    "for t in range(1000):\n",
+    "for t in range(50):\n",
     "    with autograd.record():\n",
     "        # Forward pass: compute predicted y\n",
     "        h = x.dot(w1)  # equivalent to np.dot(x, w1)\n",
-    "        h_relu = np.ext.relu(h)  # equivalent to mx.nd.relu(h)\n",
+    "        h_relu = npe.relu(h)  # equivalent to mx.nd.relu(h)\n",
     "        y_pred = h_relu.dot(w2)  # equivalent to np.dot(h_relu, w2)\n",
     "\n",
     "        # Compute loss\n",
@@ -302,23 +302,14 @@
     "\n",
     "    # Update weights\n",
     "    w1 -= learning_rate * w1.grad\n",
-    "    w2 -= learning_rate * w2.grad\n",
-    "\n",
-    "    if sw is not None:\n",
-    "        sw.add_scalar('loss', loss.item(), global_step=t)  # loss.item() copies the tensor element to a python scalar\n",
-    "        if t % 50 == 0:\n",
-    "            sw.add_histogram(tag='w1', values=w1, global_step=t)\n",
-    "            sw.add_histogram(tag='w2', values=w2, global_step=t)\n",
-    "\n",
-    "if sw is not None:\n",
-    "    sw.close()"
+    "    w2 -= learning_rate * w2.grad"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## MXNet Numpy Operators in Gluon `HybridBlock`"
+    "### MXNet Numpy Operators in Gluon `HybridBlock`"
    ]
   },
   {
@@ -329,13 +320,7 @@
    "source": [
     "import mxnet as mx\n",
     "from mxnet import gluon, autograd\n",
-    "try:\n",
-    "    from mxboard import SummaryWriter\n",
-    "except ImportError:\n",
-    "    SummaryWriter = None\n",
     "\n",
-    "# create a summary writer for visualization\n",
-    "sw = SummaryWriter(logdir='./logs', flush_secs=2) if SummaryWriter is not None else None\n",
     "\n",
     "# Use numpy-compatible semantics to support scalar tensors\n",
     "mx.set_np_compat(True)\n",
@@ -352,7 +337,7 @@
     "\n",
     "    def hybrid_forward(self, F, x, w1, w2):\n",
     "        h = x.dot(w1)  # equivalent to F.np.dot(x, w1)\n",
-    "        h_relu = F.np.ext.relu(h)  # equivalent to F.relu(h)\n",
+    "        h_relu = F.npe.relu(h)  # equivalent to F.relu(h)\n",
     "        y_pred = h_relu.dot(w2)  # equivalent to F.np.dot(h_relu, w2)\n",
     "        return y_pred\n",
     "\n",
@@ -373,21 +358,13 @@
     "total_loss = TotalLoss()\n",
     "trainer = gluon.Trainer(regressor.collect_params(), 'sgd', {'learning_rate': 1e-3, 'momentum': 0.9})\n",
     "\n",
-    "for t in range(1000):\n",
+    "for t in range(50):\n",
     "    with autograd.record():\n",
     "        output = regressor(x)  # output is a type of np.ndarray because np.dot is the last op in the network\n",
     "        loss = total_loss(output, y)  # loss is a scalar np.ndarray\n",
     "    loss.backward()\n",
     "    print(t, loss)  # note that loss.asnumpy() is called\n",
-    "    trainer.step(1)\n",
-    "    if sw is not None:\n",
-    "        sw.add_scalar('loss', loss.item(), global_step=t)  # loss.item() copies the tensor element to a python scalar\n",
-    "        if t % 50 == 0:\n",
-    "            for k, v in regressor.collect_params().items():\n",
-    "                sw.add_histogram(tag=k, values=v.data(), global_step=t)\n",
-    "\n",
-    "if sw is not None:\n",
-    "    sw.close()"
+    "    trainer.step(1)"
    ]
   }
  ],
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 5a92644..e9e9a37 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -2812,14 +2812,6 @@ MXNET_DLL int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
                                EngineFnPropertyHandle prop_handle DEFAULT(NULL),
                                int priority DEFAULT(0), const char* opr_name DEFAULT(NULL));
 /*!
-  * \brief Determines if an op is a Numpy op by its name prefix.
-  * Every Numpy op starts with a prefix string "_numpy_".
-  * \param creator Operator handle
-  * \param is_np_op Indicator of whether creator is a numpy op handle
-  */
-MXNET_DLL int MXIsNumpyCompatOp(AtomicSymbolCreator creator,
-                                int* is_np_op);
-/*!
  * \brief Create an NDArray from source sharing the same data chunk.
  * \param src source NDArray
  * \param out new NDArray sharing the same data chunck with src
@@ -2831,15 +2823,6 @@ MXNET_DLL int MXShallowCopyNDArray(NDArrayHandle src, NDArrayHandle* out);
  * \param out new Symbol sharing the same graph structure with src
  */
 MXNET_DLL int MXShallowCopySymbol(SymbolHandle src, SymbolHandle * out);
-/*!
- * \brief Checks if an output of CachedOp is from a numpy op.
- * \param handle CachedOp shared ptr
- * \param output_idx index of the output of the CachedOp
- * \param is_from_np_op indicator of whether the output is from a numpy op
- */
-MXNET_DLL int MXIsCachedOpOutputFromNumpyCompatOp(CachedOpHandle handle,
-                                                  int output_idx,
-                                                  int* is_from_np_op);
 
 #ifdef __cplusplus
 }
diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h
index 0e4e322..889b502 100644
--- a/include/mxnet/op_attr_types.h
+++ b/include/mxnet/op_attr_types.h
@@ -319,15 +319,6 @@ using FNeedRequantize = std::function<bool (const NodeAttrs& attrs)>;
 using FAvoidQuantizeInput = std::function<bool (const NodeAttrs& attrs,
                                                 size_t index)>;
 
-/*!
- * \brief Indicates whether this operator is NumPy compatible.
- * It is for distinguishing the operator from classic MXNet operators
- * which do not support zero-dim and zero-size tensors.
- * In Python, it is used to determine whether to output numpy ndarrays
- * or symbols that are NumPy compatible.
- */
-using TIsNumpyCompatible = bool;
-
 }  // namespace mxnet
 
 #endif  // MXNET_OP_ATTR_TYPES_H_
diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py
index 7c8150b..883e846 100644
--- a/python/mxnet/__init__.py
+++ b/python/mxnet/__init__.py
@@ -30,6 +30,9 @@ from . import contrib
 from . import ndarray
 from . import ndarray as nd
 from . import numpy
+from . import numpy_extension
+from . import numpy as np
+from . import numpy_extension as npe
 from . import name
 # use mx.sym as short for symbol
 from . import symbol as sym
diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py
index 60ec248..6404d89 100644
--- a/python/mxnet/_ctypes/ndarray.py
+++ b/python/mxnet/_ctypes/ndarray.py
@@ -26,7 +26,7 @@ import ctypes
 from ..base import _LIB
 from ..base import c_str_array, c_handle_array
 from ..base import NDArrayHandle, CachedOpHandle
-from ..base import check_call, _is_np_compat_op
+from ..base import check_call
 
 
 class NDArrayBase(object):
@@ -70,7 +70,7 @@ def _set_np_ndarray_class(cls):
     _np_ndarray_cls = cls
 
 
-def _imperative_invoke(handle, ndargs, keys, vals, out):
+def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op):
     """ctypes implementation of imperative invoke wrapper"""
     if out is not None:
         original_output = out
@@ -99,9 +99,9 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
         c_str_array([str(s) for s in vals]),
         ctypes.byref(out_stypes)))
 
+    create_ndarray_fn = _np_ndarray_cls if is_np_op else _ndarray_cls
     if original_output is not None:
         return original_output
-    create_ndarray_fn = _np_ndarray_cls if _is_np_compat_op(handle) else _ndarray_cls
     if num_output.value == 1:
         return create_ndarray_fn(ctypes.cast(output_vars[0], NDArrayHandle),
                                  stype=out_stypes[0])
@@ -112,11 +112,14 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
 
 class CachedOp(object):
     """Cached operator handle."""
-    __slots__ = ["handle"]
+    __slots__ = ["handle", "is_np_sym"]
 
     def __init__(self, sym, flags=()):
         self.handle = CachedOpHandle()
 
+        from ..symbol.numpy._symbol import _Symbol
+        self.is_np_sym = True if isinstance(sym, _Symbol) else False
+
         check_call(_LIB.MXCreateCachedOpEx(
             sym.handle,
             len(flags),
@@ -167,12 +170,10 @@ class CachedOp(object):
 
         if original_output is not None:
             return original_output
+        create_ndarray_fn = _np_ndarray_cls if self.is_np_sym else _ndarray_cls
         if num_output.value == 1:
-            create_ndarray_fn = _np_ndarray_cls if self._is_from_np_compat_op(0) else _ndarray_cls
             return create_ndarray_fn(ctypes.cast(output_vars[0], NDArrayHandle),
                                      stype=out_stypes[0])
         else:
-            return [_np_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle), stype=out_stypes[i])
-                    if self._is_from_np_compat_op(i) else
-                    _ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle), stype=out_stypes[i])
-                    for i in range(num_output.value)]
+            return [create_ndarray_fn(ctypes.cast(output_vars[i], NDArrayHandle),
+                                      stype=out_stypes[i]) for i in range(num_output.value)]
diff --git a/python/mxnet/_ctypes/symbol.py b/python/mxnet/_ctypes/symbol.py
index 7aea0a2..fc159f8 100644
--- a/python/mxnet/_ctypes/symbol.py
+++ b/python/mxnet/_ctypes/symbol.py
@@ -22,7 +22,7 @@ from __future__ import absolute_import as _abs
 
 import ctypes
 from ..base import _LIB
-from ..base import c_str_array, c_handle_array, c_str, mx_uint, _is_np_compat_op
+from ..base import c_str_array, c_handle_array, c_str, mx_uint
 from ..base import SymbolHandle
 from ..base import check_call
 
@@ -122,7 +122,7 @@ def _set_np_symbol_class(cls):
     _np_symbol_cls = cls
 
 
-def _symbol_creator(handle, args, kwargs, keys, vals, name):
+def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op):
     sym_handle = SymbolHandle()
     check_call(_LIB.MXSymbolCreateAtomicSymbol(
         ctypes.c_void_p(handle),
@@ -135,10 +135,8 @@ def _symbol_creator(handle, args, kwargs, keys, vals, name):
         raise TypeError(
             'Operators with variable length input can only accept input'
             'Symbols either as positional or keyword arguments, not both')
-    if _is_np_compat_op(handle):
-        s = _np_symbol_cls(sym_handle)
-    else:
-        s = _symbol_cls(sym_handle)
+    create_symbol_fn = _np_symbol_cls if is_np_op else _symbol_cls
+    s = create_symbol_fn(sym_handle)
     if args:
         s._compose(*args, name=name)
     elif kwargs:
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 131cb4d..85dd525 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -16,7 +16,7 @@
 # under the License.
 
 # coding: utf-8
-# pylint: disable=invalid-name, no-member, trailing-comma-tuple, bad-mcs-classmethod-argument, unnecessary-pass
+# pylint: disable=invalid-name, no-member, trailing-comma-tuple, bad-mcs-classmethod-argument, unnecessary-pass, too-many-lines
 """ctypes library of mxnet and helper functions."""
 from __future__ import absolute_import
 
@@ -598,7 +598,9 @@ def _init_op_module(root_namespace, module_name, make_op_func):
                                      ctypes.byref(plist)))
     op_names = []
     for i in range(size.value):
-        op_names.append(py_str(plist[i]))
+        op_name = py_str(plist[i])
+        if not _is_np_op(op_name):
+            op_names.append(op_name)
 
     module_op = sys.modules["%s.%s.op" % (root_namespace, module_name)]
     module_internal = sys.modules["%s.%s._internal" % (root_namespace, module_name)]
@@ -692,7 +694,9 @@ def _generate_op_module_signature(root_namespace, module_name, op_code_gen_func)
                                      ctypes.byref(plist)))
     op_names = []
     for i in range(size.value):
-        op_names.append(py_str(plist[i]))
+        op_name = py_str(plist[i])
+        if not _is_np_op(op_name):
+            op_names.append(op_name)
 
     module_op_file = get_module_file("%s.%s.op" % (root_namespace, module_name))
     module_op_all = []
@@ -743,19 +747,28 @@ def _sanity_check_params(func_name, unsupported_params, param_dict):
                                       .format(func_name, param_name))
 
 
-_NP_OP_SUBMODULE_LIST = ['_ext_', '_random_', '_linalg_']
-_NP_OP_PREFIX = '_numpy_'
+_NP_OP_PREFIX = '_np_'
+_NP_OP_SUBMODULE_LIST = ['_random_', '_linalg_']
 
+_NP_EXT_OP_PREFIX = '_npe_'
 
-def _get_np_op_submodule_name(op_name):
-    assert op_name.startswith(_NP_OP_PREFIX)
-    for name in _NP_OP_SUBMODULE_LIST:
-        if op_name[len(_NP_OP_PREFIX):].startswith(name):
-            return name
+_NP_INTERNAL_OP_PREFIX = '_npi_'
+
+
+def _is_np_op(op_name):
+    return op_name.startswith(_NP_OP_PREFIX) or op_name.startswith(_NP_EXT_OP_PREFIX)\
+           or op_name.startswith(_NP_INTERNAL_OP_PREFIX)
+
+
+def _get_op_submodule_name(op_name, op_name_prefix, submodule_name_list):
+    assert op_name.startswith(op_name_prefix)
+    for submodule_name in submodule_name_list:
+        if op_name[len(op_name_prefix):].startswith(submodule_name):
+            return submodule_name
     return ""
 
 
-def _init_np_op_module(root_namespace, module_name, make_op_func):
+def _init_np_op_module(root_module_name, np_module_name, mx_module_name, make_op_func):
     """
     Register numpy operators in namespaces `mxnet.numpy`, `mxnet.ndarray.numpy`
     and `mxnet.symbol.numpy`. They are used in imperative mode, Gluon APIs w/o hybridization,
@@ -765,51 +778,89 @@ def _init_np_op_module(root_namespace, module_name, make_op_func):
 
     Parameters
     ----------
-    root_namespace : str
+    root_module_name : str
         Top level module name, `mxnet` in the current cases.
-    module_name : str
-        Second level module name, `ndarray` or `symbol` in the current case.
+    np_module_name : str
+        Second level module name, `numpy` or `numpy_extension` in the current case.
     make_op_func : function
         Function for creating op functions.
     """
+    if np_module_name == 'numpy':
+        op_name_prefix = _NP_OP_PREFIX
+        submodule_name_list = _NP_OP_SUBMODULE_LIST
+    elif np_module_name == 'numpy_extension':
+        op_name_prefix = _NP_EXT_OP_PREFIX
+        submodule_name_list = []
+    elif np_module_name == 'numpy._internal':
+        op_name_prefix = _NP_INTERNAL_OP_PREFIX
+        submodule_name_list = []
+    else:
+        raise ValueError('unsupported np module name {}'.format(np_module_name))
+
     plist = ctypes.POINTER(ctypes.c_char_p)()
     size = ctypes.c_uint()
-
     check_call(_LIB.MXListAllOpNames(ctypes.byref(size), ctypes.byref(plist)))
     op_names = []
     for i in range(size.value):
         name = py_str(plist[i])
-        if name.startswith(_NP_OP_PREFIX):
+        if name.startswith(op_name_prefix):
             op_names.append(name)
 
-    if module_name == 'numpy':
-        # register ops for mxnet.numpy
-        module_pattern = "%s.%s._op"
-        submodule_pattern = "%s.%s.%s"
+    if mx_module_name is None:
+        # register np/npe ops for imperative programming
+        op_module_name = "%s.%s._op" % (root_module_name, np_module_name)  # e.g. mxnet.numpy._op
+        op_submodule_name = "%s.%s" % (root_module_name, np_module_name)  # e.g. mxnet.numpy.random
+    elif mx_module_name == 'ndarray' or mx_module_name == 'symbol':
+        # register numpy internal ops and np/npe ops for use in Gluon
+        # np internal ops are registered in mxnet.ndarray/symbol.numpy._internal
+        # np ops are registered in mxnet.ndarray/symbol.numpy._op
+        # npe ops are registered in mxnet.ndarray/symbol.numpy_extension._op
+        op_module_name = "%s.%s.%s" % (root_module_name, mx_module_name, np_module_name)
+        if op_name_prefix != _NP_INTERNAL_OP_PREFIX:
+            op_module_name += '._op'
+        # e.g. mxnet.symbol.numpy.random
+        op_submodule_name = "%s.%s.%s" % (root_module_name, mx_module_name, np_module_name)
     else:
-        # register ops for mxnet.ndarray.numpy or mxnet.symbol.numpy
-        module_pattern = "%s.%s.numpy._op"
-        submodule_pattern = "%s.%s.numpy.%s"
-    module_np_op = sys.modules[module_pattern % (root_namespace, module_name)]
+        raise ValueError('unsupported mxnet module {}'.format(mx_module_name))
+    op_submodule_name += '.%s'
+
+    op_module = sys.modules[op_module_name]
     submodule_dict = {}
-    for submodule_name in _NP_OP_SUBMODULE_LIST:
-        submodule_dict[submodule_name] = \
-            sys.modules[submodule_pattern % (root_namespace, module_name, submodule_name[1:-1])]
+    for submodule_name in submodule_name_list:
+        submodule_dict[submodule_name] = sys.modules[op_submodule_name % submodule_name[1:-1]]
     for name in op_names:
         hdl = OpHandle()
         check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
-        submodule_name = _get_np_op_submodule_name(name)
-        module_name_local = module_name
+        submodule_name = _get_op_submodule_name(name, op_name_prefix, submodule_name_list)
         if len(submodule_name) > 0:
-            func_name = name[(len(_NP_OP_PREFIX) + len(submodule_name)):]
+            func_name = name[(len(op_name_prefix) + len(submodule_name)):]
             cur_module = submodule_dict[submodule_name]
-            module_name_local = submodule_pattern % (root_namespace,
-                                                     module_name, submodule_name[1:-1])
+            module_name_local = op_submodule_name % submodule_name[1:-1]
         else:
-            func_name = name[len(_NP_OP_PREFIX):]
-            cur_module = module_np_op
+            func_name = name[len(op_name_prefix):]
+            cur_module = op_module
+            module_name_local =\
+                op_module_name[:-len('._op')] if op_module_name.endswith('._op') else op_module_name
 
         function = make_op_func(hdl, name, func_name)
         function.__module__ = module_name_local
         setattr(cur_module, function.__name__, function)
         cur_module.__all__.append(function.__name__)
+
+
+def set_module(module):
+    """Decorator for overriding __module__ on a function or class.
+
+    Example usage::
+
+        @set_module('mxnet.numpy')
+        def example():
+            pass
+
+        assert example.__module__ == 'numpy'
+    """
+    def decorator(func):
+        if module is not None:
+            func.__module__ = module
+        return func
+    return decorator
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index a578d34..46aca12 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -32,7 +32,7 @@ from ..symbol import Symbol
 from ..ndarray import NDArray
 from .. import name as _name
 from .parameter import Parameter, ParameterDict, DeferredInitializationError
-from .utils import _indent, _brief_print_list, HookHandle
+from .utils import _indent, _brief_print_list, HookHandle, _check_same_symbol_type
 from .. import numpy as _mx_np
 
 
@@ -746,7 +746,7 @@ class HybridBlock(Block):
                 out = self.hybrid_forward(symbol, *grouped_inputs, **params)  # pylint: disable=no-value-for-parameter
             out, self._out_format = _flatten(out, "output")
 
-            self._cached_graph = inputs, symbol.Group(out)
+            self._cached_graph = inputs, symbol.Group(out, _check_same_symbol_type(out))
 
         return self._cached_graph
 
@@ -1049,7 +1049,7 @@ class SymbolBlock(HybridBlock):
 
         syms, self._in_format = _flatten(inputs, "input")
         out, self._out_format = _flatten(outputs, "output")
-        out = symbol.Group(out)
+        out = symbol.Group(out, _check_same_symbol_type(out))
 
         input_names = set()
         for i in syms:
diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py
index 3957b74..08c7850 100644
--- a/python/mxnet/gluon/utils.py
+++ b/python/mxnet/gluon/utils.py
@@ -430,3 +430,25 @@ def shape_is_known(shape):
         assert dim_size > unknown_dim_size, "shape dimension size cannot be less than {}, while " \
                                             "received {}".format(unknown_dim_size, dim_size)
     return True
+
+def _check_same_symbol_type(symbols):
+    """Check whether all the symbols in the list are of the same type.
+    Raise type error if the types are different. Return the class of
+    the symbols."""
+    from ..symbol.numpy import _Symbol as np_symbol
+    from ..symbol import Symbol as classic_symbol
+    is_np_sym = True if isinstance(symbols[0], np_symbol) else False
+    for s in symbols[1:]:
+        if is_np_sym != isinstance(s, np_symbol):
+            raise TypeError('Found both classic symbol (mx.sym.Symbol) and numpy symbol '
+                            '(mx.sym.np._Symbol) in outputs. This will prevent you from building '
+                            'a computation graph by grouping them since different types of symbols '
+                            'are not allowed to be grouped in Gluon to form a computation graph. '
+                            'You will need to convert them to the same type of symbols, either '
+                            'classic or numpy following this rule: if you want numpy ndarray '
+                            'output(s) from the computation graph, please convert all the classic '
+                            'symbols in the list to numpy symbols by calling `as_np_ndarray()` '
+                            'on each of them; if you want classic ndarray output(s) from the '
+                            'computation graph, please convert all the numpy symbols in the list '
+                            'to classic symbols by calling `as_classic_ndarray()` on each of them.')
+    return np_symbol if is_np_sym else classic_symbol
diff --git a/python/mxnet/ndarray/__init__.py b/python/mxnet/ndarray/__init__.py
index f0e6edb..c326850 100644
--- a/python/mxnet/ndarray/__init__.py
+++ b/python/mxnet/ndarray/__init__.py
@@ -31,6 +31,7 @@ from .utils import load, load_frombuffer, save, zeros, empty, array
 from .sparse import _ndarray_cls
 from .ndarray import _GRAD_REQ_MAP, _DTYPE_MX_TO_NP, _DTYPE_NP_TO_MX, _new_empty_handle
 from . import numpy as np
+from . import numpy_extension as npe
 
 __all__ = op.__all__ + ndarray.__all__ + utils.__all__ + \
-          ['contrib', 'linalg', 'random', 'sparse', 'image']
+          ['contrib', 'linalg', 'random', 'sparse', 'image', 'numpy', 'numpy_extension']
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index b73b0aa..70190d9 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -187,15 +187,15 @@ fixed-size items.
 
     def as_np_ndarray(self):
         """Convert mxnet.ndarray.NDArray to mxnet.numpy.ndarray."""
+        storage_type = self.stype
+        if storage_type != 'default':
+            raise ValueError('cannot convert ndarray of stype {} to numpy ndarray'
+                             .format(str(type(storage_type))))
         from ..numpy import ndarray
         hdl = NDArrayHandle()
         check_call(_LIB.MXShallowCopyNDArray(self.handle, ctypes.byref(hdl)))
         return ndarray(handle=hdl, writable=self.writable)
 
-    def _is_np_compat(self):
-        """Always returns False except for mxnet.numpy.ndarray."""
-        return False
-
     @property
     def _tvm_handle(self):
         return self.handle.value
@@ -220,8 +220,6 @@ fixed-size items.
     def __add__(self, other):
         """x.__add__(y) <=> x+y <=> mx.nd.add(x, y) """
         # other may be the type of mxnet.numpy.ndarray
-        if isinstance(other, NDArray) and other._is_np_compat():
-            return other.__add__(self)
         return add(self, other)
 
     def __iadd__(self, other):
@@ -236,15 +234,11 @@ fixed-size items.
             raise TypeError('type %s not supported' % str(type(other)))
 
     def __radd__(self, other):
-        if isinstance(other, NDArray) and other._is_np_compat():
-            return other.__add__(self)
         return self.__add__(other)
 
     def __sub__(self, other):
         """x.__sub__(y) <=> x-y <=> mx.nd.subtract(x, y) """
         # other may be the type of mxnet.numpy.ndarray
-        if isinstance(other, NDArray) and other._is_np_compat():
-            return other.__rsub__(self)
         return subtract(self, other)
 
     def __isub__(self, other):
@@ -260,14 +254,10 @@ fixed-size items.
 
     def __rsub__(self, other):
         """x.__rsub__(y) <=> y-x <=> mx.nd.subtract(y, x) """
-        if isinstance(other, NDArray) and other._is_np_compat():
-            return other.__sub__(self)
         return subtract(other, self)
 
     def __mul__(self, other):
         """x.__mul__(y) <=> x*y <=> mx.nd.multiply(x, y) """
-        if isinstance(other, NDArray) and other._is_np_compat():
-            return other.__mul__(self)
         return multiply(self, other)
 
     def __neg__(self):
@@ -286,20 +276,14 @@ fixed-size items.
             raise TypeError('type %s not supported' % str(type(other)))
 
     def __rmul__(self, other):
-        if isinstance(other, NDArray) and other._is_np_compat():
-            return other.__mul__(self)
         return self.__mul__(other)
 
     def __div__(self, other):
         """x.__div__(y) <=> x/y <=> mx.nd.divide(x, y) """
-        if isinstance(other, NDArray) and other._is_np_compat():
-            return other.__rtruediv__(self)
         return divide(self, other)
 
     def __rdiv__(self, other):
         """x.__rdiv__(y) <=> y/x <=> mx.nd.divide(y, x) """
-        if isinstance(other, NDArray) and other._is_np_compat():
-            return other.__truediv__(self)
         return divide(other, self)
 
     def __idiv__(self, other):
@@ -314,13 +298,9 @@ fixed-size items.
             raise TypeError('type %s not supported' % str(type(other)))
 
     def __truediv__(self, other):
-        if isinstance(other, NDArray) and other._is_np_compat():
-            return other.__rtruediv__(self)
         return divide(self, other)
 
     def __rtruediv__(self, other):
-        if isinstance(other, NDArray) and other._is_np_compat():
-            return other.__truediv__(self)
         return divide(other, self)
 
     def __itruediv__(self, other):
@@ -328,14 +308,10 @@ fixed-size items.
 
     def __mod__(self, other):
         """x.__mod__(y) <=> x%y <=> mx.nd.modulo(x, y) """
-        if isinstance(other, NDArray) and other._is_np_compat():
-            return other.__rmod__(self)
         return modulo(self, other)
 
     def __rmod__(self, other):
         """x.__rmod__(y) <=> y%x <=> mx.nd.modulo(y, x) """
-        if isinstance(other, NDArray) and other._is_np_compat():
-            return other.__mod__(self)
         return modulo(other, self)
 
     def __imod__(self, other):
@@ -351,20 +327,14 @@ fixed-size items.
 
     def __pow__(self, other):
         """x.__pow__(y) <=> x**y <=> mx.nd.power(x,y) """
-        if isinstance(other, NDArray) and other._is_np_compat():
-            return other.__rpow__(self)
         return power(self, other)
 
     def __rpow__(self, other):
         """x.__pow__(y) <=> y**x <=> mx.nd.power(y,x) """
-        if isinstance(other, NDArray) and other._is_np_compat():
-            return other.__pow__(self)
         return power(other, self)
 
     def __eq__(self, other):
         """x.__eq__(y) <=> x==y <=> mx.nd.equal(x, y) """
-        if isinstance(other, NDArray) and other._is_np_compat():
-            return other.__eq__(self)
         return equal(self, other)
 
     def __hash__(self):
@@ -373,32 +343,22 @@ fixed-size items.
 
     def __ne__(self, other):
         """x.__ne__(y) <=> x!=y <=> mx.nd.not_equal(x, y) """
-        if isinstance(other, NDArray) and other._is_np_compat():
-            return other.__ne__(self)
         return not_equal(self, other)
 
     def __gt__(self, other):
         """x.__gt__(y) <=> x>y <=> mx.nd.greater(x, y) """
-        if isinstance(other, NDArray) and other._is_np_compat():
-            return other.__lt__(self)
         return greater(self, other)
 
     def __ge__(self, other):
         """x.__ge__(y) <=> x>=y <=> mx.nd.greater_equal(x, y) """
-        if isinstance(other, NDArray) and other._is_np_compat():
-            return other.__le__(self)
         return greater_equal(self, other)
 
     def __lt__(self, other):
         """x.__lt__(y) <=> x<y <=> mx.nd.lesser(x, y) """
-        if isinstance(other, NDArray) and other._is_np_compat():
-            return other.__gt__(self)
         return lesser(self, other)
 
     def __le__(self, other):
         """x.__le__(y) <=> x<=y <=> mx.nd.less_equal(x, y) """
-        if isinstance(other, NDArray) and other._is_np_compat():
-            return other.__ge__(self)
         return lesser_equal(self, other)
 
     def __bool__(self):
diff --git a/python/mxnet/ndarray/numpy/__init__.py b/python/mxnet/ndarray/numpy/__init__.py
index d97e808..7eb478f 100644
--- a/python/mxnet/ndarray/numpy/__init__.py
+++ b/python/mxnet/ndarray/numpy/__init__.py
@@ -15,12 +15,11 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""numpy module for numpy ops under mxnet.ndarray."""
+"""Module for numpy ops under mxnet.ndarray."""
 
-from . import ext
 from . import random
 from . import linalg
-from . import _op
+from . import _op, _internal
 from . import _register
 from ._op import *  # pylint: disable=wildcard-import
 
diff --git a/python/mxnet/numpy/ext.py b/python/mxnet/ndarray/numpy/_internal.py
similarity index 91%
rename from python/mxnet/numpy/ext.py
rename to python/mxnet/ndarray/numpy/_internal.py
index e4c8251..c5f2928 100644
--- a/python/mxnet/numpy/ext.py
+++ b/python/mxnet/ndarray/numpy/_internal.py
@@ -15,6 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""namespace for registering numpy.ext ops for imperative programming."""
+"""Namespace for numpy internal ops."""
 
 __all__ = []
diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index 9b32c31..e905fdf 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -15,18 +15,19 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""numpy namespace for operators used in Gluon APIs dispatched by F=ndarray module."""
+"""Namespace for numpy operators used in Gluon dispatched by F=ndarray."""
 
 from __future__ import absolute_import
 import numpy as _np
-from ...base import _sanity_check_params, use_np_compat, numeric_types
+from ...base import _sanity_check_params, use_np_compat, numeric_types, set_module
 from ...context import current_context
-from .. import _internal
+from . import _internal as _npi
 from ..ndarray import NDArray
 
 __all__ = ['zeros', 'ones', 'maximum', 'minimum']
 
 
+@set_module('mxnet.ndarray.numpy')
 @use_np_compat
 def zeros(shape, dtype=_np.float32, **kwargs):
     """Return a new array of given shape and type, filled with zeros.
@@ -55,9 +56,10 @@ def zeros(shape, dtype=_np.float32, **kwargs):
     if ctx is None:
         ctx = current_context()
     dtype = _np.float32 if dtype is None else dtype
-    return _internal._np_zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
+    return _npi.zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
 
 
+@set_module('mxnet.ndarray.numpy')
 @use_np_compat
 def ones(shape, dtype=None, **kwargs):
     """Return a new array of given shape and type, filled with ones.
@@ -86,7 +88,7 @@ def ones(shape, dtype=None, **kwargs):
     if ctx is None:
         ctx = current_context()
     dtype = _np.float32 if dtype is None else dtype
-    return _internal._np_ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
+    return _npi.ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
 
 
 #pylint: disable= too-many-arguments, no-member, protected-access
@@ -138,6 +140,7 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou
 #pylint: enable= too-many-arguments, no-member, protected-access
 
 
+@set_module('mxnet.ndarray.numpy')
 @use_np_compat
 def maximum(x1, x2, out=None):
     """Returns element-wise maximum of the input arrays with broadcasting.
@@ -152,10 +155,10 @@ def maximum(x1, x2, out=None):
     -------
     out : mxnet.numpy.ndarray or scalar
         The maximum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars."""
-    return _ufunc_helper(x1, x2, _internal._np_maximum, _np.maximum,
-                         _internal._np_maximum_scalar, None, out)
+    return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out)
 
 
+@set_module('mxnet.ndarray.numpy')
 @use_np_compat
 def minimum(x1, x2, out=None):
     """Returns element-wise minimum of the input arrays with broadcasting.
@@ -170,5 +173,4 @@ def minimum(x1, x2, out=None):
     -------
     out : mxnet.numpy.ndarray or scalar
         The minimum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars."""
-    return _ufunc_helper(x1, x2, _internal._np_minimum, _np.minimum,
-                         _internal._np_minimum_scalar, None, out)
+    return _ufunc_helper(x1, x2, _npi.minimum, _np.minimum, _npi.minimum_scalar, None, out)
diff --git a/python/mxnet/ndarray/numpy/_register.py b/python/mxnet/ndarray/numpy/_register.py
index 840797f..3ac464e 100644
--- a/python/mxnet/ndarray/numpy/_register.py
+++ b/python/mxnet/ndarray/numpy/_register.py
@@ -15,10 +15,14 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""module for registering numpy ops under mxnet.ndarray.numpy."""
+"""Registering numpy ops."""
 
 from ...base import _init_np_op_module
 from ..register import _make_ndarray_function
 
 
-_init_np_op_module('mxnet', 'ndarray', _make_ndarray_function)
+_init_np_op_module(root_module_name='mxnet', np_module_name='numpy',
+                   mx_module_name='ndarray', make_op_func=_make_ndarray_function)
+
+_init_np_op_module(root_module_name='mxnet', np_module_name='numpy._internal',
+                   mx_module_name='ndarray', make_op_func=_make_ndarray_function)
diff --git a/python/mxnet/ndarray/numpy/linalg.py b/python/mxnet/ndarray/numpy/linalg.py
index b8f10b3..8f521fd 100644
--- a/python/mxnet/ndarray/numpy/linalg.py
+++ b/python/mxnet/ndarray/numpy/linalg.py
@@ -15,6 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""numpy.linalg namespace for operators used in Gluon APIs dispatched by F=symbol module."""
+"""Namespace for operators used in Gluon dispatched by F=ndarray."""
 
 __all__ = []
diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py
index 60908b5..8f521fd 100644
--- a/python/mxnet/ndarray/numpy/random.py
+++ b/python/mxnet/ndarray/numpy/random.py
@@ -15,6 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""numpy.random namespace for operators used in Gluon APIs dispatched by F=ndarray module."""
+"""Namespace for operators used in Gluon dispatched by F=ndarray."""
 
 __all__ = []
diff --git a/python/mxnet/ndarray/numpy/__init__.py b/python/mxnet/ndarray/numpy_extension/__init__.py
similarity index 88%
copy from python/mxnet/ndarray/numpy/__init__.py
copy to python/mxnet/ndarray/numpy_extension/__init__.py
index d97e808..a718274 100644
--- a/python/mxnet/ndarray/numpy/__init__.py
+++ b/python/mxnet/ndarray/numpy_extension/__init__.py
@@ -15,11 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""numpy module for numpy ops under mxnet.ndarray."""
+"""Module for the ops not belonging to the official numpy package."""
 
-from . import ext
-from . import random
-from . import linalg
 from . import _op
 from . import _register
 from ._op import *  # pylint: disable=wildcard-import
diff --git a/python/mxnet/ndarray/numpy/ext.py b/python/mxnet/ndarray/numpy_extension/_op.py
similarity index 86%
rename from python/mxnet/ndarray/numpy/ext.py
rename to python/mxnet/ndarray/numpy_extension/_op.py
index e13423f..22738a0 100644
--- a/python/mxnet/ndarray/numpy/ext.py
+++ b/python/mxnet/ndarray/numpy_extension/_op.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""numpy.ext namespace for operators used in Gluon APIs dispatched by F=ndarray module."""
+"""Namespace for the operators not belonging to the official numpy package
+used in Gluon dispatched by F=ndarray module."""
 
 __all__ = []
diff --git a/python/mxnet/ndarray/numpy/_register.py b/python/mxnet/ndarray/numpy_extension/_register.py
similarity index 81%
copy from python/mxnet/ndarray/numpy/_register.py
copy to python/mxnet/ndarray/numpy_extension/_register.py
index 840797f..32cd068 100644
--- a/python/mxnet/ndarray/numpy/_register.py
+++ b/python/mxnet/ndarray/numpy_extension/_register.py
@@ -15,10 +15,11 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""module for registering numpy ops under mxnet.ndarray.numpy."""
+"""Registering numpy_extension ops."""
 
 from ...base import _init_np_op_module
 from ..register import _make_ndarray_function
 
 
-_init_np_op_module('mxnet', 'ndarray', _make_ndarray_function)
+_init_np_op_module(root_module_name='mxnet', np_module_name='numpy_extension',
+                   mx_module_name='ndarray', make_op_func=_make_ndarray_function)
diff --git a/python/mxnet/ndarray/register.py b/python/mxnet/ndarray/register.py
index 1ccf228..a285e50 100644
--- a/python/mxnet/ndarray/register.py
+++ b/python/mxnet/ndarray/register.py
@@ -24,12 +24,60 @@ import numpy as _np  # pylint: disable=unused-import
 from ._internal import NDArrayBase, _imperative_invoke # pylint: disable=unused-import
 from ..ndarray_doc import _build_doc
 
-from ..base import mx_uint, check_call, _LIB, py_str, _init_op_module, _Null # pylint: disable=unused-import
+from ..base import mx_uint, check_call, _LIB, py_str, _init_op_module, _Null, _is_np_op  # pylint: disable=unused-import
+
+
+def _verify_all_np_ndarrays(op_name, func_name, *array_list):
+    """Verify if all the arrays are numpy ndarrays.
+
+    Parameters
+    ----------
+    op_name : str
+        Operator full name registered in backend.
+    func_name : str
+        Operator name exposed to users. This is usually the name by stripping off
+        the prefix of the full operator names registered in backend.
+    array_list : list of arrays
+    """
+    from ..numpy import ndarray as np_ndarray
+    for array in array_list:
+        if (array is not None) and (not isinstance(array, np_ndarray)):
+            raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. '
+                            'This is a numpy operator which can only accept '
+                            'MXNet numpy ndarrays, while received a classic ndarray. '
+                            'Please call `as_np_ndarray()` upon the classic ndarray to '
+                            'convert it to an MXNet numpy ndarray, and then feed the converted '
+                            'array to this operator.'
+                            .format(op_name, func_name))
+
+
+def _verify_all_classic_ndarrays(op_name, func_name, *array_list):
+    """Verify if all the arrays are classic ndarrays.
+
+    Parameters
+    ----------
+    op_name : str
+        Operator full name registered in backend.
+    func_name : str
+        Operator name exposed to users. This is usually the name by stripping off
+        the prefix of the full operator names registered in backend.
+    array_list : list of arrays
+    """
+    from ..numpy import ndarray as np_ndarray
+    for array in array_list:
+        if (array is not None) and (isinstance(array, np_ndarray)):
+            raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. '
+                            'This is a classic operator which can only accept '
+                            'classic ndarrays, while received an MXNet numpy ndarray. '
+                            'Please call `as_classic_ndarray()` upon the numpy ndarray to '
+                            'convert it to a classic ndarray, and then feed the converted '
+                            'array to this operator.'
+                            .format(op_name, func_name))
 
 
 # pylint: disable=too-many-locals
-def _generate_ndarray_function_code(handle, name, func_name, signature_only=False):
-    """Generate function for ndarray op by handle and function name."""
+def _generate_ndarray_function_code(handle, op_name, func_name, signature_only=False):
+    """Generate function for ndarray op by handle and function op_name."""
     real_name = ctypes.c_char_p()
     desc = ctypes.c_char_p()
     num_args = mx_uint()
@@ -52,7 +100,7 @@ def _generate_ndarray_function_code(handle, name, func_name, signature_only=Fals
     arg_types = [py_str(arg_types[i]) for i in range(narg)]
     key_var_num_args = py_str(key_var_num_args.value)
     ret_type = py_str(ret_type.value) if ret_type.value is not None else ''
-    doc_str = _build_doc(name,
+    doc_str = _build_doc(op_name,
                          py_str(desc.value),
                          arg_names,
                          arg_types,
@@ -139,10 +187,16 @@ def %s(%s):"""%(func_name, ', '.join(signature)))
         keys.append('%s')
         vals.append(_np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name))
 
+    is_np_op = _is_np_op(op_name)
+    verify_ndarrays_fn =\
+        _verify_all_np_ndarrays.__name__ if is_np_op else _verify_all_classic_ndarrays.__name__
     if not signature_only:
         code.append("""
-    return _imperative_invoke(%d, ndargs, keys, vals, out)"""%(
-        handle.value))
+    {}("{}", "{}", out, *ndargs)
+        """.format(verify_ndarrays_fn, op_name, func_name))
+        code.append("""
+    return _imperative_invoke(%d, ndargs, keys, vals, out, %s)"""%(
+        handle.value, str(is_np_op)))
     else:
         code.append("""
     return (0,)""")
diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/__init__.py
index 2a58f27..0f3c3c7 100644
--- a/python/mxnet/numpy/__init__.py
+++ b/python/mxnet/numpy/__init__.py
@@ -17,15 +17,15 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""numpy module for imperative programming."""
+"""Module for numpy ops used in imperative programming."""
 
 from __future__ import absolute_import
 from . import random
 from . import linalg
-from . import ext
 from .multiarray import *  # pylint: disable=wildcard-import
 from . import _op
 from . import _register
 from ._op import *  # pylint: disable=wildcard-import
+from ..base import use_np_compat, set_np_compat, np_compat
 
 __all__ = []
diff --git a/python/mxnet/numpy/_op.py b/python/mxnet/numpy/_op.py
index e6a918c..8f6f9cc 100644
--- a/python/mxnet/numpy/_op.py
+++ b/python/mxnet/numpy/_op.py
@@ -15,6 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""namespace for registering numpy ops for imperative programming."""
+"""Namespace for registering numpy ops for imperative programming."""
 
 __all__ = []
diff --git a/python/mxnet/numpy/_register.py b/python/mxnet/numpy/_register.py
index 53ceecd..8a2d2ea 100644
--- a/python/mxnet/numpy/_register.py
+++ b/python/mxnet/numpy/_register.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""Register backend ops in mxnet.ndarray namespace."""
+"""Registering ops in mxnet.numpy for imperative programming."""
 
 from __future__ import absolute_import
 
@@ -23,4 +23,5 @@ from ..base import _init_np_op_module
 from ..ndarray.register import _make_ndarray_function
 
 
-_init_np_op_module('mxnet', 'numpy', _make_ndarray_function)
+_init_np_op_module(root_module_name='mxnet', np_module_name='numpy',
+                   mx_module_name=None, make_op_func=_make_ndarray_function)
diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/linalg.py
index 96c7ddc..e49bfcf 100644
--- a/python/mxnet/numpy/linalg.py
+++ b/python/mxnet/numpy/linalg.py
@@ -15,6 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""namespace for registering numpy.linalg ops for imperative programming."""
+"""Namespace for ops used in imperative programming."""
 
 __all__ = []
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 6c414b4..dfcce0b 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -25,14 +25,14 @@ from __future__ import division
 from array import array as native_array
 import ctypes
 import numpy as _np
-from ..ndarray import NDArray, _DTYPE_NP_TO_MX
+from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _GRAD_REQ_MAP
 from ..ndarray._internal import _set_np_ndarray_class
 from . import _op as _mx_np_op
 from ..base import use_np_compat, check_call, _LIB, NDArrayHandle, _sanity_check_params
-from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types
+from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types, set_module
 from ..context import current_context
 from ..ndarray import numpy as _mx_nd_np
-from ..ndarray import _internal as _nd_internal
+from ..ndarray.numpy import _internal as _npi
 
 __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum']
 
@@ -73,16 +73,14 @@ def _np_ndarray_cls(handle, writable=True, stype=0):
 _set_np_ndarray_class(_np_ndarray_cls)
 
 
-class ndarray(NDArray):  # pylint: disable=invalid-name
+@set_module('mxnet.numpy')  # pylint: disable=invalid-name
+class ndarray(NDArray):
     """An array object represents a multidimensional, homogeneous array of fixed-size items.
     An associated data-type object describes the format of each element in the array
     (its byte-order, how many bytes it occupies in memory, whether it is an integer, a
     floating point number, or something else, etc.). Arrays should be constructed using
     `array`, `zeros` or `empty`. Currently, only c-contiguous arrays are supported."""
 
-    def _is_np_compat(self):
-        return True
-
     @use_np_compat
     def __getitem__(self, item):
         # TODO(junwu): make output shape of integer indexing correct
@@ -90,15 +88,15 @@ class ndarray(NDArray):  # pylint: disable=invalid-name
 
     @use_np_compat
     def __setitem__(self, key, value):
-        super(ndarray, self).__setitem__(key, value)
+        self.as_classic_ndarray().__setitem__(key, value)
 
     @use_np_compat
     def __add__(self, other):
         """x.__add__(y) <=> x + y"""
-        if isinstance(other, NDArray):
-            return _nd_internal._np_add(self, other)
+        if isinstance(other, ndarray):
+            return _npi.add(self, other)
         elif isinstance(other, numeric_types):
-            return _nd_internal._np_add_scalar(self, float(other))
+            return _npi.add_scalar(self, float(other))
         else:
             raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
 
@@ -107,20 +105,20 @@ class ndarray(NDArray):  # pylint: disable=invalid-name
         """x.__iadd__(y) <=> x += y"""
         if not self.writable:
             raise ValueError('trying to add to a readonly ndarray')
-        if isinstance(other, NDArray):
-            return _nd_internal._np_add(self, other, out=self)
+        if isinstance(other, ndarray):
+            return _npi.add(self, other, out=self)
         elif isinstance(other, numeric_types):
-            return _nd_internal._np_add_scalar(self, float(other), out=self)
+            return _npi.add_scalar(self, float(other), out=self)
         else:
             raise TypeError('type {} is not supported'.format(str(type(other))))
 
     @use_np_compat
     def __sub__(self, other):
         """x.__sub__(y) <=> x - y"""
-        if isinstance(other, NDArray):
-            return _nd_internal._np_subtract(self, other)
+        if isinstance(other, ndarray):
+            return _npi.subtract(self, other)
         elif isinstance(other, numeric_types):
-            return _nd_internal._np_subtract_scalar(self, float(other))
+            return _npi.subtract_scalar(self, float(other))
         else:
             raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
 
@@ -129,30 +127,30 @@ class ndarray(NDArray):  # pylint: disable=invalid-name
         """x.__isub__(y) <=> x -= y"""
         if not self.writable:
             raise ValueError('trying to subtract from a readonly ndarray')
-        if isinstance(other, NDArray):
-            return _nd_internal._np_subtract(self, other, out=self)
+        if isinstance(other, ndarray):
+            return _npi.subtract(self, other, out=self)
         elif isinstance(other, numeric_types):
-            return _nd_internal._np_subtract_scalar(self, float(other), out=self)
+            return _npi.subtract_scalar(self, float(other), out=self)
         else:
             raise TypeError('type {} is not supported'.format(str(type(other))))
 
     @use_np_compat
     def __rsub__(self, other):
         """x.__rsub__(y) <=> y - x"""
-        if isinstance(other, NDArray):
-            return _nd_internal._np_subtract(other, self)
+        if isinstance(other, ndarray):
+            return _npi.subtract(other, self)
         elif isinstance(other, numeric_types):
-            return _nd_internal._np_rsubtract_scalar(self, float(other))
+            return _npi.rsubtract_scalar(self, float(other))
         else:
             raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
 
     @use_np_compat
     def __mul__(self, other):
         """x.__mul__(y) <=> x * y"""
-        if isinstance(other, NDArray):
-            return _nd_internal._np_multiply(self, other)
+        if isinstance(other, ndarray):
+            return _npi.multiply(self, other)
         elif isinstance(other, numeric_types):
-            return _nd_internal._np_multiply_scalar(self, float(other))
+            return _npi.multiply_scalar(self, float(other))
         else:
             raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
 
@@ -190,20 +188,20 @@ class ndarray(NDArray):  # pylint: disable=invalid-name
     @use_np_compat
     def __truediv__(self, other):
         """x.__truediv__(y) <=> x / y"""
-        if isinstance(other, NDArray):
-            return _nd_internal._true_divide(self, other)
+        if isinstance(other, ndarray):
+            return _npi.true_divide(self, other)
         elif isinstance(other, numeric_types):
-            return _nd_internal._true_divide_scalar(self, float(other))
+            return _npi.true_divide_scalar(self, float(other))
         else:
             raise TypeError("ndarray does not support type {} as divisor".format(str(type(other))))
 
     @use_np_compat
     def __rtruediv__(self, other):
         """x.__rtruediv__(y) <=> y / x"""
-        if isinstance(other, NDArray):
-            return _nd_internal._true_divide(other, self)
+        if isinstance(other, ndarray):
+            return _npi.true_divide(other, self)
         elif isinstance(other, numeric_types):
-            return _nd_internal._rtrue_divide_scalar(self, float(other))
+            return _npi.rtrue_divide_scalar(self, float(other))
         else:
             raise TypeError("ndarray does not support type {} as dividend".format(str(type(other))))
 
@@ -214,20 +212,20 @@ class ndarray(NDArray):  # pylint: disable=invalid-name
     @use_np_compat
     def __mod__(self, other):
         """x.__mod__(y) <=> x % y"""
-        if isinstance(other, NDArray):
-            return _nd_internal._np_mod(self, other)
+        if isinstance(other, ndarray):
+            return _npi.mod(self, other)
         elif isinstance(other, numeric_types):
-            return _nd_internal._np_mod_scalar(self, float(other))
+            return _npi.mod_scalar(self, float(other))
         else:
             raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
 
     @use_np_compat
     def __rmod__(self, other):
         """x.__rmod__(y) <=> y % x"""
-        if isinstance(other, NDArray):
-            return _nd_internal._np_mod(other, self)
+        if isinstance(other, ndarray):
+            return _npi.mod(other, self)
         elif isinstance(other, numeric_types):
-            return _nd_internal._np_rmod_scalar(self, float(other))
+            return _npi.rmod_scalar(self, float(other))
         else:
             raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
 
@@ -238,20 +236,20 @@ class ndarray(NDArray):  # pylint: disable=invalid-name
     @use_np_compat
     def __pow__(self, other):
         """x.__pow__(y) <=> x ** y"""
-        if isinstance(other, NDArray):
-            return _nd_internal._np_power(self, other)
+        if isinstance(other, ndarray):
+            return _npi.power(self, other)
         elif isinstance(other, numeric_types):
-            return _nd_internal._np_power_scalar(self, float(other))
+            return _npi.power_scalar(self, float(other))
         else:
             raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
 
     @use_np_compat
     def __rpow__(self, other):
         """x.__rpow__(y) <=> y ** x"""
-        if isinstance(other, NDArray):
-            return _nd_internal._np_power(other, self)
+        if isinstance(other, ndarray):
+            return _npi.power(other, self)
         elif isinstance(other, numeric_types):
-            return _nd_internal._np_rpower_scalar(self, float(other))
+            return _npi.rpower_scalar(self, float(other))
         else:
             raise TypeError("ndarray does not support type {} as operand".format(str(type(other))))
 
@@ -355,15 +353,41 @@ class ndarray(NDArray):  # pylint: disable=invalid-name
 
     @use_np_compat
     def __repr__(self):
-        """Returns a string representation of the array."""
-        return '%s\n<%s shape=%s ctx=%s>' % (str(self.asnumpy()), self.__class__.__name__,
-                                             self.shape, self.context)
+        """Returns a string representation of the array using the following rules:
+        1. If the `ndarray` is a scalar tensor, only the string of the scalar is returned.
+        2. Else if the `ndarray` is allocated on cpu, the string of its numpy form, class name,
+        and shape is returned.
+        3. Else (the `ndarray` is allocated on gpu), the string of its numpy form, class name,
+        shape, and context is returned."""
+        array_str = str(self.asnumpy())
+        if self.ndim == 0:  # scalar tensor
+            return array_str
+        context = self.context
+        if context.device_type == 'gpu':
+            return '%s\n<%s shape=%s ctx=%s>' % (array_str, self.__class__.__name__, self.shape,
+                                                 context)
+        else:
+            return '%s\n<%s shape=%s>' % (array_str, self.__class__.__name__, self.shape)
 
     @use_np_compat
-    def attach_grad(self, grad_req='write', stype=None):
-        if stype is not None:
-            raise NotImplementedError('mxnet.numpy.ndarray currently does not support stype')
-        super(ndarray, self).attach_grad(grad_req, stype)
+    def attach_grad(self, grad_req='write'):  # pylint: disable=arguments-differ
+        """Attach a gradient buffer to this ndarray, so that `backward`
+        can compute gradient with respect to it.
+
+        Parameters
+        ----------
+        grad_req : {'write', 'add', 'null'}
+            How gradient will be accumulated.
+            - 'write': gradient will be overwritten on every backward.
+            - 'add': gradient will be added to existing value on every backward.
+            - 'null': do not compute gradient for this NDArray.
+        """
+        grad = _mx_np_op.zeros_like(self)  # pylint: disable=undefined-variable
+        grad_req = _GRAD_REQ_MAP[grad_req]
+        check_call(_LIB.MXAutogradMarkVariables(
+            1, ctypes.pointer(self.handle),
+            ctypes.pointer(mx_uint(grad_req)),
+            ctypes.pointer(grad.handle)))
 
     @property
     def grad(self):
@@ -412,6 +436,43 @@ class ndarray(NDArray):  # pylint: disable=invalid-name
         self.copyto(res)
         return res
 
+    @use_np_compat
+    def copyto(self, other):
+        """Copies the value of this array to another array.
+
+        If ``other`` is a ``ndarray`` object, then ``other.shape`` and
+        ``self.shape`` should be the same. This function copies the value from
+        ``self`` to ``other``.
+
+        If ``other`` is a context, a new ``NDArray`` will be first created on
+        the target context, and the value of ``self`` is copied.
+
+        Parameters
+        ----------
+        other : ndarray or Context
+            The destination array or context.
+
+        Returns
+        -------
+        ndarray
+            The copied array. If ``other`` is an ``ndarray``, then the return value
+            and ``other`` will point to the same ``ndarray``.
+
+        Examples
+        --------
+        >>> x = np.ones((2,3))
+        >>> y = np.zeros((2,3), mx.gpu(0))
+        >>> z = x.copyto(y)
+        >>> z is y
+        True
+        >>> y.asnumpy()
+        array([[ 1.,  1.,  1.],
+               [ 1.,  1.,  1.]], dtype=float32)
+        """
+        if isinstance(other, ndarray):
+            other = other.as_classic_ndarray()
+        return self.as_classic_ndarray().copyto(other).as_np_ndarray()
+
     def asscalar(self):
         raise AttributeError('mxnet.numpy.ndarray object has no attribute as_scalar')
 
@@ -435,7 +496,7 @@ class ndarray(NDArray):  # pylint: disable=invalid-name
         if order != 'C':
             raise NotImplementedError('reshape only supports C-order,'
                                       ' while received {}'.format(order))
-        return _mx_np_op.reshape(self, shape=shape, order=order)
+        return _mx_np_op.reshape(self, newshape=shape, order=order)
 
     def reshape_like(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`reshape_like`.
@@ -1117,15 +1178,11 @@ class ndarray(NDArray):  # pylint: disable=invalid-name
         """Number of elements in the array."""
         return super(ndarray, self).size
 
-    @property
-    @use_np_compat
-    def stype(self):
-        raise AttributeError('mxnet.numpy.ndarray object has no attribute stype')
-
     def tostype(self, stype):
         raise AttributeError('mxnet.numpy.ndarray object has no attribute tostype')
 
 
+@set_module('mxnet.numpy')
 @use_np_compat
 def empty(shape, dtype=None, **kwargs):
     """Return a new array of given shape and type, without initializing entries.
@@ -1158,6 +1215,7 @@ def empty(shape, dtype=None, **kwargs):
     return ndarray(handle=_new_alloc_handle(shape, ctx, False, dtype))
 
 
+@set_module('mxnet.numpy')
 @use_np_compat
 def array(object, dtype=None, **kwargs):
     """
@@ -1169,10 +1227,7 @@ def array(object, dtype=None, **kwargs):
         An array, any object exposing the array interface, an object whose
         __array__ method returns an array, or any (nested) sequence.
     dtype : data-type, optional
-        The desired data-type for the array.  If not given, then the type will
-        be determined as the minimum type required to hold the objects in the
-        sequence. This argument can only be used to 'upcast' the array.  For
-        downcasting, use the .astype(t) method.
+        The desired data-type for the array. Default is `float32`.
     ctx : device context, optional
         Device context on which the memory is allocated. Default is
         `mxnet.context.current_context()`.
@@ -1186,18 +1241,19 @@ def array(object, dtype=None, **kwargs):
     ctx = kwargs.get('ctx', current_context())
     if ctx is None:
         ctx = current_context()
+    if dtype is None:
+        dtype = _np.float32
     if not isinstance(object, (ndarray, NDArray, _np.ndarray)):
         try:
             object = _np.array(object, dtype=dtype)
         except:
             raise TypeError('source array must be an array like object')
-    if dtype is None:
-        dtype = object.dtype
     ret = empty(object.shape, dtype=dtype, ctx=ctx)
     ret[:] = object
     return ret
 
 
+@set_module('mxnet.numpy')
 def zeros(shape, dtype=_np.float32, **kwargs):
     """Return a new array of given shape and type, filled with zeros.
     This function currently only supports storing multi-dimensional data
@@ -1223,6 +1279,7 @@ def zeros(shape, dtype=_np.float32, **kwargs):
     return _mx_nd_np.zeros(shape, dtype, **kwargs)
 
 
+@set_module('mxnet.numpy')
 def ones(shape, dtype=None, **kwargs):
     """Return a new array of given shape and type, filled with zeros.
     This function currently only supports storing multi-dimensional data
@@ -1248,6 +1305,7 @@ def ones(shape, dtype=None, **kwargs):
     return _mx_nd_np.ones(shape, dtype, **kwargs)
 
 
+@set_module('mxnet.numpy')
 def maximum(x1, x2, out=None):
     """Returns element-wise maximum of the input arrays with broadcasting.
 
@@ -1264,6 +1322,7 @@ def maximum(x1, x2, out=None):
     return _mx_nd_np.maximum(x1, x2, out=out)
 
 
+@set_module('mxnet.numpy')
 def minimum(x1, x2, out=None):
     """Returns element-wise minimum of the input arrays with broadcasting.
 
diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py
index b1f4b02..e49bfcf 100644
--- a/python/mxnet/numpy/random.py
+++ b/python/mxnet/numpy/random.py
@@ -15,6 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""namespace for registering numpy.random ops for imperative programming."""
+"""Namespace for ops used in imperative programming."""
 
 __all__ = []
diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy_extension/__init__.py
similarity index 85%
copy from python/mxnet/numpy/__init__.py
copy to python/mxnet/numpy_extension/__init__.py
index 2a58f27..bd51175 100644
--- a/python/mxnet/numpy/__init__.py
+++ b/python/mxnet/numpy_extension/__init__.py
@@ -17,15 +17,12 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""numpy module for imperative programming."""
+"""Module for ops not belonging to the official numpy package for imperative programming."""
 
 from __future__ import absolute_import
-from . import random
-from . import linalg
-from . import ext
-from .multiarray import *  # pylint: disable=wildcard-import
 from . import _op
 from . import _register
 from ._op import *  # pylint: disable=wildcard-import
+from ..context import *  # pylint: disable=wildcard-import
 
 __all__ = []
diff --git a/python/mxnet/numpy/_op.py b/python/mxnet/numpy_extension/_op.py
similarity index 90%
copy from python/mxnet/numpy/_op.py
copy to python/mxnet/numpy_extension/_op.py
index e6a918c..a995e48 100644
--- a/python/mxnet/numpy/_op.py
+++ b/python/mxnet/numpy_extension/_op.py
@@ -15,6 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""namespace for registering numpy ops for imperative programming."""
+"""Namespace for registering numpy_extension ops for imperative programming."""
 
 __all__ = []
diff --git a/python/mxnet/numpy/_register.py b/python/mxnet/numpy_extension/_register.py
similarity index 79%
copy from python/mxnet/numpy/_register.py
copy to python/mxnet/numpy_extension/_register.py
index 53ceecd..8abb725 100644
--- a/python/mxnet/numpy/_register.py
+++ b/python/mxnet/numpy_extension/_register.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""Register backend ops in mxnet.ndarray namespace."""
+"""Registering ops in mxnet.numpy_extension for imperative programming."""
 
 from __future__ import absolute_import
 
@@ -23,4 +23,5 @@ from ..base import _init_np_op_module
 from ..ndarray.register import _make_ndarray_function
 
 
-_init_np_op_module('mxnet', 'numpy', _make_ndarray_function)
+_init_np_op_module(root_module_name='mxnet', np_module_name='numpy_extension',
+                   mx_module_name=None, make_op_func=_make_ndarray_function)
diff --git a/python/mxnet/symbol/__init__.py b/python/mxnet/symbol/__init__.py
index ae9477a..1cd8057 100644
--- a/python/mxnet/symbol/__init__.py
+++ b/python/mxnet/symbol/__init__.py
@@ -28,5 +28,7 @@ from .op import *
 from .symbol import *
 # pylint: enable=wildcard-import
 from . import numpy as np
+from . import numpy_extension as npe
 
-__all__ = op.__all__ + symbol.__all__ + ['contrib', 'linalg', 'random', 'sparse', 'image']
+__all__ = op.__all__ + symbol.__all__\
+          + ['contrib', 'linalg', 'random', 'sparse', 'image', 'numpy', 'numpy_extension']
diff --git a/python/mxnet/symbol/numpy/__init__.py b/python/mxnet/symbol/numpy/__init__.py
index 1f20c03..857849c 100644
--- a/python/mxnet/symbol/numpy/__init__.py
+++ b/python/mxnet/symbol/numpy/__init__.py
@@ -15,13 +15,12 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""numpy module for numpy ops under mxnet.symbol."""
+"""Module for numpy ops under mxnet.symbol."""
 
 from . import random
 from . import linalg
-from . import ext
-from . import _op, _symbol
-from ._symbol import _NumpySymbol
+from . import _op, _symbol, _internal
+from ._symbol import _Symbol
 from . import _register
 from ._op import *  # pylint: disable=wildcard-import
 from ._symbol import *  # pylint: disable=wildcard-import
diff --git a/python/mxnet/numpy/_op.py b/python/mxnet/symbol/numpy/_internal.py
similarity index 91%
copy from python/mxnet/numpy/_op.py
copy to python/mxnet/symbol/numpy/_internal.py
index e6a918c..c5f2928 100644
--- a/python/mxnet/numpy/_op.py
+++ b/python/mxnet/symbol/numpy/_internal.py
@@ -15,6 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""namespace for registering numpy ops for imperative programming."""
+"""Namespace for numpy internal ops."""
 
 __all__ = []
diff --git a/python/mxnet/symbol/numpy/_op.py b/python/mxnet/symbol/numpy/_op.py
index 96da828..a4a979f 100644
--- a/python/mxnet/symbol/numpy/_op.py
+++ b/python/mxnet/symbol/numpy/_op.py
@@ -15,6 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""numpy namespace for operators used in Gluon APIs dispatched by F=symbol module."""
+"""Namespace for operators used in Gluon dispatched by F=symbol module."""
 
 __all__ = []
diff --git a/python/mxnet/symbol/numpy/_register.py b/python/mxnet/symbol/numpy/_register.py
index 36dfd78..3245c8d 100644
--- a/python/mxnet/symbol/numpy/_register.py
+++ b/python/mxnet/symbol/numpy/_register.py
@@ -15,9 +15,14 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""module for registering numpy ops under mxnet.symbol.numpy."""
+"""Registering numpy ops."""
 
 from ...base import _init_np_op_module
 from ..register import _make_symbol_function
 
-_init_np_op_module('mxnet', 'symbol', _make_symbol_function)
+_init_np_op_module(root_module_name='mxnet', np_module_name='numpy',
+                   mx_module_name='symbol', make_op_func=_make_symbol_function)
+
+
+_init_np_op_module(root_module_name='mxnet', np_module_name='numpy._internal',
+                   mx_module_name='symbol', make_op_func=_make_symbol_function)
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index 8cf6e30..0bbd96b 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -23,21 +23,17 @@ import ctypes
 import numpy as _np
 from . import _op as _mx_np_op
 from ...base import _sanity_check_params, use_np_compat, check_call, _LIB, SymbolHandle
-from ...base import numeric_types
+from ...base import numeric_types, set_module
 from ...context import current_context
-from .. import _internal
 from ..symbol import Symbol
 from .._internal import _set_np_symbol_class
-from .. import _internal as _sym_internal
+from . import _internal as _npi
 
 __all__ = ['zeros', 'ones', 'maximum', 'minimum']
 
 
-class _NumpySymbol(Symbol):
-
-    def _is_np_compat(self):
-        return True
-
+@set_module('mxnet.symbol.numpy')
+class _Symbol(Symbol):
     def __getitem__(self, item):
         raise NotImplementedError
 
@@ -45,72 +41,72 @@ class _NumpySymbol(Symbol):
         raise NotImplementedError
 
     def __iter__(self):
-        raise AttributeError('_NumpySymbol object has no attribute __iter__')
+        raise AttributeError('_Symbol object has no attribute __iter__')
 
     @use_np_compat
     def __add__(self, other):
         """x.__add__(y) <=> x + y"""
-        if isinstance(other, Symbol):
-            return _sym_internal._np_add(self, other)
+        if isinstance(other, _Symbol):
+            return _npi.add(self, other)
         elif isinstance(other, numeric_types):
-            return _sym_internal._np_add_scalar(self, float(other))
+            return _npi.add_scalar(self, float(other))
         else:
-            raise TypeError("_NumpySymbol does not support type {} as operand"
+            raise TypeError("_Symbol does not support type {} as operand"
                             .format(str(type(other))))
 
     @use_np_compat
     def __sub__(self, other):
         """x.__sub__(y) <=> x - y"""
-        if isinstance(other, Symbol):
-            return _sym_internal._np_subtract(self, other)
+        if isinstance(other, _Symbol):
+            return _npi.subtract(self, other)
         elif isinstance(other, numeric_types):
-            return _sym_internal._np_subtract_scalar(self, float(other))
+            return _npi.subtract_scalar(self, float(other))
         else:
-            raise TypeError("_NumpySymbol does not support type {} as operand"
+            raise TypeError("_Symbol does not support type {} as operand"
                             .format(str(type(other))))
 
     @use_np_compat
     def __rsub__(self, other):
         """x.__rsub__(y) <=> y - x"""
-        if isinstance(other, Symbol):
-            return _sym_internal._np_subtract(other, self)
+        if isinstance(other, _Symbol):
+            return _npi.subtract(other, self)
         elif isinstance(other, numeric_types):
-            return _sym_internal._np_rsubtract_scalar(self, float(other))
+            return _npi.rsubtract_scalar(self, float(other))
         else:
-            raise TypeError("_NumpySymbol does not support type {} as operand"
+            raise TypeError("_Symbol does not support type {} as operand"
                             .format(str(type(other))))
 
     @use_np_compat
     def __mul__(self, other):
         """x.__mul__(y) <=> x * y"""
-        if isinstance(other, Symbol):
-            return _sym_internal._np_multiply(self, other)
+        if isinstance(other, _Symbol):
+            return _npi.multiply(self, other)
         elif isinstance(other, numeric_types):
-            return _sym_internal._np_multiply_scalar(self, float(other))
+            return _npi.multiply_scalar(self, float(other))
         else:
-            raise TypeError("_NumpySymbol does not support type {} as operand"
+            raise TypeError("_Symbol does not support type {} as operand"
                             .format(str(type(other))))
 
     @use_np_compat
     def __rmul__(self, other):
         """x.__rmul__(y) <=> y * x"""
-        if isinstance(other, Symbol):
-            return _sym_internal._np_multiply(self, other)
+        if isinstance(other, _Symbol):
+            return _npi.multiply(self, other)
         elif isinstance(other, numeric_types):
-            return _sym_internal._np_multiply_scalar(self, float(other))
+            return _npi.multiply_scalar(self, float(other))
         else:
-            raise TypeError("_NumpySymbol does not support type {} as operand"
+            raise TypeError("_Symbol does not support type {} as operand"
                             .format(str(type(other))))
 
     def __div__(self, other):
-        raise AttributeError('_NumpySymbol.__div__ is replaced by __truediv__. If you are using'
+        raise AttributeError('_Symbol.__div__ is replaced by __truediv__. If you are using'
                              ' Python2, please use the statement from __future__ import division'
                              ' to change the / operator to mean true division throughout the'
                              ' module. If you are using Python3, this error should not have'
                              ' been encountered.')
 
     def __rdiv__(self, other):
-        raise AttributeError('_NumpySymbol.__rdiv__ is replaced by __rtruediv__. If you are using'
+        raise AttributeError('_Symbol.__rdiv__ is replaced by __rtruediv__. If you are using'
                              ' Python2, please use the statement from __future__ import division'
                              ' to change the / operator to mean true division throughout the'
                              ' module. If you are using Python3, this error should not have'
@@ -119,23 +115,23 @@ class _NumpySymbol(Symbol):
     @use_np_compat
     def __mod__(self, other):
         """x.__mod__(y) <=> x % y"""
-        if isinstance(other, Symbol):
-            return _sym_internal._np_mod(self, other)
+        if isinstance(other, _Symbol):
+            return _npi.mod(self, other)
         elif isinstance(other, numeric_types):
-            return _sym_internal._np_mod_scalar(self, float(other))
+            return _npi.mod_scalar(self, float(other))
         else:
-            raise TypeError("_NumpySymbol does not support type {} as operand"
+            raise TypeError("_Symbol does not support type {} as operand"
                             .format(str(type(other))))
 
     @use_np_compat
     def __rmod__(self, other):
         """x.__rmod__(y) <=> y % x"""
-        if isinstance(other, Symbol):
-            return _sym_internal._np_mod(other, self)
+        if isinstance(other, _Symbol):
+            return _npi.mod(other, self)
         elif isinstance(other, numeric_types):
-            return _sym_internal._np_rmod_scalar(self, float(other))
+            return _npi.rmod_scalar(self, float(other))
         else:
-            raise TypeError("_NumpySymbol does not support type {} as operand"
+            raise TypeError("_Symbol does not support type {} as operand"
                             .format(str(type(other))))
 
     @use_np_compat
@@ -145,23 +141,23 @@ class _NumpySymbol(Symbol):
     @use_np_compat
     def __truediv__(self, other):
         """x.__truediv__(y) <=> x / y"""
-        if isinstance(other, Symbol):
-            return _sym_internal._true_divide(self, other)
+        if isinstance(other, _Symbol):
+            return _npi.true_divide(self, other)
         elif isinstance(other, numeric_types):
-            return _sym_internal._true_divide_scalar(self, float(other))
+            return _npi.true_divide_scalar(self, float(other))
         else:
-            raise TypeError("_NumpySymbol does not support type {} as divisor"
+            raise TypeError("_Symbol does not support type {} as divisor"
                             .format(str(type(other))))
 
     @use_np_compat
     def __rtruediv__(self, other):
         """x.__rtruediv__(y) <=> y / x"""
-        if isinstance(other, Symbol):
-            return _sym_internal._true_divide(other, self)
+        if isinstance(other, _Symbol):
+            return _npi.true_divide(other, self)
         elif isinstance(other, numeric_types):
-            return _sym_internal._rtrue_divide_scalar(self, float(other)).as_np_ndarray()
+            return _npi.rtrue_divide_scalar(self, float(other)).as_np_ndarray()
         else:
-            raise TypeError("_NumpySymbol does not support type {} as dividend"
+            raise TypeError("_Symbol does not support type {} as dividend"
                             .format(str(type(other))))
 
     @use_np_compat
@@ -171,23 +167,23 @@ class _NumpySymbol(Symbol):
     @use_np_compat
     def __pow__(self, other):
         """x.__pow__(y) <=> x ** y"""
-        if isinstance(other, Symbol):
-            return _sym_internal._np_power(self, other)
+        if isinstance(other, _Symbol):
+            return _npi.power(self, other)
         elif isinstance(other, numeric_types):
-            return _sym_internal._np_power_scalar(self, float(other))
+            return _npi.power_scalar(self, float(other))
         else:
-            raise TypeError("_NumpySymbol does not support type {} as operand"
+            raise TypeError("_Symbol does not support type {} as operand"
                             .format(str(type(other))))
 
     @use_np_compat
     def __rpow__(self, other):
         """x.__rpow__(y) <=> y ** x"""
-        if isinstance(other, Symbol):
-            return _sym_internal._np_power(other, self)
+        if isinstance(other, _Symbol):
+            return _npi.power(other, self)
         elif isinstance(other, numeric_types):
-            return _sym_internal._np_rpower_scalar(self, float(other))
+            return _npi.rpower_scalar(self, float(other))
         else:
-            raise TypeError("_NumpySymbol does not support type {} as operand"
+            raise TypeError("_Symbol does not support type {} as operand"
                             .format(str(type(other))))
 
     @use_np_compat
@@ -197,7 +193,7 @@ class _NumpySymbol(Symbol):
 
     @use_np_compat
     def __deepcopy__(self, _):
-        return super(_NumpySymbol, self).as_np_ndarray()
+        return super(_Symbol, self).as_np_ndarray()
 
     @use_np_compat
     def __eq__(self, other):
@@ -233,7 +229,7 @@ class _NumpySymbol(Symbol):
         raise NotImplementedError
 
     def as_classic_ndarray(self):
-        """Convert _NumpySymbol to mxnet.symbol.Symbol to use its convenience fluent methods."""
+        """Convert _Symbol to mxnet.symbol.Symbol to use its convenience fluent methods."""
         hdl = SymbolHandle()
         check_call(_LIB.MXShallowCopySymbol(self.handle, ctypes.byref(hdl)))
         return Symbol(handle=hdl)
@@ -258,7 +254,7 @@ class _NumpySymbol(Symbol):
         if order != 'C':
             raise NotImplementedError('ndarray.copy only supports order=\'C\', while '
                                       'received {}'.format(str(order)))
-        return _mx_np_op.reshape(self, shape=shape, order=order)
+        return _mx_np_op.reshape(self, newshape=shape, order=order)
 
     def reshape_like(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`reshape_like`.
@@ -266,7 +262,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`reshape_like`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute reshape_like')
+        raise AttributeError('_Symbol object has no attribute reshape_like')
 
     def zeros_like(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`zeros_like`.
@@ -274,7 +270,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`zeros_like`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute zeros_like')
+        raise AttributeError('_Symbol object has no attribute zeros_like')
 
     def ones_like(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`ones_like`.
@@ -282,7 +278,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`ones_like`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute ones_like')
+        raise AttributeError('_Symbol object has no attribute ones_like')
 
     def broadcast_axes(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`broadcast_axes`.
@@ -290,7 +286,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`broadcast_axes`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute broadcast_like')
+        raise AttributeError('_Symbol object has no attribute broadcast_like')
 
     @use_np_compat
     def repeat(self, *args, **kwargs):
@@ -307,7 +303,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`pad`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute pad')
+        raise AttributeError('_Symbol object has no attribute pad')
 
     @use_np_compat
     def swapaxes(self, *args, **kwargs):
@@ -324,7 +320,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`split`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute split')
+        raise AttributeError('_Symbol object has no attribute split')
 
     def split_v2(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`split_v2`.
@@ -332,7 +328,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`split_v2`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute split_v2')
+        raise AttributeError('_Symbol object has no attribute split_v2')
 
     def slice(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`slice`.
@@ -340,7 +336,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`slice`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute slice')
+        raise AttributeError('_Symbol object has no attribute slice')
 
     def slice_axis(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`slice_axis`.
@@ -348,7 +344,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`slice_axis`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute slice_axis')
+        raise AttributeError('_Symbol object has no attribute slice_axis')
 
     def slice_like(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`slice_like`.
@@ -356,7 +352,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`slice_like`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute slice_like')
+        raise AttributeError('_Symbol object has no attribute slice_like')
 
     @use_np_compat
     def take(self, *args, **kwargs):
@@ -373,7 +369,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`one_hot`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute one_hot')
+        raise AttributeError('_Symbol object has no attribute one_hot')
 
     def pick(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`pick`.
@@ -381,7 +377,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`pick`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute pick')
+        raise AttributeError('_Symbol object has no attribute pick')
 
     @use_np_compat
     def sort(self, *args, **kwargs):
@@ -398,7 +394,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`topk`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute topk')
+        raise AttributeError('_Symbol object has no attribute topk')
 
     @use_np_compat
     def argsort(self, *args, **kwargs):
@@ -424,7 +420,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`argmax_channel`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute argmax_channel')
+        raise AttributeError('_Symbol object has no attribute argmax_channel')
 
     @use_np_compat
     def argmin(self, *args, **kwargs):
@@ -450,7 +446,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`abs`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute abs')
+        raise AttributeError('_Symbol object has no attribute abs')
 
     def sign(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`sign`.
@@ -458,7 +454,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`sign`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute abs')
+        raise AttributeError('_Symbol object has no attribute abs')
 
     @use_np_compat
     def flatten(self, *args, **kwargs):
@@ -475,7 +471,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`shape_array`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute shape_array')
+        raise AttributeError('_Symbol object has no attribute shape_array')
 
     def size_array(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`size_array`.
@@ -483,7 +479,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`size_array`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute size_array')
+        raise AttributeError('_Symbol object has no attribute size_array')
 
     def expand_dims(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`expand_dims`.
@@ -491,7 +487,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`expand_dims`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute expand_dims')
+        raise AttributeError('_Symbol object has no attribute expand_dims')
 
     def tile(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`tile`.
@@ -499,7 +495,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`tile`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute tile')
+        raise AttributeError('_Symbol object has no attribute tile')
 
     @use_np_compat
     def transpose(self, *axes):  # pylint: disable=arguments-differ
@@ -516,7 +512,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`flip`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute flip')
+        raise AttributeError('_Symbol object has no attribute flip')
 
     def depth_to_space(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`depth_to_space`.
@@ -524,7 +520,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`depth_to_space`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute depth_to_space')
+        raise AttributeError('_Symbol object has no attribute depth_to_space')
 
     def space_to_depth(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`space_to_depth`.
@@ -532,7 +528,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`space_to_depth`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute space_to_depth')
+        raise AttributeError('_Symbol object has no attribute space_to_depth')
 
     def diag(self, k=0, **kwargs):
         """Convenience fluent method for :py:func:`diag`.
@@ -540,7 +536,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`diag`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute diag')
+        raise AttributeError('_Symbol object has no attribute diag')
 
     @use_np_compat
     def sum(self, axis=None, dtype=None, out=None, keepdims=False):  # pylint: disable=arguments-differ
@@ -557,7 +553,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`nansum`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute nansum')
+        raise AttributeError('_Symbol object has no attribute nansum')
 
     @use_np_compat
     def prod(self, *args, **kwargs):
@@ -574,7 +570,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`nanprod`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute nanprod')
+        raise AttributeError('_Symbol object has no attribute nanprod')
 
     @use_np_compat
     def mean(self, *args, **kwargs):
@@ -609,7 +605,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`norm`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute norm')
+        raise AttributeError('_Symbol object has no attribute norm')
 
     @use_np_compat
     def round(self, *args, **kwargs):
@@ -626,7 +622,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`rint`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute rint')
+        raise AttributeError('_Symbol object has no attribute rint')
 
     def fix(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`fix`.
@@ -634,7 +630,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`fix`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute fix')
+        raise AttributeError('_Symbol object has no attribute fix')
 
     def floor(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`floor`.
@@ -642,7 +638,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`floor`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute floor')
+        raise AttributeError('_Symbol object has no attribute floor')
 
     def ceil(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`ceil`.
@@ -650,7 +646,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`ceil`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute ceil')
+        raise AttributeError('_Symbol object has no attribute ceil')
 
     def trunc(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`trunc`.
@@ -658,7 +654,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`trunc`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute trunc')
+        raise AttributeError('_Symbol object has no attribute trunc')
 
     def sin(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`sin`.
@@ -666,7 +662,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`sin`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute sin')
+        raise AttributeError('_Symbol object has no attribute sin')
 
     def cos(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`cos`.
@@ -674,7 +670,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`cos`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute cos')
+        raise AttributeError('_Symbol object has no attribute cos')
 
     def tan(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`tan`.
@@ -682,7 +678,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`tan`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute tan')
+        raise AttributeError('_Symbol object has no attribute tan')
 
     def arcsin(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`arcsin`.
@@ -690,7 +686,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`arcsin`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute arcsin')
+        raise AttributeError('_Symbol object has no attribute arcsin')
 
     def arccos(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`arccos`.
@@ -698,7 +694,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`arccos`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute arccos')
+        raise AttributeError('_Symbol object has no attribute arccos')
 
     def arctan(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`arctan`.
@@ -706,7 +702,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`arctan`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute arctan')
+        raise AttributeError('_Symbol object has no attribute arctan')
 
     def degrees(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`degrees`.
@@ -714,7 +710,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`degrees`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute degrees')
+        raise AttributeError('_Symbol object has no attribute degrees')
 
     def radians(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`radians`.
@@ -722,7 +718,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`radians`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute radians')
+        raise AttributeError('_Symbol object has no attribute radians')
 
     def sinh(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`sinh`.
@@ -730,7 +726,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`sinh`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute sinh')
+        raise AttributeError('_Symbol object has no attribute sinh')
 
     def cosh(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`cosh`.
@@ -738,7 +734,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`cosh`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute cosh')
+        raise AttributeError('_Symbol object has no attribute cosh')
 
     def tanh(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`tanh`.
@@ -746,7 +742,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`tanh`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute tanh')
+        raise AttributeError('_Symbol object has no attribute tanh')
 
     def arcsinh(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`arcsinh`.
@@ -754,7 +750,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`arcsinh`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute arcsinh')
+        raise AttributeError('_Symbol object has no attribute arcsinh')
 
     def arccosh(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`arccosh`.
@@ -762,7 +758,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`arccosh`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute arccosh')
+        raise AttributeError('_Symbol object has no attribute arccosh')
 
     def arctanh(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`arctanh`.
@@ -770,7 +766,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`arctanh`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute arctanh')
+        raise AttributeError('_Symbol object has no attribute arctanh')
 
     def exp(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`exp`.
@@ -778,7 +774,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`exp`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute exp')
+        raise AttributeError('_Symbol object has no attribute exp')
 
     def expm1(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`expm1`.
@@ -786,7 +782,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`expm1`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute expm1')
+        raise AttributeError('_Symbol object has no attribute expm1')
 
     def log(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`log`.
@@ -794,7 +790,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`log`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute log')
+        raise AttributeError('_Symbol object has no attribute log')
 
     def log10(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`log10`.
@@ -802,7 +798,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`log10`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute log10')
+        raise AttributeError('_Symbol object has no attribute log10')
 
     def log2(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`log2`.
@@ -810,7 +806,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`log2`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute log2')
+        raise AttributeError('_Symbol object has no attribute log2')
 
     def log1p(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`log1p`.
@@ -818,7 +814,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`log1p`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute log1p')
+        raise AttributeError('_Symbol object has no attribute log1p')
 
     def sqrt(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`sqrt`.
@@ -826,7 +822,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`sqrt`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute sqrt')
+        raise AttributeError('_Symbol object has no attribute sqrt')
 
     def rsqrt(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`rsqrt`.
@@ -834,7 +830,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`rsqrt`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute rsqrt')
+        raise AttributeError('_Symbol object has no attribute rsqrt')
 
     def cbrt(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`cbrt`.
@@ -842,7 +838,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`cbrt`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute cqrt')
+        raise AttributeError('_Symbol object has no attribute cqrt')
 
     def rcbrt(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`rcbrt`.
@@ -850,7 +846,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`rcbrt`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute rcqrt')
+        raise AttributeError('_Symbol object has no attribute rcqrt')
 
     def square(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`square`.
@@ -858,7 +854,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`square`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute square')
+        raise AttributeError('_Symbol object has no attribute square')
 
     def reciprocal(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`reciprocal`.
@@ -866,7 +862,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`reciprocal`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute reciprocal')
+        raise AttributeError('_Symbol object has no attribute reciprocal')
 
     def relu(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`relu`.
@@ -874,7 +870,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`relu`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute relu')
+        raise AttributeError('_Symbol object has no attribute relu')
 
     def sigmoid(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`sigmoid`.
@@ -882,7 +878,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`sigmoid`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute sigmoid')
+        raise AttributeError('_Symbol object has no attribute sigmoid')
 
     def softmax(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`softmax`.
@@ -890,7 +886,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`softmax`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute softmax')
+        raise AttributeError('_Symbol object has no attribute softmax')
 
     def log_softmax(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`log_softmax`.
@@ -898,7 +894,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`log_softmax`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute log_softmax')
+        raise AttributeError('_Symbol object has no attribute log_softmax')
 
     def softmin(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`softmin`.
@@ -906,7 +902,7 @@ class _NumpySymbol(Symbol):
         The arguments are the same as for :py:func:`softmin`, with
         this array as data.
         """
-        raise AttributeError('_NumpySymbol object has no attribute softmin')
+        raise AttributeError('_Symbol object has no attribute softmin')
 
     @use_np_compat
     def squeeze(self, *args, **kwargs):
@@ -918,12 +914,13 @@ class _NumpySymbol(Symbol):
         raise NotImplementedError
 
     def broadcast_to(self, *args, **kwargs):
-        raise AttributeError('_NumpySymbol object has no attribute broadcast_to')
+        raise AttributeError('_Symbol object has no attribute broadcast_to')
 
     def broadcast_like(self, *args, **kwargs):
-        raise AttributeError('_NumpySymbol object has no attribute broadcast_like')
+        raise AttributeError('_Symbol object has no attribute broadcast_like')
 
 
+@set_module('mxnet.symbol.numpy')
 @use_np_compat
 def zeros(shape, dtype=_np.float32, **kwargs):
     """Return a new array of given shape and type, filled with zeros.
@@ -952,9 +949,10 @@ def zeros(shape, dtype=_np.float32, **kwargs):
     if ctx is None:
         ctx = current_context()
     dtype = _np.float32 if dtype is None else dtype
-    return _internal._np_zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
+    return _npi.zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
 
 
+@set_module('mxnet.symbol.numpy')
 @use_np_compat
 def ones(shape, dtype=None, **kwargs):
     """Return a new array of given shape and type, filled with zeros.
@@ -983,7 +981,7 @@ def ones(shape, dtype=None, **kwargs):
     if ctx is None:
         ctx = current_context()
     dtype = _np.float32 if dtype is None else dtype
-    return _internal._np_ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
+    return _npi.ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
 
 
 #pylint: disable= too-many-arguments, no-member, protected-access
@@ -1035,16 +1033,16 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou
 #pylint: enable= too-many-arguments, no-member, protected-access
 
 
+@set_module('mxnet.symbol.numpy')
 @use_np_compat
 def maximum(x1, x2, out=None):
-    return _ufunc_helper(x1, x2, _internal._np_maximum, _np.maximum,
-                         _internal._np_maximum_scalar, None, out)
+    return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out)
 
 
+@set_module('mxnet.symbol.numpy')
 @use_np_compat
 def minimum(x1, x2, out=None):
-    return _ufunc_helper(x1, x2, _internal._np_minimum, _np.minimum,
-                         _internal._np_minimum_scalar, None, out)
+    return _ufunc_helper(x1, x2, _npi.minimum, _np.minimum, _npi.minimum_scalar, None, out)
 
 
-_set_np_symbol_class(_NumpySymbol)
+_set_np_symbol_class(_Symbol)
diff --git a/python/mxnet/symbol/numpy/ext.py b/python/mxnet/symbol/numpy/ext.py
deleted file mode 100644
index 12c5f15..0000000
--- a/python/mxnet/symbol/numpy/ext.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""numpy.ext namespace for operators used in Gluon APIs dispatched by F=symbol module."""
-
-__all__ = []
diff --git a/python/mxnet/symbol/numpy/linalg.py b/python/mxnet/symbol/numpy/linalg.py
index b8f10b3..869fdeb 100644
--- a/python/mxnet/symbol/numpy/linalg.py
+++ b/python/mxnet/symbol/numpy/linalg.py
@@ -15,6 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""numpy.linalg namespace for operators used in Gluon APIs dispatched by F=symbol module."""
+"""Namespace for operators used in Gluon dispatched by F=symbol."""
 
 __all__ = []
diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py
index 79c73d8..869fdeb 100644
--- a/python/mxnet/symbol/numpy/random.py
+++ b/python/mxnet/symbol/numpy/random.py
@@ -15,6 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""numpy.random namespace for operators used in Gluon APIs dispatched by F=symbol module."""
+"""Namespace for operators used in Gluon dispatched by F=symbol."""
 
 __all__ = []
diff --git a/python/mxnet/ndarray/numpy/__init__.py b/python/mxnet/symbol/numpy_extension/__init__.py
similarity index 88%
copy from python/mxnet/ndarray/numpy/__init__.py
copy to python/mxnet/symbol/numpy_extension/__init__.py
index d97e808..a718274 100644
--- a/python/mxnet/ndarray/numpy/__init__.py
+++ b/python/mxnet/symbol/numpy_extension/__init__.py
@@ -15,11 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""numpy module for numpy ops under mxnet.ndarray."""
+"""Module for the ops not belonging to the official numpy package."""
 
-from . import ext
-from . import random
-from . import linalg
 from . import _op
 from . import _register
 from ._op import *  # pylint: disable=wildcard-import
diff --git a/python/mxnet/symbol/numpy/_op.py b/python/mxnet/symbol/numpy_extension/_op.py
similarity index 86%
copy from python/mxnet/symbol/numpy/_op.py
copy to python/mxnet/symbol/numpy_extension/_op.py
index 96da828..82eaa8e 100644
--- a/python/mxnet/symbol/numpy/_op.py
+++ b/python/mxnet/symbol/numpy_extension/_op.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""numpy namespace for operators used in Gluon APIs dispatched by F=symbol module."""
+"""Namespace for operators not belonging to the official numpy package
+used in Gluon APIs dispatched by F=symbol module."""
 
 __all__ = []
diff --git a/python/mxnet/symbol/numpy/_register.py b/python/mxnet/symbol/numpy_extension/_register.py
similarity index 81%
copy from python/mxnet/symbol/numpy/_register.py
copy to python/mxnet/symbol/numpy_extension/_register.py
index 36dfd78..b118987 100644
--- a/python/mxnet/symbol/numpy/_register.py
+++ b/python/mxnet/symbol/numpy_extension/_register.py
@@ -15,9 +15,10 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""module for registering numpy ops under mxnet.symbol.numpy."""
+"""Registering numpy_extension ops."""
 
 from ...base import _init_np_op_module
 from ..register import _make_symbol_function
 
-_init_np_op_module('mxnet', 'symbol', _make_symbol_function)
+_init_np_op_module(root_module_name='mxnet', np_module_name='numpy_extension',
+                   mx_module_name='symbol', make_op_func=_make_symbol_function)
diff --git a/python/mxnet/symbol/register.py b/python/mxnet/symbol/register.py
index ac59f8b..a835e2e 100644
--- a/python/mxnet/symbol/register.py
+++ b/python/mxnet/symbol/register.py
@@ -27,12 +27,58 @@ from ._internal import SymbolBase, _symbol_creator
 from ..attribute import AttrScope
 from ..base import mx_uint, check_call, _LIB, py_str
 from ..symbol_doc import _build_doc
-from ..base import _Null, _init_op_module
+from ..base import _Null, _init_op_module, _is_np_op
 from ..name import NameManager
 # pylint: enable=unused-import
 
 
-def _generate_symbol_function_code(handle, name, func_name, signature_only=False):
+def _verify_np_symbol(op_name, func_name, sym):
+    """Verify if the sym is a numpy symbol.
+
+    Parameters
+    ----------
+    op_name : str
+        Operator full name registered in backend.
+    func_name : str
+        Operator name exposed to users. This is usually the name by stripping off
+        the prefix of the full operator names registered in backend.
+    sym : symbol to be verified
+    """
+    from .numpy._symbol import _Symbol as np_symbol
+    if not isinstance(sym, np_symbol):
+        raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. '
+                        'This is a numpy operator which can only accept '
+                        'MXNet numpy ndarrays, while received a classic ndarray. '
+                        'Please call `as_np_ndarray()` upon the classic ndarray to '
+                        'convert it to an MXNet numpy ndarray, and then feed the converted '
+                        'array to this operator.'
+                        .format(op_name, func_name))
+
+
+def _verify_classic_symbol(op_name, func_name, sym):
+    """Verify if the sym is a classic symbol.
+
+    Parameters
+    ----------
+    op_name : str
+        Operator full name registered in backend.
+    func_name : str
+        Operator name exposed to users. This is usually the name by stripping off
+        the prefix of the full operator names registered in backend.
+    sym : symbol to be verified
+    """
+    from .numpy._symbol import _Symbol as np_symbol
+    if isinstance(sym, np_symbol):
+        raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. '
+                        'This is a classic operator which can only accept '
+                        'classic ndarrays, while received an MXNet numpy ndarray. '
+                        'Please call `as_classic_ndarray()` upon the numpy ndarray to '
+                        'convert it to a classic ndarray, and then feed the converted '
+                        'array to this operator.'
+                        .format(op_name, func_name))
+
+
+def _generate_symbol_function_code(handle, op_name, func_name, signature_only=False):
     """Generate function for symbol op by handle and function name."""
     real_name = ctypes.c_char_p()
     desc = ctypes.c_char_p()
@@ -56,7 +102,7 @@ def _generate_symbol_function_code(handle, name, func_name, signature_only=False
     arg_types = [py_str(arg_types[i]) for i in range(narg)]
     key_var_num_args = py_str(key_var_num_args.value)
     ret_type = py_str(ret_type.value) if ret_type.value is not None else ''
-    doc_str = _build_doc(name,
+    doc_str = _build_doc(op_name,
                          py_str(desc.value),
                          arg_names,
                          arg_types,
@@ -95,6 +141,8 @@ def _generate_symbol_function_code(handle, name, func_name, signature_only=False
     signature.append('**kwargs')
     signature = ndsignature + signature
 
+    is_np_op = _is_np_op(op_name)
+    verify_symbol_fn = _verify_np_symbol.__name__ if is_np_op else _verify_classic_symbol.__name__
     code = []
     if arr_name:
         code.append("""
@@ -106,7 +154,8 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name))
         assert isinstance(i, SymbolBase), \\
             "Positional arguments must be Symbol instances, " \\
             "but got %s"%str(i)
-        sym_args.append(i)""".format(arr_name))
+        {}('{}', '{}', i)
+        sym_args.append(i)""".format(arr_name, verify_symbol_fn, op_name, func_name))
             if dtype_name is not None:
                 code.append("""
     if '%s' in kwargs:
@@ -128,9 +177,10 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name))
     for k, v in kwargs.items():
         if isinstance(v, SymbolBase):
             sym_kwargs[k] = v
+            %s('%s', '%s', v)
         else:
             keys.append(k)
-            vals.append(v)"""%(func_name.lower()))
+            vals.append(v)"""%(func_name.lower(), verify_symbol_fn, op_name, func_name))
             if key_var_num_args: # pylint: disable=using-constant-test
                 code.append("""
     if '%s' not in kwargs:
@@ -139,8 +189,8 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name))
             key_var_num_args, key_var_num_args))
 
             code.append("""
-    return _symbol_creator(%d, sym_args, sym_kwargs, keys, vals, name)"""%(
-        handle.value))
+    return _symbol_creator(%d, sym_args, sym_kwargs, keys, vals, name, %s)"""%(
+        handle.value, str(is_np_op)))
     else:
         code.append("""
 def %s(%s):"""%(func_name, ', '.join(signature)))
@@ -155,9 +205,10 @@ def %s(%s):"""%(func_name, ', '.join(signature)))
     for _k, _v in kwargs.items():
         if isinstance(_v, SymbolBase):
             sym_kwargs[_k] = _v
+            {}('{}', '{}', _v)
         else:
             _keys.append(_k)
-            _vals.append(_v)""")
+            _vals.append(_v)""".format(verify_symbol_fn, op_name, func_name))
             # NDArray args
             for name in ndarg_names: # pylint: disable=redefined-argument-from-local
                 code.append("""
@@ -165,6 +216,9 @@ def %s(%s):"""%(func_name, ', '.join(signature)))
         assert isinstance({name}, SymbolBase), \\
             "Argument {name} must be Symbol instances, but got %s"%str({name})
         sym_kwargs['{name}'] = {name}""".format(name=name))
+                code.append("""
+        {}('{}', '{}', {name})
+                """.format(verify_symbol_fn, op_name, func_name, name=name))
             # kwargs
             for name in kwarg_names: # pylint: disable=redefined-argument-from-local
                 code.append("""
@@ -182,8 +236,8 @@ def %s(%s):"""%(func_name, ', '.join(signature)))
     if not hasattr(NameManager._current, "value"):
         NameManager._current.value = NameManager()
     name = NameManager._current.value.get(name, '%s')
-    return _symbol_creator(%d, None, sym_kwargs, _keys, _vals, name)"""%(
-        func_name.lower(), handle.value))
+    return _symbol_creator(%d, None, sym_kwargs, _keys, _vals, name, %s)"""%(
+        func_name.lower(), handle.value, str(is_np_op)))
 
     if signature_only:
         code.append("""
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 7be042c..96397f6 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -62,15 +62,11 @@ class Symbol(SymbolBase):
     __array_priority__ = 1000.0
 
     def as_np_ndarray(self):
-        """Convert mxnet.symbol.Symbol to _NumpySymbol."""
-        from .numpy import _NumpySymbol
+        """Convert mx.sym.Symbol to mx.sym.np._Symbol."""
+        from .numpy import _Symbol
         hdl = SymbolHandle()
         check_call(_LIB.MXShallowCopySymbol(self.handle, ctypes.byref(hdl)))
-        return _NumpySymbol(hdl)
-
-    def _is_np_compat(self):
-        """Always returns False except for mxnet.symbol.numpy._NumpySymbol."""
-        return False
+        return _Symbol(hdl)
 
     def __repr__(self):
         """Gets a string representation of the symbol."""
@@ -110,8 +106,6 @@ class Symbol(SymbolBase):
         Scalar input is supported.
         Broadcasting is not supported. Use `broadcast_add` instead. """
         if isinstance(other, Symbol):
-            if other._is_np_compat():
-                return other.__add__(self)
             return _internal._Plus(self, other)
         if isinstance(other, Number):
             return _internal._PlusScalar(self, scalar=other)
@@ -127,8 +121,6 @@ class Symbol(SymbolBase):
         raise NotImplementedForSymbol(self.__iadd__, '+=', other, 1)
 
     def __radd__(self, other):
-        if isinstance(other, Symbol) and other._is_np_compat():
-            return other.__add__(self)
         return self.__add__(other)
 
     def __sub__(self, other):
@@ -137,8 +129,6 @@ class Symbol(SymbolBase):
         Scalar input is supported.
         Broadcasting is not supported. Use `broadcast_sub` instead. """
         if isinstance(other, Symbol):
-            if other._is_np_compat():
-                return other.__rsub__(self)
             return _internal._Minus(self, other)
         if isinstance(other, Number):
             return _internal._MinusScalar(self, scalar=other)
@@ -161,7 +151,7 @@ class Symbol(SymbolBase):
         array([[-2., -2., -2.],
                [-2., -2., -2.]], dtype=float32)
         """
-        if isinstance(other, Symbol) and other._is_np_compat():
+        if isinstance(other, Symbol):
             return other.__sub__(self)
         if isinstance(other, Number):
             return _internal._RMinusScalar(self, scalar=other)
@@ -174,8 +164,6 @@ class Symbol(SymbolBase):
         Scalar input is supported.
         Broadcasting is not supported. Use `broadcast_mul` instead. """
         if isinstance(other, Symbol):
-            if other._is_np_compat():
-                return other.__mul__(self)
             return _internal._Mul(self, other)
         if isinstance(other, Number):
             return _internal._MulScalar(self, scalar=other)
@@ -186,8 +174,6 @@ class Symbol(SymbolBase):
         raise NotImplementedForSymbol(self.__imul__, '*=', other)
 
     def __rmul__(self, other):
-        if isinstance(other, Symbol) and other._is_np_compat():
-            return other.__mul__(self)
         return self.__mul__(other)
 
     def __div__(self, other):
@@ -196,8 +182,6 @@ class Symbol(SymbolBase):
         Scalar input is supported.
         Broadcasting is not supported. Use `broadcast_div` instead. """
         if isinstance(other, Symbol):
-            if other._is_np_compat():
-                return other.__rtruediv__(self)
             return _internal._Div(self, other)
         if isinstance(other, Number):
             return _internal._DivScalar(self, scalar=other)
@@ -217,7 +201,7 @@ class Symbol(SymbolBase):
         array([[ 0.33333334,  0.33333334,  0.33333334],
                [ 0.33333334,  0.33333334,  0.33333334]], dtype=float32)
         """
-        if isinstance(other, Symbol) and other._is_np_compat():
+        if isinstance(other, Symbol):
             return other.__truediv__(self)
         if isinstance(other, Number):
             return _internal._RDivScalar(self, scalar=other)
@@ -230,8 +214,6 @@ class Symbol(SymbolBase):
         Scalar input is supported.
         Broadcasting is not supported. Use `broadcast_mod` instead. """
         if isinstance(other, Symbol):
-            if other._is_np_compat():
-                return other.__rmod__(self)
             return _internal._Mod(self, other)
         if isinstance(other, Number):
             return _internal._ModScalar(self, scalar=other)
@@ -251,7 +233,7 @@ class Symbol(SymbolBase):
         array([[ 1.,  1.,  1.,
                [ 1.,  1.,  1., dtype=float32)
         """
-        if isinstance(other, Symbol) and other._is_np_compat():
+        if isinstance(other, Symbol):
             return other.__mod__(self)
         if isinstance(other, Number):
             return _internal._RModScalar(self, scalar=other)
@@ -276,8 +258,6 @@ class Symbol(SymbolBase):
         Scalar input is supported.
         Broadcasting is not supported. Use `broadcast_pow` instead. """
         if isinstance(other, Symbol):
-            if other._is_np_compat():
-                return other.__rpow__(self)
             return _internal._Power(self, other)
         if isinstance(other, Number):
             return _internal._PowerScalar(self, scalar=other)
@@ -287,8 +267,6 @@ class Symbol(SymbolBase):
     def __rpow__(self, other):
         """x.__rpow__(y) <=> y ** x"""
         if isinstance(other, Symbol):
-            if other._is_np_compat():
-                return other.__pow__(self)
             return other.__pow__(self)
         elif isinstance(other, Number):
             return _internal._rpower_scalar(self, scalar=other)
@@ -348,8 +326,6 @@ class Symbol(SymbolBase):
         Scalar input is supported.
         Broadcasting is not supported. Use `broadcast_equal` instead. """
         if isinstance(other, Symbol):
-            if other._is_np_compat():
-                return other.__eq__(self)
             return _internal._equal(self, other)
         if isinstance(other, numeric_types):
             return _internal._equal_scalar(self, scalar=other)
@@ -362,8 +338,6 @@ class Symbol(SymbolBase):
         Scalar input is supported.
         Broadcasting is not supported. Use `broadcast_not_equal` instead. """
         if isinstance(other, Symbol):
-            if other._is_np_compat():
-                return other.__ne__(self)
             return _internal._not_equal(self, other)
         if isinstance(other, numeric_types):
             return _internal._not_equal_scalar(self, scalar=other)
@@ -376,8 +350,6 @@ class Symbol(SymbolBase):
         Scalar input is supported.
         Broadcasting is not supported. Use `broadcast_greater` instead. """
         if isinstance(other, Symbol):
-            if other._is_np_compat():
-                return other.__lt__(self)
             return _internal._greater(self, other)
         if isinstance(other, numeric_types):
             return _internal._greater_scalar(self, scalar=other)
@@ -390,8 +362,6 @@ class Symbol(SymbolBase):
         Scalar input is supported.
         Broadcasting is not supported. Use `broadcast_greater_equal` instead. """
         if isinstance(other, Symbol):
-            if other._is_np_compat():
-                return other.__le__(self)
             return _internal._greater_equal(self, other)
         if isinstance(other, numeric_types):
             return _internal._greater_equal_scalar(self, scalar=other)
@@ -404,8 +374,6 @@ class Symbol(SymbolBase):
         Scalar input is supported.
         Broadcasting is not supported. Use `broadcast_lesser` instead. """
         if isinstance(other, Symbol):
-            if other._is_np_compat():
-                return other.__gt__(self)
             return _internal._lesser(self, other)
         if isinstance(other, numeric_types):
             return _internal._lesser_scalar(self, scalar=other)
@@ -418,8 +386,6 @@ class Symbol(SymbolBase):
         Scalar input is supported.
         Broadcasting is not supported. Use `broadcast_lesser_equal` instead. """
         if isinstance(other, Symbol):
-            if other._is_np_compat():
-                return other.__ge__(self)
             return _internal._lesser_equal(self, other)
         if isinstance(other, numeric_types):
             return _internal._lesser_equal_scalar(self, scalar=other)
@@ -2720,8 +2686,12 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None,
 Variable = var
 
 
-def Group(symbols):
+def Group(symbols, create_fn=Symbol):
     """Creates a symbol that contains a collection of other symbols, grouped together.
+    A classic symbol (`mx.sym.Symbol`) will be returned if all the symbols in the list
+    are of that type; a numpy symbol (`mx.sym.np._Symbol`) will be returned if all the
+    symbols in the list are of that type. A type error will be raised if a list of mixed
+    classic and numpy symbols are provided.
 
     Example
     -------
@@ -2735,6 +2705,9 @@ def Group(symbols):
     symbols : list
         List of symbols to be grouped.
 
+    create_fn : mx.sym.Symbol or mx.sym.np._Symbol
+        Symbol class for creating the grouped symbol.
+
     Returns
     -------
     sym : Symbol
@@ -2746,7 +2719,7 @@ def Group(symbols):
     check_call(_LIB.MXSymbolCreateGroup(
         mx_uint(len(symbols)),
         c_handle_array(symbols), ctypes.byref(handle)))
-    return Symbol(handle)
+    return create_fn(handle)
 
 
 def load(fname):
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index 2d10caf..7786cdb 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -47,6 +47,7 @@ from .context import Context, current_context
 from .ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
 from .ndarray import array
 from .symbol import Symbol
+from .symbol.numpy import _Symbol as np_symbol
 
 
 def default_context():
@@ -945,7 +946,12 @@ def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rto
     input_shape = {k: v.shape for k, v in location.items()}
     _, out_shape, _ = sym.infer_shape(**input_shape)
     proj = mx.sym.Variable("__random_proj")
+    is_np_sym = True if isinstance(sym, np_symbol) else False
+    if is_np_sym:  # convert to np symbol for using element-wise multiplication
+        proj = proj.as_np_ndarray()
     out = sym * proj
+    if is_np_sym:  # convert to classic symbol so that make_loss can be used
+        out = out.as_classic_ndarray()
     out = mx.sym.make_loss(out)
 
     location = dict(list(location.items()) +
diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h
index 82fe28b..233acc8 100644
--- a/src/c_api/c_api_common.h
+++ b/src/c_api/c_api_common.h
@@ -163,21 +163,4 @@ inline void CopyAttr(const nnvm::IndexedGraph& idx,
 extern const std::vector<std::string> kHiddenKeys;
 }  // namespace mxnet
 
-/*!
- * An operator is considered as numpy compatible if it satisfies either one
- * of the following conditions.
- * 1. The op has the attribute mxnet::TIsNumpyCompatible> registered as True.
- * 2. The op's name starts with the prefix _numpy_.
- * The first condition is usually for the ops registered as internal ops, such
- * as _np_add, _true_divide, etc. They are wrapped by some user-facing op
- * APIs in the Python end.
- * The second condition is for the ops registered in the backend while exposed
- * directly to users as is, such as _numpy_sum etc.
- */
-inline bool IsNumpyCompatOp(const nnvm::Op* op) {
-  static const auto& is_np_compat =
-      nnvm::Op::GetAttr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible");
-  return is_np_compat.get(op, false);
-}
-
 #endif  // MXNET_C_API_C_API_COMMON_H_
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index f65c804..c9c6000 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -378,19 +378,3 @@ int MXAutogradGetSymbol(NDArrayHandle handle, SymbolHandle *out) {
   *out = reinterpret_cast<SymbolHandle>(sym);
   API_END();
 }
-
-int MXIsCachedOpOutputFromNumpyCompatOp(CachedOpHandle handle,
-                                        int output_idx,
-                                        int* is_from_np_op) {
-  API_BEGIN();
-  CachedOpPtr op = *static_cast<CachedOpPtr*>(handle);
-  const auto& output_entries = op->GetForwardSym().outputs;
-  CHECK_LT(output_idx, static_cast<int>(output_entries.size()));
-  const nnvm::NodePtr& node_ptr = output_entries[output_idx].node;
-  if (node_ptr->is_variable()) {
-    *is_from_np_op = 0;
-  } else {
-    *is_from_np_op = (IsNumpyCompatOp(node_ptr->op()) ? 1 : 0);
-  }
-  API_END();
-}
diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h
index 2c4d579..0f3d71d 100644
--- a/src/operator/numpy/np_broadcast_reduce_op.h
+++ b/src/operator/numpy/np_broadcast_reduce_op.h
@@ -169,6 +169,7 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
   if (param.initial.has_value()) {
     LOG(FATAL) << "initial is not supported yet";
   }
+  if (outputs[0].shape_.Size() == 0U) return;  // zero-size tensor
   if (param.axis.has_value() && param.axis.value().ndim() == 0) {
     UnaryOp::IdentityCompute<xpu>(attrs, ctx, inputs, req, outputs);
   }
diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc
index c1c1132..a72efd9 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_value.cc
+++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc
@@ -47,7 +47,7 @@ inline bool NumpySumType(const nnvm::NodeAttrs& attrs,
   return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
 }
 
-NNVM_REGISTER_OP(_numpy_sum)
+NNVM_REGISTER_OP(_np_sum)
 .describe(R"code()code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs(1)
@@ -61,14 +61,13 @@ NNVM_REGISTER_OP(_numpy_sum)
 .add_argument("a", "NDArray-or-Symbol", "The input")
 .add_arguments(NumpyReduceAxesParam::__FIELDS__())
 .set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesCompute<cpu, mshadow_op::sum, true>)
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
 .set_attr<FResourceRequest>("FResourceRequest",
   [](const NodeAttrs& attrs) {
     return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
   })
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_numpy_sum"});
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_np_sum"});
 
-NNVM_REGISTER_OP(_backward_numpy_sum)
+NNVM_REGISTER_OP(_backward_np_sum)
 .set_num_outputs(1)
 .set_attr_parser(ParamParser<NumpyReduceAxesParam>)
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
@@ -102,7 +101,7 @@ inline bool NumpyMeanType(const nnvm::NodeAttrs& attrs,
   return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
 }
 
-NNVM_REGISTER_OP(_numpy_mean)
+NNVM_REGISTER_OP(_np_mean)
 .describe(R"code()code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs(1)
@@ -116,14 +115,13 @@ NNVM_REGISTER_OP(_numpy_mean)
 .add_argument("a", "NDArray-or-Symbol", "The input")
 .add_arguments(NumpyReduceAxesParam::__FIELDS__())
 .set_attr<FCompute>("FCompute<cpu>", NumpyReduceAxesCompute<cpu, mshadow_op::sum, true, true>)
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
 .set_attr<FResourceRequest>("FResourceRequest",
   [](const NodeAttrs& attrs) {
     return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
   })
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_numpy_mean"});
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_np_mean"});
 
-NNVM_REGISTER_OP(_backward_numpy_mean)
+NNVM_REGISTER_OP(_backward_np_mean)
 .set_num_outputs(1)
 .set_attr_parser(ParamParser<NumpyReduceAxesParam>)
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cu b/src/operator/numpy/np_broadcast_reduce_op_value.cu
index f16745d..2f50738 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_value.cu
+++ b/src/operator/numpy/np_broadcast_reduce_op_value.cu
@@ -27,16 +27,16 @@
 namespace mxnet {
 namespace op {
 
-NNVM_REGISTER_OP(_numpy_sum)
+NNVM_REGISTER_OP(_np_sum)
 .set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesCompute<gpu, mshadow_op::sum, true>);
 
-NNVM_REGISTER_OP(_backward_numpy_sum)
+NNVM_REGISTER_OP(_backward_np_sum)
 .set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBackwardUseNone<gpu>);
 
-NNVM_REGISTER_OP(_numpy_mean)
+NNVM_REGISTER_OP(_np_mean)
 .set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesCompute<gpu, mshadow_op::sum, true, true>);
 
-NNVM_REGISTER_OP(_backward_numpy_mean)
+NNVM_REGISTER_OP(_backward_np_mean)
 .set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBackwardUseNone<gpu, true>);
 
 
diff --git a/src/operator/numpy/np_dot-inl.h b/src/operator/numpy/np_dot-inl.h
index 8fc7d5d..2f7c589 100644
--- a/src/operator/numpy/np_dot-inl.h
+++ b/src/operator/numpy/np_dot-inl.h
@@ -95,6 +95,7 @@ inline void NumpyDotForward(const nnvm::NodeAttrs& attrs,
   const TBlob& a = inputs[0];
   const TBlob& b = inputs[1];
   const TBlob& out = outputs[0];
+  if (out.shape_.Size() == 0U) return;  // zero-size tensor, no need to launch kernel
   const mxnet::TShape a_shape = a.shape_;
   const mxnet::TShape b_shape = b.shape_;
 
@@ -107,7 +108,13 @@ inline void NumpyDotForward(const nnvm::NodeAttrs& attrs,
       (out.type_flag_ == kFloat16 && ctx.run_ctx.ctx.dev_mask() == mshadow::gpu::kDevMask))
       << "dot only supports float32/float64 for CPU, and float16/float32/float64 for GPU";
   MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, DType, {
-    if (a_shape.ndim() == 1 && b_shape.ndim() == 1) {
+    if (a_shape.Size() == 0U || b_shape.Size() == 0U) {
+      if (req[0] != kAddTo) {
+        Tensor<xpu, 1, DType> out_data = out.get_with_shape<xpu, 1, DType>(
+            Shape1(out.shape_.Size()), s);
+        out_data = static_cast<DType>(0);
+      }
+    } else if (a_shape.ndim() == 1 && b_shape.ndim() == 1) {
       // Case 1: both 1-D arrays, inner product of vectors
       if (out.type_flag_ == kFloat16) {
         MMImpl<xpu>(ctx, a, b, out, req[0]);
@@ -158,12 +165,14 @@ inline void NumpyDotBackward(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(outputs.size(), 2U);
 
   const TBlob& ograd = inputs[0];
+  if (ograd.shape_.Size() == 0U) return;
   const TBlob& a = inputs[1];
   const TBlob& b = inputs[2];
   const TBlob& grad_a = outputs[0];
   const TBlob& grad_b = outputs[1];
   const mxnet::TShape a_shape = a.shape_;
   const mxnet::TShape b_shape = b.shape_;
+  if (a_shape.Size() == 0U || b_shape.Size() == 0U) return;
 
   Stream<xpu> *s = ctx.get_stream<xpu>();
   MSHADOW_REAL_TYPE_SWITCH(ograd.type_flag_, DType, {
diff --git a/src/operator/numpy/np_dot.cc b/src/operator/numpy/np_dot.cc
index c25953f..bcb310f 100644
--- a/src/operator/numpy/np_dot.cc
+++ b/src/operator/numpy/np_dot.cc
@@ -71,7 +71,7 @@ inline bool NumpyDotShape(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
-NNVM_REGISTER_OP(_numpy_dot)
+NNVM_REGISTER_OP(_np_dot)
 .describe(R"doc(Dot product of two arrays. Specifically,
 
 - If both a and b are 1-D arrays, it is inner product of vectors.
diff --git a/src/operator/numpy/np_dot.cu b/src/operator/numpy/np_dot.cu
index 2accd9d..9a9c69a 100644
--- a/src/operator/numpy/np_dot.cu
+++ b/src/operator/numpy/np_dot.cu
@@ -27,7 +27,7 @@
 namespace mxnet {
 namespace op {
 
-NNVM_REGISTER_OP(_numpy_dot)
+NNVM_REGISTER_OP(_np_dot)
 .set_attr<FCompute>("FCompute<gpu>", NumpyDotForward<gpu>);
 
 NNVM_REGISTER_OP(_backward_np_dot)
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc
index 5d36c29..2ffa3b8 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.cc
+++ b/src/operator/numpy/np_elemwise_broadcast_op.cc
@@ -57,12 +57,11 @@ bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs,
     [](const NodeAttrs& attrs){                                     \
       return std::vector<std::pair<int, int> >{{0, 0}};             \
     })                                                              \
-  .set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)  \
   .add_argument("data", "NDArray-or-Symbol", "source input")        \
   .add_argument("scalar", "float", "scalar input")
 
 
-MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_add)
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_add)
 .describe(R"code(Add arguments element-wise with broadcasting if necessary.
 
 Example::
@@ -78,10 +77,9 @@ Example::
 
 )code" ADD_FILELINE)
 .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::plus>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"});
 
-MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_subtract)
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_subtract)
 .describe(R"code(Subtract arguments element-wise with broadcasting if necessary.
 
 Example::
@@ -97,10 +95,9 @@ Example::
 
 )code" ADD_FILELINE)
 .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::minus>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"});
 
-MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_multiply)
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_multiply)
 .describe(R"code(Multiply arguments with broadcasting if necessary.
 
 Example::
@@ -116,10 +113,9 @@ Example::
 
 )code" ADD_FILELINE)
 .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::mul>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"});
 
-MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_mod)
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod)
 .describe(R"code(Return element-wise remainder of division.
 It is equivalent to the Python modulus operator``x1 % x2`` and has the same sign as the divisor x2.
 
@@ -136,10 +132,9 @@ Example::
 
 )code" ADD_FILELINE)
 .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::mod>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mod"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mod"});
 
-MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_power)
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_power)
 .describe(R"code(First array elements raised to powers from second array, element-wise.
 
 Raise each base in x1 to the positionally-corresponding power in x2. x1 and x2 must be
@@ -158,56 +153,53 @@ Example::
 
 )code" ADD_FILELINE)
 .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::power>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_power"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_power"});
 
-MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_maximum)
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_maximum)
 .describe(R"code()code" ADD_FILELINE)
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::maximum>)
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::maximum>);
 
-MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_minimum)
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_minimum)
 .describe(R"code()code" ADD_FILELINE)
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::minimum>)
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::minimum>);
 
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_add_scalar)
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_add_scalar)
 .set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::plus>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"});
 
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_subtract_scalar)
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_subtract_scalar)
 .set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::minus>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"});
 
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_rsubtract_scalar)
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rsubtract_scalar)
 .set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rminus>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"negative"});
 
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_multiply_scalar)
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_multiply_scalar)
 .set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::mul>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_mul_scalar"});
 
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_mod_scalar)
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_mod_scalar)
 .set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::mod>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_mod_scalar"});
 
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_rmod_scalar)
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rmod_scalar)
 .set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rmod>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_rmod_scalar"});
 
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_power_scalar)
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_power_scalar)
 .set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::power>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_power_scalar"});
 
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_rpower_scalar)
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rpower_scalar)
 .set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rpower>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_rpower_scalar"});
 
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_maximum_scalar)
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_maximum_scalar)
 .set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::maximum>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_maximum_scalar"});
 
-MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_minimum_scalar)
+MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_minimum_scalar)
 .set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::minimum>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_minimum_scalar"});
 
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu
index 26e2fce..c858b3a 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.cu
+++ b/src/operator/numpy/np_elemwise_broadcast_op.cu
@@ -27,55 +27,55 @@
 
 namespace mxnet {
 namespace op {
-NNVM_REGISTER_OP(_np_add)
+NNVM_REGISTER_OP(_npi_add)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::plus>);
 
-NNVM_REGISTER_OP(_np_subtract)
+NNVM_REGISTER_OP(_npi_subtract)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::minus>);
 
-NNVM_REGISTER_OP(_np_multiply)
+NNVM_REGISTER_OP(_npi_multiply)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::mul>);
 
-NNVM_REGISTER_OP(_np_mod)
+NNVM_REGISTER_OP(_npi_mod)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::mod>);
 
-NNVM_REGISTER_OP(_np_power)
+NNVM_REGISTER_OP(_npi_power)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::power>);
 
-NNVM_REGISTER_OP(_np_maximum)
+NNVM_REGISTER_OP(_npi_maximum)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::maximum>);
 
-NNVM_REGISTER_OP(_np_minimum)
+NNVM_REGISTER_OP(_npi_minimum)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::minimum>);
 
-NNVM_REGISTER_OP(_np_add_scalar)
+NNVM_REGISTER_OP(_npi_add_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::plus>);
 
-NNVM_REGISTER_OP(_np_subtract_scalar)
+NNVM_REGISTER_OP(_npi_subtract_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::minus>);
 
-NNVM_REGISTER_OP(_np_rsubtract_scalar)
+NNVM_REGISTER_OP(_npi_rsubtract_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rminus>);
 
-NNVM_REGISTER_OP(_np_multiply_scalar)
+NNVM_REGISTER_OP(_npi_multiply_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::mul>);
 
-NNVM_REGISTER_OP(_np_mod_scalar)
+NNVM_REGISTER_OP(_npi_mod_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::mod>);
 
-NNVM_REGISTER_OP(_np_rmod_scalar)
+NNVM_REGISTER_OP(_npi_rmod_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rmod>);
 
-NNVM_REGISTER_OP(_np_power_scalar)
+NNVM_REGISTER_OP(_npi_power_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::power>);
 
-NNVM_REGISTER_OP(_np_rpower_scalar)
+NNVM_REGISTER_OP(_npi_rpower_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rpower>);
 
-NNVM_REGISTER_OP(_np_maximum_scalar)
+NNVM_REGISTER_OP(_npi_maximum_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::maximum>);
 
-NNVM_REGISTER_OP(_np_minimum_scalar)
+NNVM_REGISTER_OP(_npi_minimum_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::minimum>);
 
 }  // namespace op
diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cc b/src/operator/numpy/np_elemwise_unary_op_basic.cc
index f31ed5e..a64356e 100644
--- a/src/operator/numpy/np_elemwise_unary_op_basic.cc
+++ b/src/operator/numpy/np_elemwise_unary_op_basic.cc
@@ -27,7 +27,7 @@
 namespace mxnet {
 namespace op {
 
-MXNET_OPERATOR_REGISTER_UNARY(_numpy__ext_relu)
+MXNET_OPERATOR_REGISTER_UNARY(_npe_relu)
 .describe(R"code(Computes rectified linear activation.
 
 .. math::
@@ -35,10 +35,9 @@ MXNET_OPERATOR_REGISTER_UNARY(_numpy__ext_relu)
 
 )code" ADD_FILELINE)
 .set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::relu>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_relu"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_relu"});
 
-MXNET_OPERATOR_REGISTER_UNARY(_numpy__ext_sigmoid)
+MXNET_OPERATOR_REGISTER_UNARY(_npe_sigmoid)
 .describe(R"code(Computes sigmoid of x element-wise.
 
 .. math::
@@ -46,18 +45,29 @@ MXNET_OPERATOR_REGISTER_UNARY(_numpy__ext_sigmoid)
 
 )code" ADD_FILELINE)
 .set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::sigmoid>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_sigmoid"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_sigmoid"});
 
-MXNET_OPERATOR_REGISTER_UNARY(_np_copy)
-.MXNET_DESCRIBE("Returns a copy of the input.")
+NNVM_REGISTER_OP(_np_copy)
+.describe(R"code(Return an array copy of the given object.)code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs){
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
 .set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
 .set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity",
   [](const NodeAttrs& attrs){
     return std::vector<bool>{true};
   })
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true);
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"a"};
+  })
+.add_argument("a", "NDArray-or-Symbol", "The input");
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cu b/src/operator/numpy/np_elemwise_unary_op_basic.cu
index 9f108f7..600f198 100644
--- a/src/operator/numpy/np_elemwise_unary_op_basic.cu
+++ b/src/operator/numpy/np_elemwise_unary_op_basic.cu
@@ -26,10 +26,10 @@
 namespace mxnet {
 namespace op {
 
-NNVM_REGISTER_OP(_numpy__ext_relu)
+NNVM_REGISTER_OP(_npe_relu)
 .set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::relu>);
 
-NNVM_REGISTER_OP(_numpy__ext_sigmoid)
+NNVM_REGISTER_OP(_npe_sigmoid)
 .set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, mshadow_op::sigmoid>);
 
 NNVM_REGISTER_OP(_np_copy)
diff --git a/src/operator/numpy/np_init_op.cc b/src/operator/numpy/np_init_op.cc
index 0abd010..83a44c8 100644
--- a/src/operator/numpy/np_init_op.cc
+++ b/src/operator/numpy/np_init_op.cc
@@ -28,7 +28,7 @@
 namespace mxnet {
 namespace op {
 
-NNVM_REGISTER_OP(_np_zeros)
+NNVM_REGISTER_OP(_npi_zeros)
 .describe("Return a new array of given shape, type, and context, filled with zeros.")
 .set_num_inputs(0)
 .set_num_outputs(1)
@@ -37,10 +37,9 @@ NNVM_REGISTER_OP(_np_zeros)
 .set_attr<nnvm::FInferType>("FInferType", InitType<InitOpParam>)
 .set_attr<FInferStorageType>("FInferStorageType", InitStorageType<InitOpParam, true, true>)
 .set_attr<FCompute>("FCompute<cpu>", FillCompute<cpu, 0>)
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
 .add_arguments(InitOpParam::__FIELDS__());
 
-NNVM_REGISTER_OP(_np_ones)
+NNVM_REGISTER_OP(_npi_ones)
 .describe("Return a new array of given shape, type, and context, filled with ones.")
 .set_num_inputs(0)
 .set_num_outputs(1)
@@ -48,8 +47,65 @@ NNVM_REGISTER_OP(_np_ones)
 .set_attr<mxnet::FInferShape>("FInferShape", InitShape<InitOpParam>)
 .set_attr<nnvm::FInferType>("FInferType", InitType<InitOpParam>)
 .set_attr<FCompute>("FCompute<cpu>", FillCompute<cpu, 1>)
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
 .add_arguments(InitOpParam::__FIELDS__());
 
+NNVM_REGISTER_OP(_np_zeros_like)
+.describe(R"code(Return an array of zeros with the same shape and type as a given array.
+
+Examples::
+
+  x = [[ 1.,  1.,  1.],
+       [ 1.,  1.,  1.]]
+
+  zeros_like(x) = [[ 0.,  0.,  0.],
+                   [ 0.,  0.,  0.]]
+
+)code")
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<nnvm::FIgnoreInputs>("FIgnoreInputs",
+  [](const NodeAttrs& attrs) {
+    return std::vector<uint32_t>(1, 0);
+  })
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"a"};
+  })
+.set_attr<FCompute>("FCompute<cpu>", FillCompute<cpu, 0>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.add_argument("a", "NDArray-or-Symbol",
+              "The shape and data-type of a define these same attributes of the returned array.");
+
+NNVM_REGISTER_OP(_np_ones_like)
+.describe(R"code(Return an array of ones with the same shape and type as a given array.
+
+Examples::
+
+  x = [[ 0.,  0.,  0.],
+       [ 0.,  0.,  0.]]
+
+  ones_like(x) = [[ 1.,  1.,  1.],
+                  [ 1.,  1.,  1.]]
+
+)code")
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<nnvm::FIgnoreInputs>("FIgnoreInputs",
+  [](const NodeAttrs& attrs) {
+    return std::vector<uint32_t>(1, 0);
+  })
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"a"};
+  })
+.set_attr<FCompute>("FCompute<cpu>", FillCompute<cpu, 1>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.add_argument("a", "NDArray-or-Symbol",
+              "The shape and data-type of a define these same attributes of the returned array.");
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/numpy/np_init_op.cu b/src/operator/numpy/np_init_op.cu
index 4e6f81d..2eb8ed6 100644
--- a/src/operator/numpy/np_init_op.cu
+++ b/src/operator/numpy/np_init_op.cu
@@ -28,10 +28,16 @@
 namespace mxnet {
 namespace op {
 
-NNVM_REGISTER_OP(_np_zeros)
+NNVM_REGISTER_OP(_npi_zeros)
 .set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 0>);
 
-NNVM_REGISTER_OP(_np_ones)
+NNVM_REGISTER_OP(_npi_ones)
+.set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 1>);
+
+NNVM_REGISTER_OP(_np_zeros_like)
+.set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 0>);
+
+NNVM_REGISTER_OP(_np_ones_like)
 .set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 1>);
 
 }  // namespace op
diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc
index 215b1c5..6e93442 100644
--- a/src/operator/numpy/np_matrix_op.cc
+++ b/src/operator/numpy/np_matrix_op.cc
@@ -54,7 +54,7 @@ bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs,
   return shape_is_known(ret);
 }
 
-NNVM_REGISTER_OP(_numpy_transpose)
+NNVM_REGISTER_OP(_np_transpose)
 .describe(R"code(Permute the dimensions of an array.
 
 Examples::
@@ -105,7 +105,6 @@ Examples::
     }
   })
 .set_attr<FCompute>("FCompute<cpu>", NumpyTranspose<cpu>)
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
 .set_attr<nnvm::FListInputNames>("FListInputNames",
   [](const NodeAttrs& attrs) {
     return std::vector<std::string>{"a"};
@@ -189,7 +188,7 @@ bool NumpyReshapeShape(const nnvm::NodeAttrs& attrs,
   return success;
 }
 
-NNVM_REGISTER_OP(_numpy_reshape)
+NNVM_REGISTER_OP(_np_reshape)
 .describe(R"code()code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs(1)
@@ -210,7 +209,6 @@ NNVM_REGISTER_OP(_numpy_reshape)
   [](const NodeAttrs& attrs) {
     return std::vector<std::string>{"a"};
   })
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
 .add_argument("a", "NDArray-or-Symbol", "Array to be reshaped.")
 .add_arguments(NumpyReshapeParam::__FIELDS__());
 
diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu
index 9753566..5bf36e5 100644
--- a/src/operator/numpy/np_matrix_op.cu
+++ b/src/operator/numpy/np_matrix_op.cu
@@ -27,10 +27,10 @@
 namespace mxnet {
 namespace op {
 
-NNVM_REGISTER_OP(_numpy_transpose)
+NNVM_REGISTER_OP(_np_transpose)
 .set_attr<FCompute>("FCompute<gpu>", NumpyTranspose<gpu>);
 
-NNVM_REGISTER_OP(_numpy_reshape)
+NNVM_REGISTER_OP(_np_reshape)
 .set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
 
 }  // namespace op
diff --git a/src/operator/numpy/np_true_divide.cc b/src/operator/numpy/np_true_divide.cc
index 3bafa26..4297627 100644
--- a/src/operator/numpy/np_true_divide.cc
+++ b/src/operator/numpy/np_true_divide.cc
@@ -54,7 +54,7 @@ bool TrueDivideType(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
-NNVM_REGISTER_OP(_true_divide)
+NNVM_REGISTER_OP(_npi_true_divide)
 .describe(R"code(
 Returns a true division of the inputs, element-wise.
 
@@ -86,11 +86,10 @@ Example::
   })
 .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::div>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_div"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
 .add_argument("lhs", "NDArray-or-Symbol", "Dividend array")
 .add_argument("rhs", "NDArray-or-Symbol", "Divisor array");
 
-NNVM_REGISTER_OP(_true_divide_scalar)
+NNVM_REGISTER_OP(_npi_true_divide_scalar)
 .set_num_inputs(1)
 .set_num_outputs(1)
 .set_attr_parser([](NodeAttrs* attrs) {
@@ -104,11 +103,10 @@ NNVM_REGISTER_OP(_true_divide_scalar)
   })
 .set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::div>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_div_scalar"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
 .add_argument("data", "NDArray-or-Symbol", "source input")
 .add_argument("scalar", "float", "scalar input");
 
-NNVM_REGISTER_OP(_rtrue_divide_scalar)
+NNVM_REGISTER_OP(_npi_rtrue_divide_scalar)
 .set_num_inputs(1)
 .set_num_outputs(1)
 .set_attr_parser([](NodeAttrs* attrs) {
@@ -122,7 +120,6 @@ NNVM_REGISTER_OP(_rtrue_divide_scalar)
   })
 .set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rdiv>)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_rdiv_scalar"})
-.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true)
 .add_argument("data", "NDArray-or-Symbol", "source input")
 .add_argument("scalar", "float", "scalar input");
 
diff --git a/src/operator/numpy/np_true_divide.cu b/src/operator/numpy/np_true_divide.cu
index cbc7cf9..be10c44 100644
--- a/src/operator/numpy/np_true_divide.cu
+++ b/src/operator/numpy/np_true_divide.cu
@@ -28,13 +28,13 @@
 namespace mxnet {
 namespace op {
 
-NNVM_REGISTER_OP(_true_divide)
+NNVM_REGISTER_OP(_npi_true_divide)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::div>);
 
-NNVM_REGISTER_OP(_true_divide_scalar)
+NNVM_REGISTER_OP(_npi_true_divide_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::div>);
 
-NNVM_REGISTER_OP(_rtrue_divide_scalar)
+NNVM_REGISTER_OP(_npi_rtrue_divide_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rdiv>);
 
 }  // namespace op
diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py
index 141d153..eb45234 100644
--- a/tests/python/unittest/test_numpy_ndarray.py
+++ b/tests/python/unittest/test_numpy_ndarray.py
@@ -20,7 +20,7 @@ from __future__ import absolute_import
 from __future__ import division
 import numpy as _np
 import mxnet as mx
-from mxnet import numpy as np
+from mxnet import np
 from mxnet.gluon import HybridBlock
 from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, assert_exception
 from common import with_seed
@@ -37,15 +37,15 @@ def test_array_creation():
             mx_arr = np.array(src, dtype=dtype)
             assert mx_arr.context == mx.current_context()
             if isinstance(src, mx.nd.NDArray):
-                np_arr = _np.array(src.asnumpy(), dtype=dtype)
+                np_arr = _np.array(src.asnumpy(), dtype=dtype if dtype is not None else _np.float32)
             else:
-                np_arr = _np.array(src, dtype=dtype)
-            assert same(mx_arr.asnumpy(), np_arr)
+                np_arr = _np.array(src, dtype=dtype if dtype is not None else _np.float32)
             assert mx_arr.dtype == np_arr.dtype
+            assert same(mx_arr.asnumpy(), np_arr)
 
 
 @with_seed()
-@mx.use_np_compat
+@np.use_np_compat
 def test_zeros():
     # test np.zeros in Gluon
     class TestZeros(HybridBlock):
@@ -76,7 +76,7 @@ def test_zeros():
     for shape in shapes:
         for dtype in dtypes:
             check_zero_array_creation(shape, dtype)
-            x = mx.nd.array(_np.random.uniform(size=shape), dtype=dtype)
+            x = np.array(_np.random.uniform(size=shape), dtype=dtype)
             if dtype is None:
                 x = x.astype('float32')
             for hybridize in [True, False]:
@@ -93,7 +93,7 @@ def test_zeros():
 
 
 @with_seed()
-@mx.use_np_compat
+@np.use_np_compat
 def test_ones():
     # test np.ones in Gluon
     class TestOnes(HybridBlock):
@@ -141,7 +141,7 @@ def test_ones():
 
 
 @with_seed()
-@mx.use_np_compat
+@np.use_np_compat
 def test_ndarray_binary_element_wise_ops():
     # Cannot test operators like >, because boolean arrays are not supported yet.
     np_op_map = {'+': _np.add, '*': _np.multiply, '-': _np.subtract, '/': _np.divide,
@@ -241,23 +241,22 @@ def test_ndarray_binary_element_wise_ops():
         np_out = get_np_ret(np_input1, np_input2, op)
         for hybridize in [True, False]:
             if scalar is None:
-                get_mx_ret = TestBinaryElementWiseOp(op)
+                get_mx_ret_np = TestBinaryElementWiseOp(op)
+                get_mx_ret_classic = TestBinaryElementWiseOp(op)
                 if hybridize:
-                    get_mx_ret.hybridize()
-                mx_out = get_mx_ret(mx_input1.as_np_ndarray(), mx_input2.as_np_ndarray())
+                    get_mx_ret_np.hybridize()
+                    get_mx_ret_classic.hybridize()
+                mx_out = get_mx_ret_np(mx_input1.as_np_ndarray(), mx_input2.as_np_ndarray())
                 assert type(mx_out) == np.ndarray
                 assert np_out.shape == mx_out.shape
                 assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5)
 
-                mx_out = get_mx_ret(mx_input1, mx_input2.as_np_ndarray())
-                assert type(mx_out) == np.ndarray
-                assert np_out.shape == mx_out.shape
-                assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5)
-
-                mx_out = get_mx_ret(mx_input1.as_np_ndarray(), mx_input2)
-                assert type(mx_out) == np.ndarray
-                assert np_out.shape == mx_out.shape
-                assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5)
+                if mx_input1.shape == mx_input2.shape:
+                    # classic symbol does not support element-wise binary broadcast.
+                    mx_out = get_mx_ret_classic(mx_input1, mx_input2)
+                    assert type(mx_out) == mx.nd.NDArray
+                    assert np_out.shape == mx_out.shape
+                    assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5)
             else:
                 get_mx_ret = TestBinaryElementWiseOp(op, scalar=scalar, reverse=reverse)
                 if hybridize:
@@ -291,29 +290,42 @@ def test_ndarray_binary_element_wise_ops():
 
 
 @with_seed()
-def test_np_op_output_type():
-    # test imperative invoke
-    data = np.array([1., 3.], dtype='float32')
-    ret = np.sum(data)
-    assert type(ret) == np.ndarray
-    ret = mx.nd.sin(data)
-    assert type(ret) == mx.nd.NDArray
-
-    # test cached op
-    class TestCachedOpOutputType(HybridBlock):
-        @mx.use_np_compat
+def test_hybrid_block_multiple_outputs():
+    class TestAllNumpyOutputs(HybridBlock):
+        @np.use_np_compat
+        def hybrid_forward(self, F, x, *args, **kwargs):
+            return F.npe.relu(x), F.np.sum(x)
+
+    class TestAllClassicOutputs(HybridBlock):
+        @np.use_np_compat
+        def hybrid_forward(self, F, x, *args, **kwargs):
+            return F.relu(x.as_classic_ndarray()), F.sum(x.as_classic_ndarray())
+
+    class TestMixedTypeOutputsSuccess(HybridBlock):
+        @np.use_np_compat
+        def hybrid_forward(self, F, x, *args, **kwargs):
+            return F.relu(x.as_classic_ndarray()).as_np_ndarray(), F.np.sum(x)
+
+    data_np = np.ones((2, 3))
+    for block, expected_out_type in [(TestAllClassicOutputs, mx.nd.NDArray),
+                                      (TestAllNumpyOutputs, np.ndarray),
+                                      (TestMixedTypeOutputsSuccess, np.ndarray)]:
+        net = block()
+        for hybridize in [True, False]:
+            if hybridize:
+                net.hybridize()
+            out1, out2 = net(data_np)
+            assert type(out1) is expected_out_type
+            assert type(out2) is expected_out_type
+
+    class TestMixedTypeOutputsFailure(HybridBlock):
+        @np.use_np_compat
         def hybrid_forward(self, F, x, *args, **kwargs):
-            ret1 = F.sin(x)
-            ret2 = F.np.sum(x)
-            return ret1, ret2
+            return F.relu(x.as_classic_ndarray()), F.np.sum(x)
 
-    net = TestCachedOpOutputType()
-    for hybridize in [True, False]:
-        if hybridize:
-            net.hybridize()
-        ret1, ret2 = net(data)
-        assert type(ret1) == mx.nd.NDArray
-        assert type(ret2) == np.ndarray
+    net = TestMixedTypeOutputsFailure()
+    net.hybridize()
+    assert_exception(net, TypeError, data_np)
 
 
 @with_seed()
@@ -331,6 +343,7 @@ def test_np_ndarray_astype():
 
     def check_astype_equal(dtype, copy, expect_zero_copy=False):
         mx_ret = mx_data.astype(dtype=dtype, copy=copy)
+        assert type(mx_ret) is np.ndarray
         np_ret = np_data.astype(dtype=dtype, copy=copy)
         assert mx_ret.dtype == np_ret.dtype
         assert same(mx_ret.asnumpy(), np_ret)
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 8c13227..34b2cbe 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -19,7 +19,7 @@
 from __future__ import absolute_import
 import numpy as _np
 import mxnet as mx
-from mxnet import numpy as np
+from mxnet import np, npe
 from mxnet.gluon import HybridBlock
 from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray
 from mxnet.test_utils import check_numeric_gradient
@@ -27,7 +27,7 @@ from common import with_seed
 import random
 
 
-@mx.use_np_compat
+@np.use_np_compat
 @with_seed()
 def test_np_sum():
     class TestSum(HybridBlock):
@@ -38,7 +38,7 @@ def test_np_sum():
             self._keepdims = keepdims
 
         def hybrid_forward(self, F, a, *args, **kwargs):
-            return F.numpy.sum(a, axis=self._axis, dtype=self._dtype, keepdims=self._keepdims)
+            return F.np.sum(a, axis=self._axis, dtype=self._dtype, keepdims=self._keepdims)
 
     def is_int(dtype):
         return 'int' in dtype
@@ -63,6 +63,7 @@ def test_np_sum():
                             x = mx.nd.array(x)
                         else:
                             x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype)
+                        x = x.as_np_ndarray()
                         x.attach_grad()
                         expected_ret = _np.sum(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims)
                         expected_ret = expected_ret.astype(dtype)
@@ -77,8 +78,8 @@ def test_np_sum():
 
                         # test numeric
                         if itype == 'float32' and dtype == 'float32':
-                            x_sym = mx.sym.Variable("x")
-                            mx_sym = mx.sym.numpy.sum(x_sym, axis=axis, dtype=dtype, keepdims=keepdims)
+                            x_sym = mx.sym.Variable("x").as_np_ndarray()
+                            mx_sym = mx.sym.np.sum(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_classic_ndarray()
                             check_numeric_gradient(mx_sym, [x], numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)
 
                         # test imperative
@@ -87,10 +88,11 @@ def test_np_sum():
                         assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
 
 
-@mx.use_np_compat
+@np.use_np_compat
 @with_seed()
 def test_np_dot():
     shapes = [
+        ((3, 0), (0, 4)),
         ((3,), (3,)),        # Case 1
         ((3, 4), (4, 5)),    # Case 2
         ((), ()),            # Case 3
@@ -102,7 +104,6 @@ def test_np_dot():
     eps = 1e-3
 
     for shape_a, shape_b in shapes:
-        print(shape_a, shape_b)
         np_a = _np.random.uniform(-1.0, 1.0, shape_a)
         np_a[abs(np_a) < eps] = 2 * eps;
         np_b = _np.random.uniform(-1.0, 1.0, shape_b)
@@ -110,12 +111,12 @@ def test_np_dot():
         a = mx.nd.array(np_a)
         b = mx.nd.array(np_b)
         np_res = _np.dot(np_a, np_b)
-        mx_res = np.dot(a, b)
+        mx_res = np.dot(a.as_np_ndarray(), b.as_np_ndarray())
         assert mx_res.shape == np_res.shape
         assert_almost_equal(np_res, mx_res.asnumpy(), rtol=1e-5, atol=1e-5)
         mx_a = mx.sym.Variable("a")
         mx_b = mx.sym.Variable("b")
-        mx_sym = mx.sym.numpy.dot(mx_a, mx_b)
+        mx_sym = mx.sym.np.dot(mx_a.as_np_ndarray(), mx_b.as_np_ndarray()).as_classic_ndarray()
         check_numeric_gradient(mx_sym, {"a": a, "b": b}, numeric_eps=eps, rtol=1e-2, atol=1e-3)
 
     bad_shapes = [((4, 5), (2, 3)), ((3, 4, 5), (6, ))]
@@ -124,13 +125,13 @@ def test_np_dot():
         a = mx.nd.array(random.random()) if len(shape_a) == 0 else rand_ndarray(shape_a)
         b = mx.nd.array(random.random()) if len(shape_b) == 0 else rand_ndarray(shape_b)
         try:
-            mx_res = np.dot(a, b)
+            mx_res = np.dot(a.as_np_ndarray(), b.as_np_ndarray())
         except mx.base.MXNetError:
             continue
         assert False
 
 
-@mx.use_np_compat
+@np.use_np_compat
 @with_seed()
 def test_np_mean():
     class TestMean(HybridBlock):
@@ -141,7 +142,7 @@ def test_np_mean():
             self._keepdims = keepdims
 
         def hybrid_forward(self, F, a, *args, **kwargs):
-            return F.numpy.mean(a, axis=self._axis, dtype=self._dtype, keepdims=self._keepdims)
+            return F.np.mean(a, axis=self._axis, dtype=self._dtype, keepdims=self._keepdims)
 
     def is_int(dtype):
         return 'int' in dtype
@@ -167,6 +168,7 @@ def test_np_mean():
                             x = mx.nd.array(x, dtype=itype)
                         else:
                             x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype)
+                        x = x.as_np_ndarray()
                         x.attach_grad()
                         expected_ret = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims)
                         expected_ret = expected_ret.astype(dtype)
@@ -182,8 +184,8 @@ def test_np_mean():
 
                         # test numeric
                         if itype == 'float32' and dtype == 'float32':
-                            x_sym = mx.sym.Variable("x")
-                            mx_sym = mx.sym.numpy.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims)
+                            x_sym = mx.sym.Variable("x").as_np_ndarray()
+                            mx_sym = mx.sym.np.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_classic_ndarray()
                             check_numeric_gradient(mx_sym, [x], numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)
 
                         # test imperative
@@ -193,12 +195,12 @@ def test_np_mean():
 
 
 @with_seed()
-@mx.use_np_compat
+@np.use_np_compat
 def test_np_transpose():
     # TODO(junwu): Add more test cases
-    data = mx.sym.var('a')
-    ret = mx.sym.np.transpose(data)
-    assert type(ret) == mx.sym.np._NumpySymbol
+    data = mx.sym.var('a').as_np_ndarray()
+    ret = data.transpose()
+    assert type(ret) == mx.sym.np._Symbol
 
     dtypes = ['float32', 'int32']
     for dtype in dtypes:
@@ -223,44 +225,44 @@ def test_np_transpose():
 
 
 @with_seed()
-@mx.use_np_compat
+@np.use_np_compat
 def test_relu():
     # TODO(junwu): Add more test cases
-    data = mx.sym.var('data')
-    ret = mx.sym.np.ext.relu(data)
-    assert type(ret) == mx.sym.np._NumpySymbol
+    data = mx.sym.var('data').as_np_ndarray()
+    ret = mx.sym.npe.relu(data)
+    assert type(ret) == mx.sym.np._Symbol
 
     shapes = [(), (0, 2, 0)]
     shapes.extend([rand_shape_nd(ndim, allow_zero_size=True) for ndim in range(5)])
     for shape in shapes:
         data = np.array(_np.random.uniform(size=shape).astype('float32'))
-        ret = np.ext.relu(data)
+        ret = npe.relu(data)
         assert type(ret) == np.ndarray
 
 
 @with_seed()
-@mx.use_np_compat
+@np.use_np_compat
 def test_sigmoid():
     # TODO(junwu): Add more test cases
-    data = mx.sym.var('data')
-    ret = mx.sym.np.ext.sigmoid(data)
-    assert type(ret) == mx.sym.np._NumpySymbol
+    data = mx.sym.var('data').as_np_ndarray()
+    ret = mx.sym.npe.sigmoid(data)
+    assert type(ret) == mx.sym.np._Symbol
 
     shapes = [(), (0, 2, 0)]
     shapes.extend([rand_shape_nd(ndim, allow_zero_size=True) for ndim in range(5)])
     for shape in shapes:
         data = np.array(_np.random.uniform(size=shape).astype('float32'))
-        ret = np.ext.sigmoid(data)
+        ret = npe.sigmoid(data)
         assert type(ret) == np.ndarray
 
 
 @with_seed()
-@mx.use_np_compat
+@np.use_np_compat
 def test_np_reshape():
     # TODO(junwu): Add more test cases
-    data = mx.sym.var('a')
-    ret = mx.sym.np.reshape(data, newshape=())
-    assert type(ret) == mx.sym.np._NumpySymbol
+    data = mx.sym.var('a').as_np_ndarray()
+    ret = data.reshape(shape=())
+    assert type(ret) == mx.sym.np._Symbol
 
     data = np.ones((1, 1, 1))
     ret = np.reshape(data, ())
@@ -271,12 +273,12 @@ def test_np_reshape():
 
 
 @with_seed()
-@mx.use_np_compat
+@np.use_np_compat
 def test_np_maximum():
     # TODO(junwu): Add more test cases
-    x1, x2 = mx.sym.var('x1'), mx.sym.var('x2')
+    x1, x2 = mx.sym.var('x1').as_np_ndarray(), mx.sym.var('x2').as_np_ndarray()
     ret = mx.sym.np.maximum(x1, x2)
-    assert type(ret) == mx.sym.np._NumpySymbol
+    assert type(ret) == mx.sym.np._Symbol
 
     def check_maximum(x1, x2):
         mx_out = np.maximum(x1, x2)
@@ -292,12 +294,12 @@ def test_np_maximum():
 
 
 @with_seed()
-@mx.use_np_compat
+@np.use_np_compat
 def test_np_minimum():
     # TODO(junwu): Add more test cases
-    x1, x2 = mx.sym.var('x1'), mx.sym.var('x2')
+    x1, x2 = mx.sym.var('x1').as_np_ndarray(), mx.sym.var('x2').as_np_ndarray()
     ret = mx.sym.np.minimum(x1, x2)
-    assert type(ret) == mx.sym.np._NumpySymbol
+    assert type(ret) == mx.sym.np._Symbol
 
     def check_minimum(x1, x2):
         mx_out = np.minimum(x1, x2)