You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/07/18 00:00:55 UTC

[incubator-mxnet] 19/42: [numpy] Fix d2l chapter8 (#15237)

This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit d5537d41460cdbee8010f1c48d9acb41f2a983e0
Author: reminisce <wu...@gmail.com>
AuthorDate: Thu Jun 13 14:42:34 2019 -0700

    [numpy] Fix d2l chapter8 (#15237)
    
    * Add np op doc
    
    * Fix several issues
    
    * Add a N-D dot b 2D support
    
    * Simplify array creation api
    
    * Add swapaxes
    
    * Fix rnn gluon
    
    * More fix
    
    * Fix pylint
    
    * Delete
    
    * Fix mp windows
---
 python/mxnet/_numpy_op_doc.py          |  88 +++++++++++++++++++++++
 python/mxnet/base.py                   |   4 ++
 python/mxnet/gluon/data/dataloader.py  |   3 -
 python/mxnet/gluon/nn/basic_layers.py  |   3 +-
 python/mxnet/gluon/rnn/rnn_layer.py    |  33 +++++----
 python/mxnet/ndarray/ndarray.py        |   4 +-
 python/mxnet/ndarray/numpy/_op.py      |  45 +++++++++++-
 python/mxnet/numpy/__init__.py         |   2 -
 python/mxnet/numpy/multiarray.py       | 126 ++++++++++++++++++++++++---------
 python/mxnet/symbol/numpy/_symbol.py   |  86 +++++++++++++++++-----
 src/ndarray/ndarray.cc                 |   2 +-
 src/operator/nn/concat.cc              |   1 +
 src/operator/numpy/np_dot-inl.h        |  32 +++++++--
 src/operator/numpy/np_dot.cc           |  18 ++++-
 src/operator/numpy/np_matrix_op.cc     |  63 +++++++++++++++++
 src/operator/numpy/np_matrix_op.cu     |   3 +
 src/operator/rnn.cc                    |   1 +
 src/operator/sequence_mask.cc          |   3 +
 src/operator/swapaxis-inl.h            |  42 +++++++++--
 src/operator/swapaxis.cc               |   2 +-
 src/operator/tensor/indexing_op.cc     |   2 +
 src/operator/tensor/matrix_op-inl.h    |   8 +--
 src/operator/tensor/matrix_op.cc       |   1 +
 tests/python/unittest/test_numpy_op.py |  68 ++++++++++++++++++
 24 files changed, 549 insertions(+), 91 deletions(-)

diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py
new file mode 100644
index 0000000..17f92ce
--- /dev/null
+++ b/python/mxnet/_numpy_op_doc.py
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: skip-file
+
+"""Doc placeholder for numpy ops with prefix _np."""
+
+
+def _np_reshape(a, newshape, order='C'):
+    """Gives a new shape to an array without changing its data.
+
+    Parameters
+    ----------
+    a : ndarray
+        Array to be reshaped.
+    newshape : int or tuple of ints
+        The new shape should be compatible with the original shape. If
+        an integer, then the result will be a 1-D array of that length.
+        One shape dimension can be -1. In this case, the value is
+        inferred from the length of the array and remaining dimensions.
+    order : {'C'}, optional
+        Read the elements of `a` using this index order, and place the
+        elements into the reshaped array using this index order.  'C'
+        means to read / write the elements using C-like index order,
+        with the last axis index changing fastest, back to the first
+        axis index changing slowest. Other order types such as 'F'/'A'
+        may be added in the future.
+
+    Returns
+    -------
+    reshaped_array : ndarray
+        It will be always a copy of the original array. This behavior is different
+        from the official NumPy package where views of the original array may be
+        generated.
+
+    See Also
+    --------
+    ndarray.reshape : Equivalent method.
+    """
+    pass
+
+
+def _np_ones_like(a):
+    """Return an array of ones with the same shape and type as a given array.
+
+    Parameters
+    ----------
+    a : ndarray
+        The shape and data-type of `a` define these same attributes of
+        the returned array.
+
+    Returns
+    -------
+    out : ndarray
+        Array of ones with the same shape and type as `a`.
+    """
+    pass
+
+
+def _np_zeros_like(a):
+    """Return an array of zeros with the same shape and type as a given array.
+
+    Parameters
+    ----------
+    a : ndarray
+        The shape and data-type of `a` define these same attributes of
+        the returned array.
+
+    Returns
+    -------
+    out : ndarray
+        Array of zeros with the same shape and type as `a`.
+    """
+    pass
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index a348326..a4f75c6 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -778,6 +778,7 @@ def _init_np_op_module(root_module_name, np_module_name, mx_module_name, make_op
     make_op_func : function
         Function for creating op functions.
     """
+    from . import _numpy_op_doc as _np_op_doc
     if np_module_name == 'numpy':
         op_name_prefix = _NP_OP_PREFIX
         submodule_name_list = _NP_OP_SUBMODULE_LIST
@@ -839,3 +840,6 @@ def _init_np_op_module(root_module_name, np_module_name, mx_module_name, make_op
         function.__module__ = module_name_local
         setattr(cur_module, function.__name__, function)
         cur_module.__all__.append(function.__name__)
+
+        if hasattr(_np_op_doc, name):
+            function.__doc__ = getattr(_np_op_doc, name).__doc__
diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py
index 59b1582..9f0939e 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -470,9 +470,6 @@ class _MultiWorkerIter(object):
             batch = _as_in_context(batch, context.cpu_pinned(self._pin_device_id))
         batch = batch[0] if len(batch) == 1 else batch
         self._rcvd_idx += 1
-        if is_np_array():
-            new_batch = [member.as_np_ndarray() for member in batch]
-            batch = new_batch
         return batch
 
     def next(self):
diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py
index 512863a..eb4d55b 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -413,8 +413,9 @@ class Embedding(HybridBlock):
                                       init=weight_initializer, dtype=dtype,
                                       allow_deferred_init=True, grad_stype=grad_stype)
 
-    @_adapt_np_array
     def hybrid_forward(self, F, x, weight):
+        if is_np_array():
+            F = F.npx
         return F.Embedding(x, weight, name='fwd', **self._kwargs)
 
     def __repr__(self):
diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py
index b3cc596..1104b1e 100644
--- a/python/mxnet/gluon/rnn/rnn_layer.py
+++ b/python/mxnet/gluon/rnn/rnn_layer.py
@@ -28,6 +28,8 @@ __all__ = ['RNN', 'LSTM', 'GRU']
 from ... import ndarray, symbol
 from .. import HybridBlock, tensor_types
 from . import rnn_cell
+from ...util import is_np_array
+
 
 class _RNNLayer(HybridBlock):
     """Implementation of recurrent layers."""
@@ -217,7 +219,10 @@ class _RNNLayer(HybridBlock):
                 info.update(kwargs)
             else:
                 info = kwargs
-            states.append(func(name='%sh0_%d'%(self.prefix, i), **info))
+            state = func(name='%sh0_%d' % (self.prefix, i), **info)
+            if is_np_array():
+                state = state.as_np_ndarray()
+            states.append(state)
         return states
 
     def __call__(self, inputs, states=None, sequence_length=None, **kwargs):
@@ -236,7 +241,6 @@ class _RNNLayer(HybridBlock):
         else:
             return super(_RNNLayer, self).__call__(inputs, states, **kwargs)
 
-
     def hybrid_forward(self, F, inputs, states, sequence_length=None, **kwargs):
         if F is ndarray:
             batch_size = inputs.shape[self._layout.find('N')]
@@ -254,8 +258,9 @@ class _RNNLayer(HybridBlock):
 
     def _forward_kernel(self, F, inputs, states, sequence_length, **kwargs):
         """ forward using CUDNN or CPU kenrel"""
+        swapaxes = F.np.swapaxes if is_np_array() else F.swapaxes
         if self._layout == 'NTC':
-            inputs = F.swapaxes(inputs, dim1=0, dim2=1)
+            inputs = swapaxes(inputs, 0, 1)
         if self._projection_size is None:
             params = (kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1)
                       for t in ['weight', 'bias']
@@ -270,21 +275,23 @@ class _RNNLayer(HybridBlock):
                       for g in ['i2h', 'h2h', 'h2r']
                       if g != 'h2r' or t != 'bias')
 
-        params = F._internal._rnn_param_concat(*params, dim=0)
+        rnn_param_concat = F.np._internal.rnn_param_concat if is_np_array()\
+            else F._internal._rnn_param_concat
+        params = rnn_param_concat(*params, dim=0)
 
         if self._use_sequence_length:
             rnn_args = states + [sequence_length]
         else:
             rnn_args = states
 
-        rnn = F.RNN(inputs, params, *rnn_args, use_sequence_length=self._use_sequence_length,
-                    state_size=self._hidden_size, projection_size=self._projection_size,
-                    num_layers=self._num_layers, bidirectional=self._dir == 2,
-                    p=self._dropout, state_outputs=True, mode=self._mode,
-                    lstm_state_clip_min=self._lstm_state_clip_min,
-                    lstm_state_clip_max=self._lstm_state_clip_max,
-                    lstm_state_clip_nan=self._lstm_state_clip_nan)
-
+        rnn_fn = F.npx.RNN if is_np_array() else F.RNN
+        rnn = rnn_fn(inputs, params, *rnn_args, use_sequence_length=self._use_sequence_length,
+                     state_size=self._hidden_size, projection_size=self._projection_size,
+                     num_layers=self._num_layers, bidirectional=self._dir == 2,
+                     p=self._dropout, state_outputs=True, mode=self._mode,
+                     lstm_state_clip_min=self._lstm_state_clip_min,
+                     lstm_state_clip_max=self._lstm_state_clip_max,
+                     lstm_state_clip_nan=self._lstm_state_clip_nan)
 
         if self._mode == 'lstm':
             outputs, states = rnn[0], [rnn[1], rnn[2]]
@@ -292,7 +299,7 @@ class _RNNLayer(HybridBlock):
             outputs, states = rnn[0], [rnn[1]]
 
         if self._layout == 'NTC':
-            outputs = F.swapaxes(outputs, dim1=0, dim2=1)
+            outputs = swapaxes(outputs, 0, 1)
 
         return outputs, states
 
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 1ba7bce..5ddc9f7 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -928,7 +928,7 @@ fixed-size items.
 
         check_call(_LIB.MXNDArraySlice(
             self.handle, mx_uint(start), mx_uint(stop), ctypes.byref(handle)))
-        return NDArray(handle=handle, writable=self.writable)
+        return self.__class__(handle=handle, writable=self.writable)
 
     def _at(self, idx):
         """Returns a view of the array sliced at `idx` in the first dim.
@@ -1085,7 +1085,7 @@ fixed-size items.
                                            c_array(ctypes.c_int64, shape),
                                            reverse,
                                            ctypes.byref(handle)))
-        return NDArray(handle=handle, writable=self.writable)
+        return self.__class__(handle=handle, writable=self.writable)
 
     def reshape_like(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`reshape_like`.
diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index f3f4d74..22ca5b7 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -26,7 +26,7 @@ from . import _internal as _npi
 
 __all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax',
            'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate',
-           'clip']
+           'clip', 'swapaxes', 'expand_dims']
 
 
 @set_module('mxnet.ndarray.numpy')
