You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/07/18 00:01:10 UTC

[incubator-mxnet] 34/42: [numpy][doc-fix] mean, transpose, stack, split, log2, rint and radians (#15370)

This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 30b036e0e99e98d3c457ac8c903b7027202970e8
Author: Mike <ma...@connect.hku.hk>
AuthorDate: Wed Jul 3 17:16:32 2019 +0800

    [numpy][doc-fix] mean, transpose, stack, split, log2, rint and radians (#15370)
    
    * Doc fix for split, stack, transpose, mean, rint, radians, log2.
    
    * Minor syntax fix
    
    * Add some disable=line-too-long to pass pylint test
    
    * Add Notes following the guide of example PR by Mu Li
    
    * Minor syntax fix
    
    * Fix a non-ascii character
    
    * Fix issues mentioned in review by @reminisce
    
    * Register mean into npi namespace and wrap it to have same sigatrue as
    standard numpy
    
    * Add mean to __all__ list
    
    * Note the imcompatibility of broacasting to output
    
    * Specify out must have the same type
    
    * Minor syntax fix
    
    * Clearify the `out` in symbol is only a dummy variable
    
    Fix the mess due to pull rebase
    
    Correct the wrong return statement in multiarray
    
    Again, syntax fix
    
    Syntax fix one more time
---
 python/mxnet/_numpy_op_doc.py                      |  43 ++++
 python/mxnet/ndarray/numpy/_op.py                  | 239 ++++++++++++++++++-
 python/mxnet/numpy/multiarray.py                   | 254 +++++++++++++++++++--
 python/mxnet/symbol/numpy/_symbol.py               | 241 +++++++++++++++++--
 src/operator/numpy/np_broadcast_reduce_op_value.cc |   2 +-
 src/operator/numpy/np_broadcast_reduce_op_value.cu |   2 +-
 src/operator/numpy/np_elemwise_unary_op_basic.cc   |   6 +-
 src/operator/numpy/np_elemwise_unary_op_basic.cu   |   6 +-
 8 files changed, 750 insertions(+), 43 deletions(-)

diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py
index b285346..a27f209 100644
--- a/python/mxnet/_numpy_op_doc.py
+++ b/python/mxnet/_numpy_op_doc.py
@@ -366,3 +366,46 @@ def  _np_copy(a, out=None):
     array([0.])
     """
     pass
+
+
+def _np_transpose(a, axes=None):
+    """
+    transpose(a, axes=None)
+
+    Permute the dimensions of an array.
+
+    Parameters
+    ----------
+    a : ndarray
+        Input array.
+    axes : list of ints, optional
+        By default, reverse the dimensions,
+        otherwise permute the axes according to the values given.
+
+    Returns
+    -------
+    p : ndarray
+        a with its axes permuted.
+
+    Notes
+    -----
+    This function differs from the original `numpy.transpose
+    <https://docs.scipy.org/doc/numpy/reference/generated/numpy.transpose.html>`_ in
+    the following way(s):
+
+    - only ndarray is accepted as valid input, python iterables are not supported
+
+    Examples
+    --------
+    >>> x = np.arange(4).reshape((2,2))
+    >>> x
+    array([[0., 1.],
+           [2., 3.]])
+    >>> np.transpose(x)
+    array([[0., 2.],
+           [1., 3.]])
+    >>> x = np.ones((1, 2, 3))
+    >>> np.transpose(x, (1, 0, 2)).shape
+    (2, 1, 3)
+    """
+    pass
diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index 054d9b8..7aaba1a 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -18,6 +18,7 @@
 
 """Namespace for numpy operators used in Gluon dispatched by F=ndarray."""
 
+# pylint: disable=too-many-lines
 from __future__ import absolute_import
 import numpy as _np
 from ...base import numeric_types
@@ -30,7 +31,7 @@ __all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax',
            'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate',
            'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace',
            'sin', 'cos', 'sinh', 'cosh', 'log10', 'sqrt', 'abs', 'exp', 'arctan', 'sign', 'log',
-           'degrees']
+           'degrees', 'log2', 'rint', 'radians', 'mean']
 
 
 @set_module('mxnet.ndarray.numpy')
@@ -180,6 +181,68 @@ def minimum(x1, x2, out=None):
 
 
 @set_module('mxnet.ndarray.numpy')
+def mean(a, axis=None, dtype=None, out=None, keepdims=False):  # pylint: disable=arguments-differ
+    """
+    mean(a, axis=None, dtype=None, out=None, keepdims=None)
+
+    Compute the arithmetic mean along the specified axis.
+    Returns the average of the array elements.
+    The average is taken over the flattened array by default, otherwise over the specified axis.
+
+    Parameters
+    ----------
+    a : ndarray
+        ndarray containing numbers whose mean is desired.
+    axis : None or int or tuple of ints, optional
+        Axis or axes along which the means are computed. The default is to compute the mean of the flattened array.
+        If this is a tuple of ints, a mean is performed over multiple axes,
+        instead of a single axis or all the axes as before.
+    dtype : data-type, optional
+        Type to use in computing the mean. For integer inputs, the default is float32;
+        for floating point inputs, it is the same as the input dtype.
+    out : ndarray, optional
+        Alternate output array in which to place the result. The default is None; if provided,
+        it must have the same shape and type as the expected output
+    keepdims : bool, optional
+        If this is set to True, the axes which are reduced are left in the result
+        as dimensions with size one. With this option, the result will broadcast correctly
+        against the input array.
+        If the default value is passed, then keepdims will not be passed through to the mean
+        method of sub-classes of ndarray, however any non-default value will be. If the sub-class
+        method does not implement keepdims any exceptions will be raised.
+
+    Returns
+    -------
+    m : ndarray, see dtype parameter above
+        If out=None, returns a new array containing the mean values,
+        otherwise a reference to the output array is returned.
+
+    Notes
+    -----
+    This function differs from the original `numpy.mean
+    <https://docs.scipy.org/doc/numpy/reference/generated/numpy.mean.html>`_ in
+    the following way(s):
+
+    - only ndarray is accepted as valid input, python iterables or scalar is not supported
+    - default data type for integer input is float32
+
+    Examples
+    --------
+    >>> a = np.array([[1, 2], [3, 4]])
+    >>> np.mean(a)
+    array(2.5)
+    >>> a = np.zeros((2, 512*512), dtype=np.float32)
+    >>> a[0,:] = 1.0
+    >>> a[1,:] = 0.1
+    >>> np.mean(a)
+    array(0.55)
+    >>> np.mean(a, dtype=np.float64)
+    array(0.55)
+    """
+    return _npi.mean(a, axis=axis, dtype=dtype, keepdims=keepdims, out=out)
+
+
+@set_module('mxnet.ndarray.numpy')
 def stack(arrays, axis=0, out=None):
     """Join a sequence of arrays along a new axis.
 
@@ -188,7 +251,7 @@ def stack(arrays, axis=0, out=None):
 
     Parameters
     ----------
-    arrays : sequence of array_like
+    arrays : sequence of ndarrays
         Each array must have the same shape.
     axis : int, optional
         The axis in the result array along which the input arrays are stacked.
@@ -198,8 +261,36 @@ def stack(arrays, axis=0, out=None):
 
     Returns
     -------
-    stacked : ndarray
-        The stacked array has one more dimension than the input arrays."""
+    out : ndarray
+        The stacked array has one more dimension than the input arrays.
+
+    Notes
+    -----
+    This function differs from the original `numpy.stack
+    <https://docs.scipy.org/doc/numpy/reference/generated/numpy.stack.html>`_ in
+    the following ways:
+
+    - only sequence of ndarray is accepted as valid input
+
+    Examples
+    --------
+    >>> arrays = [np.random.uniform(size=(3, 4)) for _ in range(10)]
+    >>> np.stack(arrays, axis=0).shape
+    (10, 3, 4)
+    >>> np.stack(arrays, axis=1).shape
+    (3, 10, 4)
+    >>> np.stack(arrays, axis=2).shape
+    (3, 4, 10)
+    >>> a = np.array([1, 2, 3])
+    >>> b = np.array([2, 3, 4])
+    >>> np.stack((a, b))
+    array([[1., 2., 3.],
+           [2., 3., 4.]])
+    >>> np.stack((a, b), axis=-1)
+    array([[1., 2.],
+           [2., 3.],
+           [3., 4.]])
+    """
     def get_list(arrays):
         if not hasattr(arrays, '__getitem__') and hasattr(arrays, '__iter__'):
             raise ValueError("expected iterable for arrays but got {}".format(type(arrays)))
@@ -607,6 +698,7 @@ def expand_dims(a, axis):
     return _npi.expand_dims(a, axis)
 
 
+# pylint: disable=line-too-long
 @set_module('mxnet.ndarray.numpy')
 def split(ary, indices_or_sections, axis=0):
     """Split an array into multiple sub-arrays.
@@ -628,8 +720,7 @@ def split(ary, indices_or_sections, axis=0):
           - ary[2:3]
           - ary[3:]
 
-        If an index exceeds the dimension of the array along `axis`,
-        an empty sub-array is returned correspondingly.
+        Index `must be within` the dimension of the array along `axis`.
     axis : int, optional
         The axis along which to split, default is 0.
 
@@ -643,6 +734,22 @@ def split(ary, indices_or_sections, axis=0):
     ValueError
         If `indices_or_sections` is given as an integer, but
         a split does not result in equal division.
+
+    Notes
+    -----
+    This function differs from the original `numpy.split
+    <https://docs.scipy.org/doc/numpy/reference/generated/numpy.split.html>`_ in
+    the following ways:
+
+    - Index exceeding the dimension the dimension of the array is currently not supported.
+
+    Examples
+    --------
+    >>> x = np.arange(9.0)
+    >>> np.split(x, 3)
+    [array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7., 8.])]
+    >>> np.split(x, (3, 5, 6))
+    [array([0., 1., 2.]), array([3., 4.]), array([5.]), array([6., 7.])]
     """
     indices = []
     axis_size = ary.shape[axis]
@@ -660,6 +767,7 @@ def split(ary, indices_or_sections, axis=0):
     if not isinstance(ret, list):
         raise NotImplementedError('single output from split is not supported yet...')
     return ret
+# pylint: enable=line-too-long
 
 
 @set_module('mxnet.ndarray.numpy')
@@ -1267,3 +1375,122 @@ def degrees(x, out=None, **kwargs):
 
     """
     return _unary_func_helper(x, _npi.degrees, _np.degrees, out=out, **kwargs)
+
+
+@set_module('mxnet.ndarray.numpy')
+def rint(x, out=None, **kwargs):
+    """
+    Round elements of the array to the nearest integer.
+
+    Parameters
+    ----------
+    x : ndarray or scalar
+        Input array.
+    out : ndarray or None
+        A location into which the result is stored.
+        If provided, it must have the same shape and type as the input.
+        If not provided or None, a freshly-allocated array is returned.
+
+    Returns
+    -------
+    out : ndarray or scalar
+        Output array is same shape and type as x. This is a scalar if x is a scalar.
+
+    Notes
+    -----
+    This function differs from the original `numpy.rint
+    <https://docs.scipy.org/doc/numpy/reference/generated/numpy.rint.html>`_ in
+    the following way(s):
+
+    - only ndarray or scalar is accpted as valid input, tuple of ndarray is not supported
+    - broadcasting to `out` of different shape is currently not supported
+    - when input is plain python numerics, the result will not be stored in the `out` param
+
+    Examples
+    --------
+    >>> a = np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0])
+    >>> np.rint(a)
+    array([-2., -2., -0.,  0.,  1.,  2.,  2.])
+    """
+    return _unary_func_helper(x, _npi.rint, _np.rint, out=out, **kwargs)
+
+
+@set_module('mxnet.ndarray.numpy')
+def log2(x, out=None, **kwargs):
+    """
+    Base-2 logarithm of x.
+
+    Parameters
+    ----------
+    x : ndarray or scalar
+        Input values.
+    out : ndarray or None
+        A location into which the result is stored.
+        If provided, it must have the same shape and type as the input.
+        If not provided or None, a freshly-allocated array is returned.
+
+    Returns
+    -------
+    y : ndarray
+        The logarithm base two of `x`, element-wise.
+        This is a scalar if `x` is a scalar.
+
+    Notes
+    -----
+    This function differs from the original `numpy.log2
+    <https://www.google.com/search?q=numpy+log2>`_ in
+    the following way(s):
+
+    - only ndarray or scalar is accpted as valid input, tuple of ndarray is not supported
+    - broadcasting to `out` of different shape is currently not supported
+    - when input is plain python numerics, the result will not be stored in the `out` param
+
+    Examples
+    --------
+    >>> x = np.array([0, 1, 2, 2**4])
+    >>> np.log2(x)
+    array([-inf,   0.,   1.,   4.])
+
+    """
+    return _unary_func_helper(x, _npi.log2, _np.log2, out=out, **kwargs)
+
+
+@set_module('mxnet.ndarray.numpy')
+def radians(x, out=None, **kwargs):
+    """
+    Convert angles from degrees to radians.
+
+    Parameters
+    ----------
+    x : ndarray or scalar
+        Input array in degrees.
+    out : ndarray or None
+        A location into which the result is stored.
+        If provided, it must have the same shape and type as the input.
+        If not provided or None, a freshly-allocated array is returned.
+
+    Returns
+    -------
+    y : ndarray
+        The corresponding radian values. This is a scalar if x is a scalar.
+
+    Notes
+    -----
+    This function differs from the original `numpy.radians
+    <https://docs.scipy.org/doc/numpy/reference/generated/numpy.radians.html>`_ in
+    the following way(s):
+
+    - only ndarray or scalar is accpted as valid input, tuple of ndarray is not supported
+    - broadcasting to `out` of different shape is currently not supported
+    - when input is plain python numerics, the result will not be stored in the `out` param
+
+    Examples
+    --------
+    >>> deg = np.arange(12.) * 30.
+    >>> np.radians(deg)
+    array([0.       , 0.5235988, 1.0471976, 1.5707964, 2.0943952, 2.6179938,
+           3.1415927, 3.6651914, 4.1887903, 4.712389 , 5.2359877, 5.7595863],
+           dtype=float32)
+
+    """
+    return _unary_func_helper(x, _npi.radians, _np.radians, out=out, **kwargs)
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index db7b084..5e26ff6 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -47,7 +47,7 @@ __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', '
            'argmax', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate',
            'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace', 'sin', 'cos',
            'sin', 'cos', 'sinh', 'cosh', 'log10', 'sqrt', 'abs', 'exp', 'arctan', 'sign', 'log',
-           'degrees']
+           'degrees', 'log2', 'rint', 'radians', 'mean']
 
 
 # This function is copied from ndarray.py since pylint
@@ -927,7 +927,7 @@ class ndarray(NDArray):
 
     def mean(self, axis=None, dtype=None, out=None, keepdims=False):  # pylint: disable=arguments-differ
         """Returns the average of the array elements along given axis."""
-        return _mx_np_op.mean(self, axis=axis, dtype=dtype, keepdims=keepdims, out=out)
+        return _npi.mean(self, axis=axis, dtype=dtype, keepdims=keepdims, out=out)
 
     # TODO(junwu): Use mxnet std op instead of onp.std
     def std(self, axis=None, dtype=None, out=None, ddof=0, keepdims=False):  # pylint: disable=arguments-differ
@@ -1447,6 +1447,68 @@ def minimum(x1, x2, out=None):
 
 
 @set_module('mxnet.numpy')
+def mean(a, axis=None, dtype=None, out=None, keepdims=False):  # pylint: disable=arguments-differ
+    """
+    mean(a, axis=None, dtype=None, out=None, keepdims=None)
+
+    Compute the arithmetic mean along the specified axis.
+    Returns the average of the array elements.
+    The average is taken over the flattened array by default, otherwise over the specified axis.
+
+    Parameters
+    ----------
+    a : ndarray
+        ndarray containing numbers whose mean is desired.
+    axis : None or int or tuple of ints, optional
+        Axis or axes along which the means are computed. The default is to compute the mean of the flattened array.
+        If this is a tuple of ints, a mean is performed over multiple axes,
+        instead of a single axis or all the axes as before.
+    dtype : data-type, optional
+        Type to use in computing the mean. For integer inputs, the default is float32;
+        for floating point inputs, it is the same as the input dtype.
+    out : ndarray, optional
+        Alternate output array in which to place the result. The default is None; if provided,
+        it must have the same shape and type as the expected output.
+    keepdims : bool, optional
+        If this is set to True, the axes which are reduced are left in the result
+        as dimensions with size one. With this option, the result will broadcast correctly
+        against the input array.
+        If the default value is passed, then keepdims will not be passed through to the mean
+        method of sub-classes of ndarray, however any non-default value will be. If the sub-class
+        method does not implement keepdims any exceptions will be raised.
+
+    Returns
+    -------
+    m : ndarray, see dtype parameter above
+        If out=None, returns a new array containing the mean values,
+        otherwise a reference to the output array is returned.
+
+    Notes
+    -----
+    This function differs from the original `numpy.mean
+    <https://docs.scipy.org/doc/numpy/reference/generated/numpy.mean.html>`_ in
+    the following way(s):
+
+    - only ndarray is accepted as valid input, python iterables or scalar is not supported
+    - default data type for integer input is float32
+
+    Examples
+    --------
+    >>> a = np.array([[1, 2], [3, 4]])
+    >>> np.mean(a)
+    array(2.5)
+    >>> a = np.zeros((2, 512*512), dtype=np.float32)
+    >>> a[0,:] = 1.0
+    >>> a[1,:] = 0.1
+    >>> np.mean(a)
+    array(0.55)
+    >>> np.mean(a, dtype=np.float64)
+    array(0.55)
+    """
+    return _npi.mean(a, axis=axis, dtype=dtype, keepdims=keepdims, out=out)
+
+
+@set_module('mxnet.numpy')
 def stack(arrays, axis=0, out=None):
     """Join a sequence of arrays along a new axis.
 
