You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2019/07/10 06:47:16 UTC

[incubator-mxnet] branch numpy updated: [Numpy] Numpy compatible argsort (#15501)

This is an automated email from the ASF dual-hosted git repository.

reminisce pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/numpy by this push:
     new 1bff19b  [Numpy] Numpy compatible argsort (#15501)
1bff19b is described below

commit 1bff19ba8edb9b8a7756217ef6a2ad65c5584e98
Author: Mike <ma...@connect.hku.hk>
AuthorDate: Wed Jul 10 14:46:47 2019 +0800

    [Numpy] Numpy compatible argsort (#15501)
    
    * Add numpy compatible argsort
    
    * Minor syntax fix
---
 python/mxnet/ndarray/numpy/_op.py      |  62 +++++++++++++++++-
 python/mxnet/numpy/multiarray.py       |  70 +++++++++++++++++++-
 python/mxnet/symbol/numpy/_symbol.py   | 115 +++++++++++++++++++++++++++++++--
 src/operator/tensor/ordering_op-inl.h  |  63 +++++++++++++-----
 src/operator/tensor/ordering_op.cc     |   1 +
 tests/python/unittest/test_numpy_op.py |  41 ++++++++++++
 6 files changed, 325 insertions(+), 27 deletions(-)

diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index 282c08a..7f710a0 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -32,7 +32,8 @@ __all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax',
            'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate',
            'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace',
            'sin', 'cos', 'sinh', 'cosh', 'log10', 'sqrt', 'abs', 'exp', 'arctan', 'sign', 'log',
-           'degrees', 'log2', 'rint', 'radians', 'mean', 'reciprocal', 'square', 'arcsin']
+           'degrees', 'log2', 'rint', 'radians', 'mean', 'reciprocal', 'square', 'arcsin',
+           'argsort']
 
 
 @set_module('mxnet.ndarray.numpy')
@@ -426,6 +427,65 @@ def argmax(a, axis=None, out=None):
 
 
 @set_module('mxnet.ndarray.numpy')
+def argsort(a, axis=-1, kind='quicksort', order=None):
+    """
+    Returns the indices that would sort an input array along the given axis.
+    This function performs sorting along the given axis and returns an array
+    of indices having same shape as an input array that index data in sorted order.
+
+    Parameters
+    ----------
+    a : ndarray
+        Input array
+    axis : int, optional
+        The axis along which to sort teh input tensor.
+        If not given, the last, dimension -1 will be used by default.
+        If None, the flattened array is used.
+    kind: {'quicksort'}
+        Currently not supported.
+    order: None
+        Currently not supported.
+
+    Returns
+    -------
+    output : ndarray
+        Array of indices that sort a along the specified axis.
+        If a is one-dimensional, a[index_array] yields a sorted a.
+        More generally, np.take_along_axis(a, index_array, axis=a) always yields the sorted a,
+        irrespective of dimensionality.
+
+    Examples
+    --------
+    >>> x = np.array([3, 1, 2])
+    >>> np.argsort(x)
+    array([1., 2., 0.])
+    >>> x = np.array([[0, 3], [2, 2]])
+    >>> x
+    array([[0., 3.],
+           [2., 2.]])
+    >>> np.argsort(x, axis=0)  # sorts along first axis (down)
+    array([[0., 1.],
+           [1., 0.]])
+    >>> np.argsort(x, axis=1)  # sorts along last axis (across)
+    array([[0., 1.],
+           [0., 1.]])
+
+    Notes
+    -----
+    This function differs from the original `numpy.argsort
+    <https://docs.scipy.org/doc/numpy/reference/generated/numpy.argsort.html>`_ in
+    the following way(s):
+
+    - kind and order are currently not supported
+    """
+    if kind != 'quicksort':
+        raise AttributeError('mxnet.numpy.argsort does not support other sorting methods')
+    if order is not None:
+        raise AttributeError('mxnet.numpy.argsort does not support sorting with fields ordering')
+    return _npi.argsort(a, axis)
+
+
+@set_module('mxnet.ndarray.numpy')
 def concatenate(seq, axis=0, out=None):
     """Join a sequence of arrays along an existing axis.
 
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 513700c..cafc656 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -47,7 +47,8 @@ __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', '
            'argmax', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate',
            'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace', 'sin', 'cos',
            'sin', 'cos', 'sinh', 'cosh', 'log10', 'sqrt', 'abs', 'exp', 'arctan', 'sign', 'log',
-           'degrees', 'log2', 'rint', 'radians', 'mean', 'reciprocal', 'square', 'arcsin']
+           'degrees', 'log2', 'rint', 'radians', 'mean', 'reciprocal', 'square', 'arcsin',
+           'argsort']
 
 
 # This function is copied from ndarray.py since pylint
@@ -779,13 +780,17 @@ class ndarray(NDArray):
         """
         raise AttributeError('mxnet.numpy.ndarray object has no attribute topk')
 
-    def argsort(self, *args, **kwargs):
+    def argsort(self, axis=-1, kind='quicksort', order=None):   # pylint: disable=arguments-differ
         """Convenience fluent method for :py:func:`argsort`.
 
         The arguments are the same as for :py:func:`argsort`, with
         this array as data.
         """
-        raise NotImplementedError
+        if kind != 'quicksort':
+            raise AttributeError('mxnet.numpy.argsort does not support other sorting methods')
+        if order is not None:
+            raise AttributeError('mxnet.numpy.argsort does not support sorting with fields ordering')
+        return _npi.argsort(self, axis)
 
     def argmax_channel(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`argmax_channel`.
@@ -1671,6 +1676,65 @@ def argmax(a, axis=None, out=None):
 
 
 @set_module('mxnet.numpy')
+def argsort(a, axis=-1, kind='quicksort', order=None):
+    """
+    Returns the indices that would sort an input array along the given axis.
+    This function performs sorting along the given axis and returns an array
+    of indices having same shape as an input array that index data in sorted order.
+
+    Parameters
+    ----------
+    a : ndarray
+        Input array
+    axis : int, optional
+        The axis along which to sort teh input tensor.
+        If not given, the last, dimension -1 will be used by default.
+        If None, the flattened array is used.
+    kind: {'quicksort'}
+        Currently not supported.
+    order: None
+        Currently not supported.
+
+    Returns
+    -------
+    output : ndarray
+        Array of indices that sort a along the specified axis.
+        If a is one-dimensional, a[index_array] yields a sorted a.
+        More generally, np.take_along_axis(a, index_array, axis=a) always yields the sorted a,
+        irrespective of dimensionality.
+
+    Examples
+    --------
+    >>> x = np.array([3, 1, 2])
+    >>> np.argsort(x)
+    array([1., 2., 0.])
+    >>> x = np.array([[0, 3], [2, 2]])
+    >>> x
+    array([[0., 3.],
+           [2., 2.]])
+    >>> np.argsort(x, axis=0)  # sorts along first axis (down)
+    array([[0., 1.],
+           [1., 0.]])
+    >>> np.argsort(x, axis=1)  # sorts along last axis (across)
+    array([[0., 1.],
+           [0., 1.]])
+
+    Notes
+    -----
+    This function differs from the original `numpy.argsort
+    <https://docs.scipy.org/doc/numpy/reference/generated/numpy.argsort.html>`_ in
+    the following way(s):
+
+    - kind and order are currently not supported
+    """
+    if kind != 'quicksort':
+        raise AttributeError('mxnet.numpy.argsort does not support other sorting methods')
+    if order is not None:
+        raise AttributeError('mxnet.numpy.argsort does not support sorting with fields ordering')
+    return _npi.argsort(a, axis)
+
+
+@set_module('mxnet.numpy')
 def concatenate(seq, axis=0, out=None):
     """Join a sequence of arrays along an existing axis.
 
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index 233f671..fa47d8d 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -33,7 +33,7 @@ __all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'concatenate', 'arang
            'clip', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'split', 'swapaxes',
            'expand_dims', 'tile', 'linspace', 'sin', 'cos', 'sinh', 'cosh', 'log10', 'sqrt',
            'abs', 'exp', 'arctan', 'sign', 'log', 'degrees', 'log2', 'rint', 'radians', 'mean',
-           'reciprocal', 'square', 'arcsin']
+           'reciprocal', 'square', 'arcsin', 'argsort']
 
 
 def _num_outputs(sym):
@@ -379,13 +379,62 @@ class _Symbol(Symbol):
         """
         raise AttributeError('_Symbol object has no attribute topk')
 
-    def argsort(self, *args, **kwargs):
-        """Convenience fluent method for :py:func:`argsort`.
+    def argsort(self, axis=-1, kind='quicksort', order=None):   # pylint: disable=arguments-differ
+        """
+        Returns the indices that would sort an input array along the given axis.
+        This function performs sorting along the given axis and returns an array
+        of indices having same shape as an input array that index data in sorted order.
+
+        Parameters
+        ----------
+        a : _Symbol
+            Input array
+        axis : int, optional
+            The axis along which to sort teh input tensor.
+            If not given, the last, dimension -1 will be used by default.
+            If None, the flattened array is used.
+        kind: {'quicksort'}
+            Currently not supported.
+        order: None
+            Currently not supported.
+
+        Returns
+        -------
+        output : ndarray
+        Array of indices that sort a along the specified axis.
+        If a is one-dimensional, a[index_array] yields a sorted a.
+        More generally, np.take_along_axis(a, index_array, axis=a) always yields the sorted a,
+        irrespective of dimensionality.
+
+        Examples
+        --------
+        >>> x = np.array([3, 1, 2])
+        >>> np.argsort(x)
+        array([1., 2., 0.])
+        >>> x = np.array([[0, 3], [2, 2]])
+        >>> x
+        array([[0., 3.],
+            [2., 2.]])
+        >>> np.argsort(x, axis=0)  # sorts along first axis (down)
+        array([[0., 1.],
+            [1., 0.]])
+        >>> np.argsort(x, axis=1)  # sorts along last axis (across)
+        array([[0., 1.],
+            [0., 1.]])
 
-        The arguments are the same as for :py:func:`argsort`, with
-        this array as data.
+        Notes
+        -----
+        This function differs from the original `numpy.mean
+        <https://docs.scipy.org/doc/numpy/reference/generated/numpy.argsort.html>`_ in
+        the following way(s):
+
+        - kind and order are currently not supported
         """
-        raise NotImplementedError
+        if kind != 'quicksort':
+            raise AttributeError('mxnet.numpy.argsort does not support other sorting methods')
+        if order is not None:
+            raise AttributeError('mxnet.numpy.argsort does not support sorting with fields ordering')
+        return _npi.argsort(self, axis)
 
     def argmax_channel(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`argmax_channel`.
@@ -1261,6 +1310,60 @@ def argmax(a, axis=None, out=None):
 
 
 @set_module('mxnet.symbol.numpy')
+def argsort(a, axis=-1, kind='quicksort', order=None):
+    """
+    Returns the indices that would sort an input array along the given axis.
+    This function performs sorting along the given axis and returns an array
+    of indices having same shape as an input array that index data in sorted order.
+    Parameters
+    ----------
+    a : _Symbol
+        Input array
+    axis : int, optional
+        The axis along which to sort teh input tensor.
+        If not given, the last, dimension -1 will be used by default.
+        If None, the flattened array is used.
+    kind: {'quicksort'}
+        Currently not supported.
+    order: None
+        Currently not supported.
+    Returns
+    -------
+    output : _Symbol
+        Array of indices that sort a along the specified axis.
+        If a is one-dimensional, a[index_array] yields a sorted a.
+        More generally, np.take_along_axis(a, index_array, axis=a) always yields the sorted a,
+        irrespective of dimensionality.
+    Examples
+    --------
+    >>> x = np.array([3, 1, 2])
+    >>> np.argsort(x)
+    array([1., 2., 0.])
+    >>> x = np.array([[0, 3], [2, 2]])
+    >>> x
+    array([[0., 3.],
+           [2., 2.]])
+    >>> np.argsort(x, axis=0)  # sorts along first axis (down)
+    array([[0., 1.],
+           [1., 0.]])
+    >>> np.argsort(x, axis=1)  # sorts along last axis (across)
+    array([[0., 1.],
+           [0., 1.]])
+    Notes
+    -----
+    This function differs from the original `numpy.argsort
+    <https://docs.scipy.org/doc/numpy/reference/generated/numpy.argsort.html>`_ in
+    the following way(s):
+    - kind and order is currently not supported
+    """
+    if kind != 'quicksort':
+        raise AttributeError('mxnet.numpy.argsort does not support other sorting methods')
+    if order is not None:
+        raise AttributeError('mxnet.numpy.argsort does not support sorting with fields ordering')
+    return _npi.argsort(a, axis)
+
+
+@set_module('mxnet.symbol.numpy')
 def clip(a, a_min, a_max, out=None):
     """clip(a, a_min, a_max, out=None)
 
diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h
index 1dda901..74589e5 100644
--- a/src/operator/tensor/ordering_op-inl.h
+++ b/src/operator/tensor/ordering_op-inl.h
@@ -580,18 +580,38 @@ void ArgSort(const nnvm::NodeAttrs& attrs,
              const std::vector<OpReqType>& req,
              const std::vector<TBlob>& outputs) {
   const ArgSortParam& param = nnvm::get<ArgSortParam>(attrs.parsed);
-  TopKParam topk_param;
-  topk_param.axis = param.axis;
-  topk_param.is_ascend = param.is_ascend;
-  topk_param.k = 0;
-  topk_param.dtype = param.dtype;
-  topk_param.ret_typ = topk_enum::kReturnIndices;
-  MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
-    MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
-      TopKImpl<xpu, DType, IDType>(ctx.run_ctx,
-                                   ctx.requested[0], req, inputs[0], outputs, topk_param);
+
+  if (inputs[0].shape_.ndim() == 0) {
+  // Scalar tensor only accept axis of value 0, -1 or None
+    CHECK(!static_cast<bool>(param.axis) || param.axis.value() == -1 || param.axis.value() == 0)
+      << "Axis can only be -1 or 0 for scalor tensor";
+    MSHADOW_TYPE_SWITCH(param.dtype, DType, {
+      Stream<xpu> *s = ctx.get_stream<xpu>();
+      Tensor<xpu, 1, DType> outdata = outputs[0].get_with_shape<xpu, 1, DType>(Shape1(1), s);
+      ASSIGN_DISPATCH(outdata, OpReqType::kWriteTo, 0);
     });
-  });
+  } else if (inputs[0].shape_.Size() == 0) {
+    // If the input tensor is zero size, only a check on axis is needed
+    if (static_cast<bool>(param.axis)) {
+      int axis = param.axis.value();
+      if (axis < 0) axis += inputs[0].shape_.ndim();
+      CHECK(axis >= 0 && axis < inputs[0].shape_.ndim())
+        << "Axis must be within the range of input tensor's dimension";
+    }
+  } else {
+    TopKParam topk_param;
+    topk_param.axis = param.axis;
+    topk_param.is_ascend = param.is_ascend;
+    topk_param.k = 0;
+    topk_param.dtype = param.dtype;
+    topk_param.ret_typ = topk_enum::kReturnIndices;
+    MXNET_NO_FLOAT16_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      MSHADOW_TYPE_SWITCH(param.dtype, IDType, {
+        TopKImpl<xpu, DType, IDType>(ctx.run_ctx,
+                                     ctx.requested[0], req, inputs[0], outputs, topk_param);
+      });
+    });
+  }
 }
 
 template<typename xpu, typename DType, typename IDType>