@@ -495,3 +495,46 @@ def clip(a, a_min, a_max, out=None):
     if a_max is None:
         a_max = float('inf')
     return _npi.clip(a, a_min, a_max, out=out)
+
+
+@set_module('mxnet.ndarray.numpy')
+def swapaxes(a, axis1, axis2):
+    """Interchange two axes of an array.
+
+    Parameters
+    ----------
+    a : ndarray
+        Input array.
+    axis1 : int
+        First axis.
+    axis2 : int
+        Second axis.
+
+    Returns
+    -------
+    a_swapped : ndarray
+        Swapped array. This is always a copy of the input array.
+    """
+    return _npi.swapaxes(a, dim1=axis1, dim2=axis2)
+
+
+@set_module('mxnet.ndarray.numpy')
+def expand_dims(a, axis):
+    """Expand the shape of an array.
+
+    Insert a new axis that will appear at the `axis` position in the expanded
+
+    Parameters
+    ----------
+    a : ndarray
+        Input array.
+    axis : int
+        Position in the expanded axes where the new axis is placed.
+
+    Returns
+    -------
+    res : ndarray
+        Output array. The number of dimensions is one greater than that of
+        the input array.
+    """
+    return _npi.expand_dims(a, axis)
diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/__init__.py
index 344483d..e1c9d90 100644
--- a/python/mxnet/numpy/__init__.py
+++ b/python/mxnet/numpy/__init__.py
@@ -1,5 +1,3 @@
-#!/usr/bin/env python
-
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 409cbf4..29a7686 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -22,6 +22,12 @@
 
 from __future__ import absolute_import
 from __future__ import division