@@ -1455,18 +1517,46 @@ def stack(arrays, axis=0, out=None):
 
     Parameters
     ----------
-    arrays : sequence of array_like
+    arrays : sequence of ndarrays
         Each array must have the same shape.
     axis : int, optional
         The axis in the result array along which the input arrays are stacked.
     out : ndarray, optional
-        If provided, the destination to place the result. The shape must be correct,
-        matching that of what stack would have returned if no out argument were specified.
+        If provided, the destination to place the result. The shape and type must be the
+        same with that of what stack would have returned if no out argument were specified.
 
     Returns
     -------
-    stacked : ndarray
-        The stacked array has one more dimension than the input arrays."""
+    out : ndarray
+        The stacked array has one more dimension than the input arrays.
+
+    Notes
+    -----
+    This function differs from the original `numpy.stack
+    <https://docs.scipy.org/doc/numpy/reference/generated/numpy.stack.html>`_ in
+    the following way(s):
+
+    - only sequence of ndarray is accepted as valid input
+
+    Examples
+    --------
+    >>> arrays = [np.random.uniform(size=(3, 4)) for _ in range(10)]
+    >>> np.stack(arrays, axis=0).shape
+    (10, 3, 4)
+    >>> np.stack(arrays, axis=1).shape
+    (3, 10, 4)
+    >>> np.stack(arrays, axis=2).shape
+    (3, 4, 10)
+    >>> a = np.array([1, 2, 3])
+    >>> b = np.array([2, 3, 4])
+    >>> np.stack((a, b))
+    array([[1., 2., 3.],
+           [2., 3., 4.]])
+    >>> np.stack((a, b), axis=-1)
+    array([[1., 2.],
+           [2., 3.],
+           [3., 4.]])
+    """
     return _mx_nd_np.stack(arrays, axis=axis, out=out)
 
 