@@ -824,12 +844,21 @@ inline bool ArgSortShape(const nnvm::NodeAttrs& attrs,
                          mxnet::ShapeVector *in_attrs,
                          mxnet::ShapeVector *out_attrs) {
   const ArgSortParam& param = nnvm::get<ArgSortParam>(attrs.parsed);
-  TopKParam topk_param;
-  topk_param.axis = param.axis;
-  topk_param.is_ascend = param.is_ascend;
-  topk_param.k = 0;
-  topk_param.ret_typ = topk_enum::kReturnIndices;
-  return TopKShapeImpl(topk_param, in_attrs, out_attrs);
+  CHECK_EQ(in_attrs->size(), 1U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  mxnet::TShape& in_shape = (*in_attrs)[0];
+
+  if (in_shape.ndim() == 0) {
+    mxnet::TShape target_shape({1});
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, target_shape);
+  } else if (!static_cast<bool>(param.axis)) {
+    mxnet::TShape target_shape(Shape1(in_shape.Size()));
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, target_shape);
+  } else {
+    SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_shape);
+  }
+
+  return true;
 }
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/ordering_op.cc b/src/operator/tensor/ordering_op.cc
index f693601..7b5a145 100644
--- a/src/operator/tensor/ordering_op.cc
+++ b/src/operator/tensor/ordering_op.cc
@@ -176,6 +176,7 @@ Examples::
   // flatten and then sort
   argsort(x) = [ 3.,  1.,  5.,  0.,  4.,  2.]
 )code" ADD_FILELINE)