+
+try:
+    from __builtin__ import slice as py_slice
+except ImportError:
+    from builtins import slice as py_slice
+
 from array import array as native_array
 import sys
 import ctypes
@@ -39,7 +45,7 @@ from ..ndarray.numpy import _internal as _npi
 
 __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange',
            'argmax', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate',
-           'clip']
+           'clip', 'swapaxes', 'expand_dims']
 
 
 # This function is copied from ndarray.py since pylint
@@ -97,25 +103,38 @@ class ndarray(NDArray):
     floating point number, or something else, etc.). Arrays should be constructed using
     `array`, `zeros` or `empty`. Currently, only c-contiguous arrays are supported."""
 
+    # pylint: disable=too-many-return-statements
     def __getitem__(self, key):
-        # TODO(junwu): calling base class __setitem__ is a temp solution
-        if self.ndim == 0:
+        # TODO(junwu): calling base class __getitem__ is a temp solution
+        ndim = self.ndim
+        shape = self.shape
+        if ndim == 0:
             if key != ():
                 raise IndexError('scalar tensor can only accept `()` as index')
         if isinstance(key, tuple) and len(key) == 0:
             return self
-        if isinstance(key, tuple) and len(key) == self.ndim\
+        elif isinstance(key, tuple) and len(key) == ndim\
                 and all(isinstance(idx, integer_types) for idx in key):
-            out = self._as_nd_ndarray()
+            out = self
             for idx in key:
                 out = out[idx]
-            return out.reshape(()).as_np_ndarray()
-        if isinstance(key, integer_types):
-            if key > self.shape[0] - 1:
+            return out
+        elif isinstance(key, integer_types):
+            if key > shape[0] - 1:
                 raise IndexError(
                     'index {} is out of bounds for axis 0 with size {}'.format(
-                        key, self.shape[0]))
+                        key, shape[0]))
             return self._at(key)
+        elif isinstance(key, py_slice):
+            if key.step is not None and key.step != 1:
+                if key.step == 0:
+                    raise ValueError("slice step cannot be zero")
+                return self.as_nd_ndarray()._get_nd_basic_indexing(key).as_np_ndarray()
+            elif key.start is not None or key.stop is not None:
+                return self._slice(key.start, key.stop)
+            else:
+                return self
+
         if isinstance(key, ndarray):
             key = key._as_nd_ndarray()
         elif isinstance(key, tuple):
@@ -126,6 +145,7 @@ class ndarray(NDArray):
         elif sys.version_info[0] > 2 and isinstance(key, range):
             key = _get_index(key)
         return self._as_nd_ndarray().__getitem__(key).as_np_ndarray()
+    # pylint: enable=too-many-return-statements
 
     def __setitem__(self, key, value):
         # TODO(junwu): calling base class __setitem__ is a temp solution
@@ -369,9 +389,6 @@ class ndarray(NDArray):
         return self.transpose()
     # pylint: enable= invalid-name, undefined-variable
 
-    def _slice(self, start, stop):
-        raise NotImplementedError
-
     def all(self, axis=None, out=None, keepdims=False):
         raise NotImplementedError
 
@@ -606,13 +623,11 @@ class ndarray(NDArray):
         """
         raise AttributeError('mxnet.numpy.ndarray object has no attribute pad')
 
-    def swapaxes(self, *args, **kwargs):
-        """Convenience fluent method for :py:func:`swapaxes`.
-
-        The arguments are the same as for :py:func:`swapaxes`, with
-        this array as data.
+    def swapaxes(self, axis1, axis2):  # pylint: disable=arguments-differ
+        """Return a copy of the array with axis1 and axis2 interchanged.
+        Refer to `mxnet.numpy.swapaxes` for full documentation.
         """
-        raise NotImplementedError
+        return swapaxes(self, axis1, axis2)
 
     def split(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`split`.
@@ -1180,13 +1195,10 @@ class ndarray(NDArray):
         """
         raise AttributeError('mxnet.numpy.ndarray object has no attribute softmin')
 