@@ -1845,6 +1935,7 @@ def expand_dims(a, axis):
     return _npi.expand_dims(a, axis)
 
 
+# pylint: disable=line-too-long
 @set_module('mxnet.numpy')
 def split(ary, indices_or_sections, axis=0):
     """Split an array into multiple sub-arrays.
@@ -1866,8 +1957,7 @@ def split(ary, indices_or_sections, axis=0):
           - ary[2:3]
           - ary[3:]
 
-        If an index exceeds the dimension of the array along `axis`,
-        an empty sub-array is returned correspondingly.
+        Index `must be within` the dimension of the array along `axis`.
     axis : int, optional
         The axis along which to split, default is 0.
 
@@ -1880,8 +1970,26 @@ def split(ary, indices_or_sections, axis=0):
     ------
     ValueError
         If `indices_or_sections` is given as an integer, but
-        a split does not result in equal division."""
+        a split does not result in equal division.
+
+    Notes
+    -----
+    This function differs from the original `numpy.split
+    <https://docs.scipy.org/doc/numpy/reference/generated/numpy.split.html>`_ in
+    the following ways:
+
+    - Index exceeding the dimension the dimension of the array is currently not supported.
+
+    Examples
+    --------
+    >>> x = np.arange(9.0)
+    >>> np.split(x, 3)
+    [array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7., 8.])]
+    >>> np.split(x, (3, 5, 6))
+    [array([0., 1., 2.]), array([3., 4.]), array([5.]), array([6., 7.])]
+    """
     return _mx_nd_np.split(ary, indices_or_sections, axis=axis)
+# pylint: enable=line-too-long
 
 
 @set_module('mxnet.numpy')
@@ -2089,7 +2197,6 @@ def sqrt(x, out=None, **kwargs):
     return _mx_nd_np.sqrt(x, out=out, **kwargs)
 
 
-
 @set_module('mxnet.numpy')
 def tile(A, reps):
     r"""