+.add_alias("_npi_argsort")
 .set_num_inputs(1)
 .set_num_outputs(1)
 .set_attr_parser(ParamParser<ArgSortParam>)
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 403ac07..d373419 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -780,6 +780,47 @@ def test_np_argmax():
 
 @with_seed()
 @npx.use_np_shape
+def test_np_argsort():
+    @npx.use_np_shape
+    class TestArgsort(HybridBlock):
+        def __init__(self, axis=-1):
+            super(TestArgsort, self).__init__()
+            self._axis = axis
+
+        def hybrid_forward(self, F, a):
+            return F.np.argsort(a, self._axis)
+
+    shapes = [
+        (), 
+        (1,), 
+        (5,4),
+        (5,0,4),
+        (5,0,0),
+        (0,0,5),
+        (0,0,0),
+        (5,3,4)
+    ] 
+    for hybridize in [True, False]:
+        for shape in shapes:
+            for ax in list(range(len(shape))) + [-1, None]:
+                test_argsort = TestArgsort(ax)
+                if hybridize:
+                    test_argsort.hybridize()
+
+                x = np.random.uniform(size=shape)
+                np_out = _np.argsort(x.asnumpy(), axis=ax)
+                mx_out = test_argsort(x)
+                assert mx_out.shape == np_out.shape
+                assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+
+                # Test imperative once again
+                mx_out = np.argsort(x, axis=ax)
+                np_out = _np.argsort(x.asnumpy(), axis=ax)
+                assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
+
+
+@with_seed()
+@npx.use_np_shape
 def test_np_linalg_norm():
     @npx.use_np
     class TestLinalgNorm(HybridBlock):