-    def squeeze(self, *args, **kwargs):
-        """Convenience fluent method for :py:func:`squeeze`.
-
-        The arguments are the same as for :py:func:`squeeze`, with
-        this array as data.
+    def squeeze(self, axis=None):  # pylint: disable=arguments-differ
+        """Remove single-dimensional entries from the shape of a.
         """
-        raise NotImplementedError
+        return _mx_np_op.squeeze(self, axis=axis)
 
     def broadcast_to(self, shape):
         raise AttributeError('mxnet.numpy.ndarray object has no attribute broadcast_to')
@@ -1245,13 +1257,13 @@ def empty(shape, dtype=None, **kwargs):
 
 
 @set_module('mxnet.numpy')
-def array(object, dtype=None, **kwargs):
+def array(object, dtype=None, ctx=None):
     """
     Create an array.
 
     Parameters
     ----------
-    object : array_like or `mxnet.ndarray.NDArray` or `mxnet.numpy.ndarray`
+    object : array_like or `numpy.ndarray` or `mxnet.numpy.ndarray`
         An array, any object exposing the array interface, an object whose
         __array__ method returns an array, or any (nested) sequence.
     dtype : data-type, optional
@@ -1265,17 +1277,18 @@ def array(object, dtype=None, **kwargs):
     out : ndarray
         An array object satisfying the specified requirements.
     """
-    _sanity_check_params('array', ['copy', 'order', 'subok', 'ndim'], kwargs)
-    ctx = kwargs.get('ctx', current_context())
     if ctx is None:
         ctx = current_context()
-    if dtype is None:
-        dtype = _np.float32
-    if not isinstance(object, (ndarray, NDArray, _np.ndarray)):
-        try:
-            object = _np.array(object, dtype=dtype)
-        except:
-            raise TypeError('source array must be an array like object')
+    if isinstance(object, ndarray):
+        dtype = object.dtype if dtype is None else dtype
+    else:
+        dtype = mx_real_t if dtype is None else dtype
+        if not isinstance(object, (ndarray, _np.ndarray)):
+            try:
+                object = _np.array(object, dtype=dtype)
+            except Exception as e:
+                print(e)
+                raise TypeError('source array must be an array like object')
     ret = empty(object.shape, dtype=dtype, ctx=ctx)
     if len(object.shape) == 0:
         ret[()] = object
@@ -1662,3 +1675,46 @@ def clip(a, a_min, a_max, out=None):
         with `a_max`.
     """
     return _mx_nd_np.clip(a, a_min, a_max, out=out)
+
+
+@set_module('mxnet.numpy')
+def swapaxes(a, axis1, axis2):
+    """Interchange two axes of an array.
+
+    Parameters
+    ----------
+    a : ndarray
+        Input array.
+    axis1 : int
+        First axis.
+    axis2 : int
+        Second axis.
+
+    Returns
+    -------
+    a_swapped : ndarray
+        Swapped array. This is always a copy of the input array.
+    """
+    return _npi.swapaxes(a, dim1=axis1, dim2=axis2)
+
+
+@set_module('mxnet.numpy')
+def expand_dims(a, axis):
+    """Expand the shape of an array.
+
+    Insert a new axis that will appear at the `axis` position in the expanded
+
+    Parameters
+    ----------
+    a : ndarray
+        Input array.
+    axis : int
+        Position in the expanded axes where the new axis is placed.
+
+    Returns
+    -------
+    res : ndarray
+        Output array. The number of dimensions is one greater than that of
+        the input array.
+    """
+    return _npi.expand_dims(a, axis)
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index e333a62..f24c2aa 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -22,7 +22,7 @@ from __future__ import absolute_import
 import ctypes
 import numpy as _np
 from . import _op as _mx_np_op
-from ...base import _LIB, SymbolHandle, numeric_types
+from ...base import _LIB, SymbolHandle, numeric_types, mx_uint
 from ...util import _sanity_check_params, check_call, set_module
 from ...context import current_context
 from ..symbol import Symbol
@@ -30,13 +30,29 @@ from .._internal import _set_np_symbol_class
 from . import _internal as _npi
 
 __all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'concatenate', 'arange', 'argmax',
-           'clip', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power']
+           'clip', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'swapaxes',
+           'expand_dims']
+
+
+def _num_outputs(sym):
+    return len(sym.as_nd_ndarray())
 
 
 @set_module('mxnet.symbol.numpy')
 class _Symbol(Symbol):
-    def __getitem__(self, item):
-        raise NotImplementedError
+    def __getitem__(self, key):
+        num_outputs = _num_outputs(self)
+        if num_outputs == 1:
+            raise NotImplementedError
+        if not isinstance(key, int):
+            raise NotImplementedError
+        if key >= num_outputs:
+            # Important, python determines the end by this exception
+            raise IndexError
+        handle = SymbolHandle()
+        check_call(_LIB.MXSymbolGetOutput(
+            self.handle, mx_uint(key), ctypes.byref(handle)))
+        return _Symbol(handle=handle)
 
     def __setitem__(self, key, value):
         raise NotImplementedError
@@ -257,13 +273,11 @@ class _Symbol(Symbol):
         """
         raise AttributeError('_Symbol object has no attribute pad')
 
-    def swapaxes(self, *args, **kwargs):
-        """Convenience fluent method for :py:func:`swapaxes`.
-
-        The arguments are the same as for :py:func:`swapaxes`, with
-        this array as data.
+    def swapaxes(self, axis1, axis2):  # pylint: disable=arguments-differ
+        """Return a copy of the array with axis1 and axis2 interchanged.
+        Refer to `mxnet.numpy.swapaxes` for full documentation.
         """