@@ -2271,6 +2378,7 @@ def arctan(x, out=None, **kwargs):
     """
     return _mx_nd_np.arctan(x, out=out, **kwargs)
 
+
 @set_module('mxnet.numpy')
 def sign(x, out=None):
     """
@@ -2328,7 +2436,7 @@ def sign(x, out=None):
     return _mx_nd_np.sign(x, out=out)
 
 
-@set_module('mxnet.symbol.numpy')
+@set_module('mxnet.numpy')
 def log(x, out=None, **kwargs):
     """
     log(x, out=None)
@@ -2375,6 +2483,7 @@ def log(x, out=None, **kwargs):
     >>> np.log(a)
     array([  0.,   1.,   2., -inf], dtype=float64)
 
+
     Due to internal calculation mechanism, using default float32 dtype may cause some special behavior:
 
     >>> a = np.array([1, np.exp(1), np.exp(2), 0])
@@ -2390,7 +2499,85 @@ def log(x, out=None, **kwargs):
     return _mx_nd_np.log(x, out=out, **kwargs)
 
 
-@set_module('mxnet.symbol.numpy')
+@set_module('mxnet.numpy')
+def rint(x, out=None, **kwargs):
+    """
+    Round elements of the array to the nearest integer.
+
+    Parameters
+    ----------
+    x : ndarray or scalar
+        Input array.
+    out : ndarray or None
+        A location into which the result is stored.
+        If provided, it must have the same shape and type as the input.
+        If not provided or None, a freshly-allocated array is returned.
+
+    Returns
+    -------
+    out : ndarray or scalar
+        Output array is same shape and type as x. This is a scalar if x is a scalar.
+
+    Notes
+    -----
+    This function differs from the original `numpy.rint
+    <https://docs.scipy.org/doc/numpy/reference/generated/numpy.rint.html>`_ in
+    the following way(s):
+
+    - only ndarray or scalar is accpted as valid input, tuple of ndarray is not supported
+    - broadcasting to `out` of different shape is currently not supported
+    - when input is plain python numerics, the result will not be stored in the `out` param
+
+    Examples
+    --------
+    >>> a = np.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0])
+    >>> np.rint(a)
+    array([-2., -2., -0.,  0.,  1.,  2.,  2.])
+    """
+    return _mx_nd_np.rint(x, out=out, **kwargs)
+
+
+@set_module('mxnet.numpy')
+def log2(x, out=None, **kwargs):
+    """
+    Base-2 logarithm of x.
+
+    Parameters
+    ----------
+    x : ndarray or scalar
+        Input values.
+    out : ndarray or None
+        A location into which the result is stored.
+        If provided, it must have the same shape and type as the input.
+        If not provided or None, a freshly-allocated array is returned.
+
+    Returns
+    -------
+    y : ndarray
+        The logarithm base two of `x`, element-wise.
+        This is a scalar if `x` is a scalar.
+
+    Notes
+    -----
+    This function differs from the original `numpy.log2
+    <https://www.google.com/search?q=numpy+log2>`_ in
+    the following way(s):
+
+    - only ndarray or scalar is accpted as valid input, tuple of ndarray is not supported
+    - broadcasting to `out` of different shape is currently not supported
+    - when input is plain python numerics, the result will not be stored in the `out` param
+
+    Examples
+    --------
+    >>> x = np.array([0, 1, 2, 2**4])
+    >>> np.log2(x)
+    array([-inf,   0.,   1.,   4.])
+
+    """
+    return _mx_nd_np.log2(x, out=out, **kwargs)
+
+
+@set_module('mxnet.numpy')
 def degrees(x, out=None, **kwargs):
     """
     degrees(x, out=None)
@@ -2442,3 +2629,44 @@ def degrees(x, out=None, **kwargs):
 
     """
     return _mx_nd_np.degrees(x, out=out, **kwargs)
+
+
+@set_module('mxnet.numpy')
+def radians(x, out=None, **kwargs):
+    """
+    Convert angles from degrees to radians.
+
+    Parameters
+    ----------
+    x : ndarray or scalar
+        Input array in degrees.
+    out : ndarray or None
+        A location into which the result is stored.
+        If provided, it must have the same shape and type as the input.
+        If not provided or None, a freshly-allocated array is returned.
+
+    Returns
+    -------
+    y : ndarray
+        The corresponding radian values. This is a scalar if x is a scalar.
+
+    Notes
+    -----
+    This function differs from the original `numpy.radians
+    <https://docs.scipy.org/doc/numpy/reference/generated/numpy.radians.html>`_ in
+    the following way(s):
+
+    - only ndarray or scalar is accpted as valid input, tuple of ndarray is not supported
+    - broadcasting to `out` of different shape is currently not supported
+    - when input is plain python numerics, the result will not be stored in the `out` param
+
+    Examples
+    --------
+    >>> deg = np.arange(12.) * 30.
+    >>> np.radians(deg)
+    array([0.       , 0.5235988, 1.0471976, 1.5707964, 2.0943952, 2.6179938,
+           3.1415927, 3.6651914, 4.1887903, 4.712389 , 5.2359877, 5.7595863],
+           dtype=float32)
+
+    """
+    return _mx_nd_np.radians(x, out=out, **kwargs)
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index efdbf51..e499d8e 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -32,7 +32,7 @@ from . import _internal as _npi
 __all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'concatenate', 'arange', 'argmax',
            'clip', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'split', 'swapaxes',
            'expand_dims', 'tile', 'linspace', 'sin', 'cos', 'sinh', 'cosh', 'log10', 'sqrt',
-           'abs', 'exp', 'arctan', 'sign', 'log', 'degrees']
+           'abs', 'exp', 'arctan', 'sign', 'log', 'degrees', 'log2', 'rint', 'radians', 'mean']
 
 
 def _num_outputs(sym):
@@ -534,7 +534,7 @@ class _Symbol(Symbol):
         The arguments are the same as for :py:func:`mean`, with
         this array as data.
         """