-        raise NotImplementedError
+        return swapaxes(self, axis1, axis2)
 
     def split(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`split`.
@@ -831,13 +845,10 @@ class _Symbol(Symbol):
         """
         raise AttributeError('_Symbol object has no attribute softmin')
 
-    def squeeze(self, *args, **kwargs):
-        """Convenience fluent method for :py:func:`squeeze`.
-
-        The arguments are the same as for :py:func:`squeeze`, with
-        this array as data.
+    def squeeze(self, axis=None):  # pylint: disable=arguments-differ
+        """Remove single-dimensional entries from the shape of a.
         """
-        raise NotImplementedError
+        return _mx_np_op.squeeze(self, axis=axis)
 
     def broadcast_to(self, *args, **kwargs):
         raise AttributeError('_Symbol object has no attribute broadcast_to')
@@ -1173,4 +1184,47 @@ def clip(a, a_min, a_max, out=None):
     return _npi.clip(a, a_min, a_max, out=out)
 
 
+@set_module('mxnet.symbol.numpy')
+def swapaxes(a, axis1, axis2):
+    """Interchange two axes of an array.
+
+    Parameters
+    ----------
+    a : _Symbol
+        Input array.
+    axis1 : int
+        First axis.
+    axis2 : int
+        Second axis.
+
+    Returns
+    -------
+    a_swapped : _Symbol
+        Swapped array symbol.
+    """
+    return _npi.swapaxes(a, dim1=axis1, dim2=axis2)
+
+
+@set_module('mxnet.symbol.numpy')
+def expand_dims(a, axis):
+    """Expand the shape of an array.
+
+    Insert a new axis that will appear at the `axis` position in the expanded
+
+    Parameters
+    ----------
+    a : _Symbol
+        Input array.
+    axis : int
+        Position in the expanded axes where the new axis is placed.
+
+    Returns
+    -------
+    res : _Symbol
+        Output array. The number of dimensions is one greater than that of
+        the input array.
+    """
+    return _npi.expand_dims(a, axis)
+
+
 _set_np_symbol_class(_Symbol)
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index f883a35..f10f5db 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -312,7 +312,7 @@ NDArray NDArray::AtWithRecord(index_t idx) {
   CHECK(storage_type() == kDefaultStorage)
       << "Storage type " << storage_type() << " doesn't support At()";
   NDArray ret = this->SliceWithRecord(idx, idx+1);
-  if (shape_.ndim() > 1) {
+  if (shape_.ndim() > 1 || Imperative::Get()->is_np_shape()) {
     return ret.ReshapeWithRecord(mxnet::TShape(shape_.data()+1, shape_.data()+shape_.ndim()));
   } else {
     return ret;
diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc
index cda9c9a..80469b5 100644
--- a/src/operator/nn/concat.cc
+++ b/src/operator/nn/concat.cc
@@ -403,6 +403,7 @@ NNVM_REGISTER_OP(_backward_Concat)
 // which handles the case where the first one or two inputs may have
 // unknown shape that can be inferred from output shape.
 NNVM_REGISTER_OP(_rnn_param_concat)
+.add_alias("_npi_rnn_param_concat")
 #if MXNET_USE_MKLDNN == 1
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
diff --git a/src/operator/numpy/np_dot-inl.h b/src/operator/numpy/np_dot-inl.h
index 2f7c589..fa67c07 100644
--- a/src/operator/numpy/np_dot-inl.h
+++ b/src/operator/numpy/np_dot-inl.h
@@ -140,14 +140,17 @@ inline void NumpyDotForward(const nnvm::NodeAttrs& attrs,
         Kernel<scalar_mul_kernel<Req>, xpu>::Launch(
           s, out.Size(), out.dptr<DType>(), tensor, scalar);
       });
-    } else if (b_shape.ndim() == 1) {
+    } else if (a_shape.ndim() == 1 || b_shape.ndim() == 1) {
       // Case 4: a is N-D array and b is 1-D array, sum product over the last axis
       MMImpl<xpu>(ctx, a, b, out, req[0]);
     } else {
-      // TODO(haojin2): To be implemented...
       // Case 5: a is N-D array and b is M-D array, sum product over the last axis
       //         of a and the 2nd-to-last axis of b
-      LOG(FATAL) << "Case 5 not implemented yet...";
+      // TODO(haojin2): To be implemented...
+      if (b_shape.ndim() != 2) {
+        LOG(FATAL) << "Only support case 5 when b.ndim = 2";
+      }
+      MMImpl<xpu>(ctx, a, b, out, req[0]);
     }
   });
 }
@@ -239,10 +242,29 @@ inline void NumpyDotBackward(const nnvm::NodeAttrs& attrs,
       MMImpl<xpu>(ctx, TBlob(a_), TBlob(ograd_), TBlob(grad_b_), req[1], true, false);
       MMImpl<xpu>(ctx, TBlob(ograd_), TBlob(b_), TBlob(grad_a_), req[0], false, true);
     } else {
-      // TODO(haojin2): To be implemented...
       // Case 5: a is N-D array and b is M-D array, sum product over the last axis
       //         of a and the 2nd-to-last axis of b
-      LOG(FATAL) << "Case 5 not implemented yet...";
+      // TODO(haojin2): To be implemented...
+      if (b_shape.ndim() != 2) {
+        LOG(FATAL) << "Only support case 5 when b.ndim = 2";
+      } else {  // a is N-D, b is 2D
+        index_t na = a_shape[a_shape.ndim() - 1];
+        index_t ma = a_shape.Size() / na;
+        index_t nograd = ograd.shape_[ograd.shape_.ndim() - 1];
+        index_t mograd = ograd.shape_.Size() / nograd;
+
+        Tensor<xpu, 2, DType> a_2d =
+            a.get_with_shape<xpu, 2, DType>(Shape2(ma, na), s);
+        Tensor<xpu, 2, DType> grad_a_2d =
+            grad_a.get_with_shape<xpu, 2, DType>(Shape2(ma, na), s);
+        Tensor<xpu, 2, DType> b_2d = b.FlatTo2D<xpu, DType>(s);
+        Tensor<xpu, 2, DType> grad_b_2d = grad_b.FlatTo2D<xpu, DType>(s);
+        Tensor<xpu, 2, DType> ograd_2d =
+            ograd.get_with_shape<xpu, 2, DType>(Shape2(mograd, nograd), s);
+
+        MMImpl<xpu>(ctx, TBlob(a_2d), TBlob(ograd_2d), TBlob(grad_b_2d), req[1], true, false);
+        MMImpl<xpu>(ctx, TBlob(ograd_2d), TBlob(b_2d), TBlob(grad_a_2d), req[0], false, true);
+      }
     }
   });
 }