-        return _mx_np_op.mean(self, axis=axis, dtype=dtype, keepdims=keepdims, out=out)
+        return _npi.mean(self, axis=axis, dtype=dtype, keepdims=keepdims, out=out)
 
     def cumsum(self, axis=None, dtype=None, out=None):
         """Return the cumulative sum of the elements along the given axis."""
@@ -1022,27 +1022,115 @@ def power(x1, x2, out=None):
 
 
 @set_module('mxnet.symbol.numpy')
+def mean(a, axis=None, dtype=None, out=None, keepdims=False):  # pylint: disable=arguments-differ
+    """
+    mean(a, axis=None, dtype=None, out=None, keepdims=None)
+
+    Compute the arithmetic mean along the specified axis.
+    Returns the average of the array elements.
+    The average is taken over the flattened array by default, otherwise over the specified axis.
+
+    Parameters
+    ----------
+    a : `_Symbol`
+        _Symbol containing numbers whose mean is desired.
+    axis : None or int or tuple of ints, optional
+        Axis or axes along which the means are computed. The default is to compute the mean of the flattened array.
+        If this is a tuple of ints, a mean is performed over multiple axes,
+        instead of a single axis or all the axes as before.
+    dtype : data-type, optional
+        Type to use in computing the mean. For integer inputs, the default is float32;
+        for floating point inputs, it is the same as the input dtype.
+    out : _Symbol, optional
+        Dummy parameter to keep the consistency with the ndarray counterpart.
+    keepdims : bool, optional
+        If this is set to True, the axes which are reduced are left in the result
+        as dimensions with size one. With this option, the result will broadcast correctly
+        against the input array.
+        If the default value is passed, then keepdims will not be passed through to the mean
+        method of sub-classes of _Symbol, however any non-default value will be. If the sub-class
+        method does not implement keepdims any exceptions will be raised.
+
+    Returns
+    -------
+    m : _Symbol, see dtype parameter above
+        If out=None, returns a new array containing the mean values,
+        otherwise a reference to the output array is returned.
+
+    Notes
+    -----
+    This function differs from the original `numpy.mean
+    <https://docs.scipy.org/doc/numpy/reference/generated/numpy.mean.html>`_ in
+    the following way(s):
+
+    - only _Symbol is accepted as valid input, python iterables or scalar is not supported
+    - default data type for integer input is float32
+
+    Examples
+    --------
+    >>> a = np.array([[1, 2], [3, 4]])
+    >>> np.mean(a)
+    array(2.5)
+    >>> a = np.zeros((2, 512*512), dtype=np.float32)
+    >>> a[0,:] = 1.0
+    >>> a[1,:] = 0.1
+    >>> np.mean(a)
+    array(0.55)
+    >>> np.mean(a, dtype=np.float64)
+    array(0.55)
+    """
+    return _npi.mean(a, axis=axis, dtype=dtype, keepdims=keepdims, out=out)
+
+
+@set_module('mxnet.symbol.numpy')
 def stack(arrays, axis=0, out=None):
-    """Join a sequence of arrays along a new axis.
+    """
+    Join a sequence of arrays along a new axis.
 
     The axis parameter specifies the index of the new axis in the dimensions of the result.
-    For example, if `axis=0` it will be the first dimension and if `axis=-1` it will be the last
-    dimension.
+    For example, if `axis=0` it will be the first dimension and if `axis=-1` it will be the last dimension.
 
     Parameters
     ----------
-    arrays : sequence of array_like
+    arrays : sequence of _Symbols
         Each array must have the same shape.
     axis : int, optional
         The axis in the result array along which the input arrays are stacked.
-    out : ndarray, optional
-        If provided, the destination to place the result. The shape must be correct,
-        matching that of what stack would have returned if no out argument were specified.
+    out : _Symbol, optional
+        Dummy parameter to keep the consistency with the ndarray counterpart.
 
     Returns
     -------
-    stacked : ndarray
-        The stacked array has one more dimension than the input arrays."""
+    out : _Symbol
+        The stacked array has one more dimension than the input arrays.
+
+    Notes
+    -----
+    This function differs from the original `numpy.stack
+    <https://docs.scipy.org/doc/numpy/reference/generated/numpy.stack.html>`_ in
+    the following ways:
+
+    - only sequence of _Symbol is accepted as valid input
+
+    Examples
+    --------
+    >>> arrays = [np.random.uniform(size=(3, 4)) for _ in range(10)]
+    >>> np.stack(arrays, axis=0).shape
+    (10, 3, 4)
+    >>> np.stack(arrays, axis=1).shape
+    (3, 10, 4)
+    >>> np.stack(arrays, axis=2).shape
+    (3, 4, 10)
+    >>> a = np.array([1, 2, 3])
+    >>> b = np.array([2, 3, 4])
+    >>> np.stack((a, b))
+    array([[1., 2., 3.],
+           [2., 3., 4.]])
+    >>> np.stack((a, b), axis=-1)
+    array([[1., 2.],
+           [2., 3.],
+           [3., 4.]])
+    """
     def get_list(arrays):
         if not hasattr(arrays, '__getitem__') and hasattr(arrays, '__iter__'):
             raise ValueError("expected iterable for arrays but got {}".format(type(arrays)))
@@ -1261,13 +1349,14 @@ def expand_dims(a, axis):
     return _npi.expand_dims(a, axis)
 
 