diff --git a/src/operator/numpy/np_dot.cc b/src/operator/numpy/np_dot.cc
index 992bef0..627e688 100644
--- a/src/operator/numpy/np_dot.cc
+++ b/src/operator/numpy/np_dot.cc
@@ -80,7 +80,23 @@ inline bool NumpyDotShape(const nnvm::NodeAttrs& attrs,
   } else {
     // Case 5: a is N-D array and b is M-D array, sum product over the last axis
     //         of a and the 2nd-to-last axis of b
-    LOG(FATAL) << "Case 5 not implemented yet...";
+    TShape tmp_shape(a_shape.ndim(), -1);
+    tmp_shape[a_shape.ndim() - 1] = b_shape[b_shape.ndim() - 2];
+    SHAPE_ASSIGN_CHECK(*in_attrs, 0, tmp_shape);
+
+    tmp_shape = TShape(b_shape.ndim(), -1);
+    tmp_shape[b_shape.ndim() - 2] = a_shape[a_shape.ndim() - 1];
+    SHAPE_ASSIGN_CHECK(*in_attrs, 1, tmp_shape);
+
+    tmp_shape = TShape(a_shape.ndim() + b_shape.ndim() - 2, -1);
+    for (int i = 0; i < a_shape.ndim() - 1; ++i) {
+      tmp_shape[i] = a_shape[i];
+    }
+    for (int i = 0; i < b_shape.ndim() - 2; ++i) {
+      tmp_shape[i + a_shape.ndim() - 1] = b_shape[i];
+    }
+    tmp_shape[tmp_shape.ndim() - 1] = b_shape[b_shape.ndim() - 1];
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, tmp_shape);
   }
   return shape_is_known(*in_attrs) && shape_is_known(*out_attrs);
 }
diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc
index 80d70e5..1323447 100644
--- a/src/operator/numpy/np_matrix_op.cc
+++ b/src/operator/numpy/np_matrix_op.cc
@@ -310,5 +310,68 @@ NNVM_REGISTER_OP(_backward_np_concat)
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
 .set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>);
 
+bool NumpySqueezeShape(const nnvm::NodeAttrs& attrs,
+                       mxnet::ShapeVector *in_attrs,
+                       mxnet::ShapeVector *out_attrs) {
+  const SqueezeParam& param = nnvm::get<SqueezeParam>(attrs.parsed);
+  CHECK_EQ(in_attrs->size(), 1U) << "Input: [a]";
+  CHECK_EQ(out_attrs->size(), 1U);
+  const mxnet::TShape& dshape = in_attrs->at(0);
+  const int dndim = dshape.ndim();
+  if (!shape_is_known(dshape)) return false;
+  mxnet::TShape oshape = dshape;
+  // special case, scalar tensor
+  if (dshape.ndim() == 0) {
+    if (param.axis.has_value()) {
+      mxnet::Tuple<int> axes = param.axis.value();
+      CHECK_EQ(axes.ndim(), 1) << "cannot specify more than one axis for a scalar tensor";
+      CHECK(axes[0] == 0 || axes[0] == -1) << "axis " << axes[0]
+                                           << " is out of bounds of array of dimension 0";
+    }
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(0, -1));
+    return true;
+  }
+  if (param.axis.has_value()) {
+    // preprocess axis
+    mxnet::Tuple<int> axes = param.axis.value();
+    for (int i = 0; i < axes.ndim(); ++i) {
+      if (axes[i] < 0) {
+        axes[i] += dndim;
+        CHECK_GE(axes[i], 0)
+            << "axis " << axes[i] - dndim << " is out of bounds for array of dimension " << dndim;
+      }
+      CHECK_LT(axes[i], dndim)
+          << "axis " << axes[i] << " is out of bounds for array of dimension " << dndim;
+      CHECK_EQ(dshape[axes[i]], 1)
+          << "cannot select an axis to squeeze out which has size="
+          << dshape[axes[i]] << " not equal to one";
+      CHECK_NE(oshape[axes[i]], 0) << "duplicate value in axis";
+      oshape[axes[i]] = -1;
+    }
+  } else {
+    for (int i = 0; i < oshape.ndim(); ++i) {
+      if (oshape[i] == 1) oshape[i] = -1;
+    }
+  }
+  size_t oshape_size = SqueezeShapeHelper(&oshape);
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(oshape.data(), oshape.data()+oshape_size));
+  return true;
+}
+
+NNVM_REGISTER_OP(_np_squeeze)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<SqueezeParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"a"};
+  })
+.set_attr<mxnet::FInferShape>("FInferShape", NumpySqueezeShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_squeeze"})
+.add_argument("a", "NDArray-or-Symbol[]", "data to squeeze")
+.add_arguments(SqueezeParam::__FIELDS__());
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu
index 4cccf59..5354820 100644
--- a/src/operator/numpy/np_matrix_op.cu
+++ b/src/operator/numpy/np_matrix_op.cu
@@ -43,5 +43,8 @@ NNVM_REGISTER_OP(_npi_concatenate)
 NNVM_REGISTER_OP(_backward_np_concat)
 .set_attr<FCompute>("FCompute<gpu>", ConcatGradCompute<gpu>);
 
+NNVM_REGISTER_OP(_np_squeeze)
+.set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc
index 6a0dbd7..58f190a 100644
--- a/src/operator/rnn.cc
+++ b/src/operator/rnn.cc
@@ -634,6 +634,7 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr,
 #endif
 
 NNVM_REGISTER_OP(RNN)
+.add_alias("_npx_RNN")
 .describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are
 implemented, with both multi-layer and bidirectional support.
 
diff --git a/src/operator/sequence_mask.cc b/src/operator/sequence_mask.cc
index f4f81a8..ca58be1 100644
--- a/src/operator/sequence_mask.cc
+++ b/src/operator/sequence_mask.cc
@@ -191,5 +191,8 @@ Example::
                   "vector of sequence lengths of the form [batch_size]")
     .add_arguments(SequenceMaskParam::__FIELDS__());
 