+# pylint: disable=line-too-long
 @set_module('mxnet.symbol.numpy')
 def split(ary, indices_or_sections, axis=0):
     """Split an array into multiple sub-arrays.
 
     Parameters
     ----------
-    ary : ndarray
+    ary : _Symbol
         Array to be divided into sub-arrays.
     indices_or_sections : int or 1-D array
         If `indices_or_sections` is an integer, N, the array will be divided
@@ -1282,21 +1371,37 @@ def split(ary, indices_or_sections, axis=0):
           - ary[2:3]
           - ary[3:]
 
-        If an index exceeds the dimension of the array along `axis`,
-        an empty sub-array is returned correspondingly.
+        Index `must be within` the dimension of the array along `axis`.
     axis : int, optional
         The axis along which to split, default is 0.
 
     Returns
     -------
-    sub-arrays : list of ndarrays
+    sub-arrays : list of _Symbols
         A list of sub-arrays.
 
     Raises
     ------
     ValueError
         If `indices_or_sections` is given as an integer, but
-        a split does not result in equal division."""
+        a split does not result in equal division.
+
+    Notes
+    -----
+    This function differs from the original `numpy.split
+    <https://docs.scipy.org/doc/numpy/reference/generated/numpy.split.html>`_ in
+    the following ways:
+
+    - Index exceeding the dimension the dimension of the array is currently not supported.
+
+    Examples
+    --------
+    >>> x = np.arange(9.0)
+    >>> np.split(x, 3)
+    [array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7., 8.])]
+    >>> np.split(x, (3, 5, 6))
+    [array([0., 1., 2.]), array([3., 4.]), array([5.]), array([6., 7.])]
+    """
     indices = []
     sections = 0
     if isinstance(indices_or_sections, int):
@@ -1307,6 +1412,7 @@ def split(ary, indices_or_sections, axis=0):
         raise ValueError('indices_or_sections must either int or tuple of ints')
     ret = _npi.split(ary, indices, axis, False, sections)
     return ret
+# pylint: enable=line-too-long
 
 
 @set_module('mxnet.symbol.numpy')
@@ -1771,4 +1877,107 @@ def degrees(x, out=None, **kwargs):
     return _unary_func_helper(x, _npi.degrees, _np.degrees, out=out, **kwargs)
 
 
+def rint(x, out=None, **kwargs):
+    """
+    Round elements of the array to the nearest integer.
+
+    Parameters
+    ----------
+    x : _Symbol or scalar
+        Input array.
+    out : _Symbol or None
+        Dummy parameter to keep the consistency with the ndarray counterpart.
+
+    Returns
+    -------
+    out : _Symbol or scalar
+        Output array is same shape and type as x. This is a scalar if x is a scalar.
+
+    Notes
+    -----
+    This function differs from the original `numpy.rint
+    <https://docs.scipy.org/doc/numpy/reference/generated/numpy.rint.html>`_ in
+    the following way(s):
+
+    - only _Symbol or scalar is accpted as valid input, tuple of _Symbol is not supported
+     - broadcasting to `out` of different shape is currently not supported
+    - when input is plain python numerics, the result will not be stored in the `out` param
+
+    """
+    return _unary_func_helper(x, _npi.rint, _np.rint, out=out, **kwargs)
+
+
+@set_module('mxnet.symbol.numpy')
+def log2(x, out=None, **kwargs):
+    """
+    Base-2 logarithm of x.
+
+    Parameters
+    ----------
+    x : _Symbol
+        Input values.
+    out : ndarray or None
+        A location into which the result is stored.
+        If provided, it must have the same shape and type as the input.
+        If not provided or None, a freshly-allocated array is returned.
+
+    Returns
+    -------
+    y : _Symbol
+        The logarithm base two of `x`, element-wise.
+        This is a scalar if `x` is a scalar.
+
+    Notes
+    -----
+    This function differs from the original `numpy.log2
+    <https://www.google.com/search?q=numpy+log2>`_ in
+    the following way(s):
+
+    - only ndarray or scalar is accpted as valid input, tuple of ndarray is not supported
+    - broadcasting to `out` of different shape is currently not supported
+    - when input is plain python numerics, the result will not be stored in the `out` param
+
+    """
+    return _unary_func_helper(x, _npi.log2, _np.log2, out=out, **kwargs)
+
+
+@set_module('mxnet.symbol.numpy')
+def radians(x, out=None, **kwargs):
+    """
+    Convert angles from degrees to radians.
+
+    Parameters
+    ----------
+    x : _Symbol or scalar
+        Input array in degrees.
+    out : _Symbol or None
+       Dummy parameter to keep the consistency with the ndarray counterpart.
+
+    Returns
+    -------
+    y : _Symbol
+        The corresponding radian values. This is a scalar if x is a scalar.
+
+    Notes
+    -----
+    This function differs from the original `numpy.radians
+    <https://docs.scipy.org/doc/numpy/reference/generated/numpy.radians.html>`_ in
+    the following way(s):
+
+    - only _Symbol or scalar is accpted as valid input, tuple of _Symbol is not supported
+    - broadcasting to `out` of different shape is currently not supported
+    - when input is plain python numerics, the result will not be stored in the `out` param
+
+    Examples
+    --------
+    >>> deg = np.arange(12.) * 30.
+    >>> np.radians(deg)
+    array([0.       , 0.5235988, 1.0471976, 1.5707964, 2.0943952, 2.6179938,
+           3.1415927, 3.6651914, 4.1887903, 4.712389 , 5.2359877, 5.7595863],
+           dtype=float32)
+
+    """
+    return _unary_func_helper(x, _npi.radians, _np.radians, out=out, **kwargs)
+
+
 _set_np_symbol_class(_Symbol)
diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc
index d8234c5..9cf5c21 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_value.cc
+++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc
@@ -102,7 +102,7 @@ inline bool NumpyMeanType(const nnvm::NodeAttrs& attrs,
   return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
 }
 
-NNVM_REGISTER_OP(_np_mean)
+NNVM_REGISTER_OP(_npi_mean)
 .set_num_inputs(1)
 .set_num_outputs(1)
 .set_attr_parser(ParamParser<NumpyReduceAxesParam>)
diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cu b/src/operator/numpy/np_broadcast_reduce_op_value.cu
index a0a6472..6e18ebc 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_value.cu
+++ b/src/operator/numpy/np_broadcast_reduce_op_value.cu
@@ -33,7 +33,7 @@ NNVM_REGISTER_OP(_np_sum)
 NNVM_REGISTER_OP(_backward_np_sum)
 .set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesBackwardUseNone<gpu>);
 
-NNVM_REGISTER_OP(_np_mean)
+NNVM_REGISTER_OP(_npi_mean)
 .set_attr<FCompute>("FCompute<gpu>", NumpyReduceAxesCompute<gpu, mshadow_op::sum, true, true>);
 
 NNVM_REGISTER_OP(_backward_np_mean)
diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cc b/src/operator/numpy/np_elemwise_unary_op_basic.cc
index 3ff4400..f98f7df 100644
--- a/src/operator/numpy/np_elemwise_unary_op_basic.cc
+++ b/src/operator/numpy/np_elemwise_unary_op_basic.cc
@@ -121,7 +121,7 @@ Example::
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_sign"});
 
 // rint
-MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_np_rint, "x", mshadow_op::rint)
+MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_npi_rint, "x", mshadow_op::rint)
 .describe(R"code(Round elements of the array to the nearest integer.
 Example::
    rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) = [-2., -2., -0.,  0.,  2.,  2.,  2.]
@@ -227,7 +227,7 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_npi_log10, "x", mshadow_op::log10)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_log10"});
 
 // log2
-MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_np_log2, "x", mshadow_op::log2)
+MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_npi_log2, "x", mshadow_op::log2)
 .describe(R"code(Returns element-wise Base-2 logarithmic value of the input.
 ``2**log2(x) = x``
 )code" ADD_FILELINE)
@@ -314,7 +314,7 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_npi_degrees, "x", mshadow_op::degrees)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_degrees" });
 
 // radians
-MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_np_radians, "x", mshadow_op::radians)
+MXNET_OPERATOR_REGISTER_NUMPY_UNARY(_npi_radians, "x", mshadow_op::radians)
 .describe(R"code(Converts each element of the input array from degrees to radians.
 .. math::
    radians([0, 90, 180, 270, 360]) = [0, \pi/2, \pi, 3\pi/2, 2\pi]
diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cu b/src/operator/numpy/np_elemwise_unary_op_basic.cu
index de9416e..bc04b38 100644
--- a/src/operator/numpy/np_elemwise_unary_op_basic.cu
+++ b/src/operator/numpy/np_elemwise_unary_op_basic.cu
@@ -47,7 +47,7 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_absolute, mshadow_op::abs);
 
 MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_sign, mshadow_op::sign);
 
-MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_rint, mshadow_op::rint);
+MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_rint, mshadow_op::rint);
 
 MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_ceil, mshadow_op::ceil);
 
@@ -70,7 +70,7 @@ NNVM_REGISTER_OP(_npi_log)
 
 MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_log10, mshadow_op::log10);
 
-MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_log2, mshadow_op::log2);
+MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_log2, mshadow_op::log2);
 
 MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_log1p, mshadow_op::log1p);
 
@@ -92,7 +92,7 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_arctan, mshadow_op::arctan);
 
 MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_degrees, mshadow_op::degrees);
 
-MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_np_radians, mshadow_op::radians);
+MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_radians, mshadow_op::radians);
 
 MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_sinh, mshadow_op::sinh);