+NNVM_REGISTER_OP(SequenceMask)
+.add_alias("_npx_SequenceMask");
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/swapaxis-inl.h b/src/operator/swapaxis-inl.h
index b17a81f..fd9872d 100644
--- a/src/operator/swapaxis-inl.h
+++ b/src/operator/swapaxis-inl.h
@@ -47,7 +47,7 @@ enum SwapAxisOpOutputs {kOut};
 
 struct SwapAxisParam : public dmlc::Parameter<SwapAxisParam> {
   // use int for enumeration
-  uint32_t dim1, dim2;
+  int dim1, dim2;
   DMLC_DECLARE_PARAMETER(SwapAxisParam) {
     DMLC_DECLARE_FIELD(dim1)
     .set_default(0)
@@ -106,8 +106,6 @@ class SwapAxisOp : public Operator {
                 const std::vector<OpReqType> &req) {
     using namespace mshadow;
     using namespace mshadow::expr;
-    int dim1 = param_.dim1;
-    int dim2 = param_.dim2;
 
     TBlob data_in = in_data[swapaxisenum::kData];
     TBlob data_out = out_data[swapaxisenum::kData];
@@ -115,10 +113,27 @@ class SwapAxisOp : public Operator {
 
     mxnet::TShape shape_in = data_in.shape_;
     mxnet::TShape shape_out = data_out.shape_;
+    int axis1 = param_.dim1;
+    if (axis1 < 0) {
+      axis1 += shape_in.ndim();
+    }
+    CHECK(axis1 >= 0 && axis1 < shape_in.ndim())
+        << "axis1: axis " << param_.dim1 << " is out of bounds for array of ndim "
+        << shape_in.ndim();
+
+    int axis2 = param_.dim2;
+    if (axis2 < 0) {
+      axis2 += shape_in.ndim();
+    }
+    CHECK(axis2 >= 0 && axis2 < shape_in.ndim())
+        << "axis2: axis " << param_.dim2 << " is out of bounds for array of ndim "
+        << shape_in.ndim();
+
+    if (shape_in.Size() == 0U) return;
 
     Shape<5> inter_shape;
 
-    Reshape2Five(&inter_shape, shape_in, dim1, dim2);
+    Reshape2Five(&inter_shape, shape_in, axis1, axis2);
 
     Tensor<xpu, 5, DType> inter_data_in = data_in.get_with_shape<xpu, 5, DType>(inter_shape, s);
 
@@ -187,13 +202,28 @@ class SwapAxisProp : public OperatorProperty {
     CHECK_EQ(in_shape->size(), 1U);
 
     mxnet::TShape &shape0 = (*in_shape)[swapaxisenum::kData];
+    if (!ndim_is_known(shape0)) return false;
+    int axis1 = param_.dim1;
+    if (axis1 < 0) {
+      axis1 += shape0.ndim();
+    }
+    CHECK(axis1 >= 0 && axis1 < shape0.ndim())
+        << "axis1: axis " << param_.dim1 << " is out of bounds for array of ndim " << shape0.ndim();
+
+    int axis2 = param_.dim2;
+    if (axis2 < 0) {
+      axis2 += shape0.ndim();
+    }
+    CHECK(axis2 >= 0 && axis2 < shape0.ndim())
+        << "axis2: axis " << param_.dim2 << " is out of bounds for array of ndim " << shape0.ndim();
+
     out_shape->clear();
     out_shape->push_back(shape0);
     mxnet::TShape &shape1 = (*out_shape)[swapaxisenum::kOut];
 
-    std::swap(shape1[param_.dim1], shape1[param_.dim2]);
+    std::swap(shape1[axis1], shape1[axis2]);
 
-    return true;
+    return shape_is_known(*out_shape);
   }
 
   bool InferType(std::vector<int> *in_type,
diff --git a/src/operator/swapaxis.cc b/src/operator/swapaxis.cc
index 45bcca4..32b26cc 100644
--- a/src/operator/swapaxis.cc
+++ b/src/operator/swapaxis.cc
@@ -69,6 +69,6 @@ Examples::
                        [ 3, 7]]]
 )code" ADD_FILELINE);
 
-NNVM_REGISTER_OP(SwapAxis).add_alias("swapaxes");
+NNVM_REGISTER_OP(SwapAxis).add_alias("swapaxes").add_alias("_npi_swapaxes");
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc
index 396d1c6..f229fef 100644
--- a/src/operator/tensor/indexing_op.cc
+++ b/src/operator/tensor/indexing_op.cc
@@ -466,6 +466,7 @@ DMLC_REGISTER_PARAMETER(ScatterNDParam);
 
 NNVM_REGISTER_OP(Embedding)
 MXNET_ADD_SPARSE_OP_ALIAS(Embedding)
+.add_alias("_npx_Embedding")
 .describe(R"code(Maps integer indices to vector representations (embeddings).
 
 This operator maps words to real-valued vectors in a high-dimensional space,
@@ -764,6 +765,7 @@ Examples::
 .add_argument("indices", "NDArray-or-Symbol", "The index array");
 
 NNVM_REGISTER_OP(one_hot)
+.add_alias("_npx_one_hot")
 .describe(R"code(Returns a one-hot array.
 
 The locations represented by `indices` take value `on_value`, while all
diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h
index 4e13354..cf3d8e6 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -2183,7 +2183,7 @@ inline size_t SqueezeShapeHelper(mxnet::TShape* shape) {
   CHECK(shape != nullptr);
   size_t count = 0;
   for (int i = 0; i < shape->ndim(); ++i) {
-    if ((*shape)[i] == 0) {
+    if ((*shape)[i] == -1) {
       ++count;
     } else {
       std::swap((*shape)[i], (*shape)[i-count]);
@@ -2216,12 +2216,12 @@ inline bool SqueezeShape(const nnvm::NodeAttrs& attrs,
       CHECK_EQ(dshape[axes[i]], 1)
         << "cannot select an axis to squeeze out which has size="
         << dshape[axes[i]] << " not equal to one";
-      CHECK_NE(oshape[axes[i]], 0) << "duplicate value in axis";
-      oshape[axes[i]] = 0;
+      CHECK_NE(oshape[axes[i]], -1) << "duplicate value in axis";
+      oshape[axes[i]] = -1;
     }
   } else {
     for (int i = 0; i < oshape.ndim(); ++i) {
-      if (oshape[i] == 1) oshape[i] = 0;
+      if (oshape[i] == 1) oshape[i] = -1;
     }
   }
   size_t oshape_size = SqueezeShapeHelper(&oshape);
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index b1165c5..df43bc6 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -409,6 +409,7 @@ Examples::
 
 
 NNVM_REGISTER_OP(expand_dims)
+.add_alias("_npi_expand_dims")
 .describe(R"code(Inserts a new axis of size 1 into the array shape
 
 For example, given ``x`` with shape ``(2,3,4)``, then ``expand_dims(x, axis=1)``
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 4e80166..8a80444 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -100,6 +100,8 @@ def test_np_dot():
         ((3, 4, 5), ()),     # Case 3.5.1
         ((), (3, 4, 5)),     # Case 3.5.2
         ((3, 4, 5), (5, )),  # Case 4
+        ((3, 4, 5), (5, 2)),
+        ((5,), (5, 2))
     ]
 
     eps = 1e-3
@@ -699,6 +701,72 @@ def test_np_concat():
                 assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
 
 
+@with_seed()
+@npx.use_np_shape
+def test_np_swapaxes():
+    config = [((0, 1, 2), 0, 1),
+              ((0, 1, 2), -1, -2),
+              ((4, 5, 6, 7), 2, 3),
+              ((4, 5, 6, 7), -2, -3)]
+
+    class TestSwapaxes(HybridBlock):
+        def __init__(self, axis1, axis2):
+            super(TestSwapaxes, self).__init__()
+            self._axis1 = axis1
+            self._axis2 = axis2
+
+        def hybrid_forward(self, F, x):
+            return F.np.swapaxes(x, self._axis1, self._axis2)
+
+    for shape, axis1, axis2 in config:
+        data_np = _np.random.uniform(size=shape)
+        data_mx = np.array(data_np, dtype=data_np.dtype)
+        ret_np = _np.swapaxes(data_np, axis1=axis1, axis2=axis2)
+        ret_mx = np.swapaxes(data_mx, axis1=axis1, axis2=axis2)
+        assert same(ret_mx.asnumpy(), ret_np)
+
+        net = TestSwapaxes(axis1, axis2)
+        for hybrid in [False, True]:
+            if hybrid:
+                net.hybridize()
+            ret_mx = net(data_mx)
+            assert same(ret_mx.asnumpy(), ret_np)
+
+
+@with_seed()
+@npx.use_np_shape
+def test_np_squeeze():
+    config = [((), None),
+              ((), -1),
+              ((), 0),
+              ((4, 1, 2), None),
+              ((1, 1, 1), None),
+              ((1, 0, 1, 5), 2),
+              ((1, 0, 1, 1), (-1, -4))]
+
+    class TestSqueeze(HybridBlock):
+        def __init__(self, axis):
+            super(TestSqueeze, self).__init__()
+            self._axis = axis
+
+        def hybrid_forward(self, F, x):
+            return F.np.squeeze(x, axis=self._axis)
+
+    for shape, axis in config:
+        data_np = _np.random.uniform(size=shape)
+        data_mx = np.array(data_np, dtype=data_np.dtype)
+        ret_np = _np.squeeze(data_np, axis=axis)
+        ret_mx = np.squeeze(data_mx, axis=axis)
+        assert same(ret_mx.asnumpy(), ret_np)
+
+        net = TestSqueeze(axis)
+        for hybrid in [False, True]:
+            if hybrid:
+                net.hybridize()
+            ret_mx = net(data_mx)
+            assert same(ret_mx.asnumpy(), ret_np)
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()