You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/14 03:05:48 UTC

[GitHub] eric-haibin-lin closed pull request #8641: Doc updates for sparse operators

eric-haibin-lin closed pull request #8641: Doc updates for sparse operators
URL: https://github.com/apache/incubator-mxnet/pull/8641
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/api/python/ndarray/sparse.md b/docs/api/python/ndarray/sparse.md
index 9b742f4fc5..23c3018c53 100644
--- a/docs/api/python/ndarray/sparse.md
+++ b/docs/api/python/ndarray/sparse.md
@@ -123,8 +123,8 @@ We summarize the interface for each class in the following sections.
     CSRNDArray.copy
     CSRNDArray.copyto
     CSRNDArray.as_in_context
-    CSRNDArray.asnumpy
     CSRNDArray.asscipy
+    CSRNDArray.asnumpy
     CSRNDArray.asscalar
     CSRNDArray.astype
     CSRNDArray.tostype
@@ -139,6 +139,41 @@ We summarize the interface for each class in the following sections.
     CSRNDArray.zeros_like
 ```
 
+### Arithmetic operations
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    CSRNDArray.__add__
+    CSRNDArray.__sub__
+    CSRNDArray.__rsub__
+    CSRNDArray.__neg__
+    CSRNDArray.__mul__
+    CSRNDArray.__div__
+    CSRNDArray.__rdiv__
+```
+
+
+### Array reduction
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    CSRNDArray.sum
+    CSRNDArray.mean
+```
+
+### Powers
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    CSRNDArray.square
+```
+
 ### Indexing
 
 ```eval_rst
@@ -213,6 +248,67 @@ We summarize the interface for each class in the following sections.
     RowSparseNDArray.trunc
 ```
 
+### Arithmetic operations
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    RowSparseNDArray.__add__
+    RowSparseNDArray.__sub__
+    RowSparseNDArray.__rsub__
+    RowSparseNDArray.__neg__
+    RowSparseNDArray.__mul__
+    RowSparseNDArray.__div__
+    RowSparseNDArray.__rdiv__
+```
+
+### Trigonometric functions
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    RowSparseNDArray.sin
+    RowSparseNDArray.tan
+    RowSparseNDArray.arcsin
+    RowSparseNDArray.arctan
+    RowSparseNDArray.degrees
+    RowSparseNDArray.radians
+```
+
+### Hyperbolic functions
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    RowSparseNDArray.sinh
+    RowSparseNDArray.tanh
+    RowSparseNDArray.arcsinh
+    RowSparseNDArray.arctanh
+```
+
+### Exponents and logarithms
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    RowSparseNDArray.expm1
+    RowSparseNDArray.log1p
+```
+
+### Powers
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    RowSparseNDArray.sqrt
+    RowSparseNDArray.square
+```
+
 ### Indexing
 
 ```eval_rst
@@ -221,6 +317,7 @@ We summarize the interface for each class in the following sections.
 
     RowSparseNDArray.__getitem__
     RowSparseNDArray.__setitem__
+    RowSparseNDArray.retain
 ```
 
 ### Lazy evaluation
@@ -232,6 +329,16 @@ We summarize the interface for each class in the following sections.
     RowSparseNDArray.wait_to_read
 ```
 
+### Miscellaneous
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    RowSparseNDArray.clip
+    RowSparseNDArray.sign
+```
+
 ## Array creation routines
 
 ```eval_rst
@@ -280,6 +387,7 @@ We summarize the interface for each class in the following sections.
     elemwise_add
     elemwise_sub
     elemwise_mul
+    elemwise_div
     negative
     dot
     add_n
@@ -311,6 +419,16 @@ We summarize the interface for each class in the following sections.
     arctanh
 ```
 
+### Reduce functions
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    sum
+    mean
+```
+
 ### Rounding
 
 ```eval_rst
@@ -363,6 +481,7 @@ We summarize the interface for each class in the following sections.
 
     make_loss
     stop_gradient
+    mxnet.ndarray.contrib.SparseEmbedding
 ```
 
 ## API Reference
@@ -372,10 +491,10 @@ We summarize the interface for each class in the following sections.
 ```eval_rst
 
 .. autoclass:: mxnet.ndarray.sparse.CSRNDArray
-    :members: shape, context, dtype, stype, data, indices, indptr, copy, copyto, as_in_context, asnumpy, asscalar, astype, tostype, slice, wait_to_read, zeros_like, __getitem__, __setitem__
+    :members: shape, context, dtype, stype, data, indices, indptr, copy, copyto, as_in_context, asscipy, asnumpy, asscalar, astype, tostype, slice, wait_to_read, zeros_like, __add__, __sub__, __rsub__, __neg__, __mul__, __div__, __rdiv__, sum, mean, square, __getitem__, __setitem__
 
 .. autoclass:: mxnet.ndarray.sparse.RowSparseNDArray
-    :members: shape, context, dtype, stype, data, indices, copy, copyto, as_in_context, asnumpy, asscalar, astype, tostype, wait_to_read, zeros_like, round, rint, fix, floor, ceil, trunc, __getitem__, __setitem__
+    :members: shape, context, dtype, stype, data, indices, copy, copyto, as_in_context, asnumpy, asscalar, astype, tostype, wait_to_read, zeros_like, round, rint, fix, floor, ceil, trunc, sin, tan, arcsin, arctan, degrees, radians, sinh, tanh, arcsinh, arctanh, expm1, log1p, sqrt, square, __add__, __sub__, __rsub__, __neg__, __mul__, __div__, __rdiv__, __getitem__, __setitem__, retain, clip, sign
 
 .. automodule:: mxnet.ndarray.sparse
     :members:
diff --git a/docs/api/python/symbol/sparse.md b/docs/api/python/symbol/sparse.md
index 5ebbfcd057..820b35b9c3 100644
--- a/docs/api/python/symbol/sparse.md
+++ b/docs/api/python/symbol/sparse.md
@@ -95,10 +95,108 @@ In the rest of this document, we list sparse related routines provided by the
     :nosignatures:
 
     elemwise_add
+    elemwise_sub
+    elemwise_mul
+    elemwise_div
+    negative
     dot
     add_n
 ```
 
+### Trigonometric functions
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    sin
+    tan
+    arcsin
+    arctan
+    degrees
+    radians
+```
+
+### Hyperbolic functions
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    sinh
+    tanh
+    arcsinh
+    arctanh
+```
+
+### Reduce functions
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    sum
+    mean
+```
+
+### Rounding
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    round
+    rint
+    fix
+    floor
+    ceil
+    trunc
+```
+
+### Exponents and logarithms
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    expm1
+    log1p
+```
+
+### Powers
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    sqrt
+    square
+```
+
+### Miscellaneous
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    clip
+    abs
+    sign
+```
+
+## Neural network
+
+### More
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    make_loss
+    stop_gradient
+    mxnet.symbol.contrib.SparseEmbedding
+```
+
 ## API Reference
 
 <script type="text/javascript" src='../../../_static/js/auto_module_index.js'></script>
diff --git a/python/mxnet/ndarray/sparse.py b/python/mxnet/ndarray/sparse.py
index 45a269a10d..75a3ada111 100644
--- a/python/mxnet/ndarray/sparse.py
+++ b/python/mxnet/ndarray/sparse.py
@@ -30,6 +30,7 @@
 
 import ctypes
 import warnings
+import operator
 
 __all__ = ["_ndarray_cls", "csr_matrix", "row_sparse_array",
            "BaseSparseNDArray", "CSRNDArray", "RowSparseNDArray"]
@@ -113,18 +114,58 @@ def __repr__(self):
         return '\n<%s %s @%s>' % (self.__class__.__name__,
                                   shape_info, self.context)
 
+    def __add__(self, other):
+        """x.__add__(y) <=> x+y <=> mx.nd.sparse.add(x, y) """
+        return add(self, other)
+
     def __iadd__(self, other):
         raise NotImplementedError()
 
+    def __radd__(self, other):
+        return self.__add__(other)
+
+    def __sub__(self, other):
+        """x.__sub__(y) <=> x-y <=> mx.nd.sparse.subtract(x, y) """
+        return subtract(self, other)
+
     def __isub__(self, other):
         raise NotImplementedError()
 
+    def __rsub__(self, other):
+        """x.__rsub__(y) <=> y-x <=> mx.nd.sparse.subtract(y, x) """
+        return subtract(other, self)
+
+    def __mul__(self, other):
+        """x.__mul__(y) <=> x*y <=> mx.nd.spares.multiply(x, y) """
+        return multiply(self, other)
+
+    def __neg__(self):
+        """x.__neg__(y) <=> -x """
+        return _internal._mul_scalar(self, -1.0)
+
     def __imul__(self, other):
         raise NotImplementedError()
 
+    def __rmul__(self, other):
+        return self.__mul__(other)
+
+    def __div__(self, other):
+        """x.__div__(y) <=> x/y <=> mx.nd.sparse.divide(x, y) """
+        return divide(self, other)
+
+    def __rdiv__(self, other):
+        """x.__rdiv__(y) <=> y/x <=> mx.nd.sparse.divide(y, x) """
+        return divide(other, self)
+
     def __idiv__(self, other):
         raise NotImplementedError()
 
+    def __truediv__(self, other):
+        return divide(self, other)
+
+    def __rtruediv__(self, other):
+        return divide(other, self)
+
     def __itruediv__(self, other):
         raise NotImplementedError()
 
@@ -735,6 +776,13 @@ def copyto(self, other):
         else:
             raise TypeError('copyto does not support type ' + str(type(other)))
 
+    def retain(self, *args, **kwargs):
+        """Convenience fluent method for :py:func:`retain`.
+
+        The arguments are the same as for :py:func:`retain`, with
+        this array as data.
+        """
+        return retain(self, *args, **kwargs)
 
 def _prepare_src_array(source_array, dtype):
     """Prepare `source_array` so that it can be used to construct NDArray.
@@ -1260,3 +1308,255 @@ def array(source_array, ctx=None, dtype=None):
                          type(source_array))
     else:
         raise ValueError("Unexpected source_array type: ", type(source_array))
+
+#pylint: disable= too-many-arguments, no-member, protected-access
+def _ufunc_helper(lhs, rhs, fn_elemwise_arr, fn_broadcast_arr, fn_scalar,
+                  lfn_scalar, rfn_scalar=None):
+    """ Helper function for element-wise operation.
+    The function will perform numpy-like broadcasting if needed and call different functions.
+
+    Parameters
+    --------
+    lhs : NDArray or numeric value
+        Left-hand side operand.
+
+    rhs : NDArray or numeric value
+        Right-hand operand,
+
+    fn_elemwise_arr : function
+        Function to be called if both lhs and rhs are of ``NDArray`` type and the shape matches.
+
+    fn_broadcast_arr : function
+        Function to be called if both lhs and rhs are of ``NDArray`` type and \
+        the shape doesn't match.
+
+    fn_scalar : function
+        Function to be called if both lhs and rhs are numeric values.
+
+    lfn_scalar : function
+        Function to be called if lhs is ``NDArray`` while rhs is numeric value
+
+    rfn_scalar : function
+        Function to be called if lhs is numeric value while rhs is ``NDArray``;
+        if none is provided, then the function is commutative, so rfn_scalar is equal to lfn_scalar
+
+    Returns
+    --------
+    NDArray
+        result array
+    """
+    if isinstance(lhs, numeric_types):
+        if isinstance(rhs, numeric_types):
+            return fn_scalar(lhs, rhs)
+        else:
+            if rfn_scalar is None:
+                # commutative function
+                return lfn_scalar(rhs, float(lhs))
+            else:
+                return rfn_scalar(rhs, float(lhs))
+    elif isinstance(rhs, numeric_types):
+        return lfn_scalar(lhs, float(rhs))
+    elif isinstance(rhs, NDArray):
+        if lhs.shape == rhs.shape:
+            return fn_elemwise_arr(lhs, rhs)
+        else:
+            return fn_broadcast_arr(lhs, rhs)
+    else:
+        raise TypeError('type %s not supported' % str(type(rhs)))
+
+#pylint: enable= too-many-arguments, no-member, protected-access
+def add(lhs, rhs):
+    """Returns element-wise sum of the input arrays.
+
+    Equivalent to ``lhs + rhs``, ``mx.nd.sparse.elemwise_add(lhs, rhs)``,
+    ``mx.nd.broadcast_add(lhs, rhs)`` and ``mx.nd.broadcast_plus(lhs, rhs)``.
+
+    .. note::
+
+       If the corresponding dimensions of two arrays have the same size or one of them has size 1,
+       then the arrays are broadcastable to a common shape.
+       If the shape of `lhs` array and that of `rhs` array match,
+       then `elemwise_add` will be used instead of `broadcast_add`.
+
+    Parameters
+    ----------
+    lhs : scalar or array
+        First array to be added.
+    rhs : scalar or array
+         Second array to be added.
+        If ``lhs.shape != rhs.shape``, they must be
+        broadcastable to a common shape.
+
+    Returns
+    -------
+    NDArray, CSRNDArray or RowSparseNDArray
+        The element-wise sum of the input arrays.
+
+    Examples
+    --------
+    >>> x = mx.nd.ones((2,3)).tostype('csr')
+    >>> y = mx.nd.arange(6).reshape((2,3)).tostype('csr')
+    >>> x.asnumpy()
+    array([[ 1.,  1.,  1.],
+           [ 1.,  1.,  1.]], dtype=float32)
+    >>> y.asnumpy()
+    array([[ 0.,  1.,  2.],
+           [ 3.,  4.,  5.]], dtype=float32)
+    >>> (x+y).asnumpy()
+    array([[ 1.,  2.,  3.],
+           [ 4.,  5.,  6.]], dtype=float32)
+    """
+    # pylint: disable= no-member, protected-access
+    return _ufunc_helper(
+        lhs,
+        rhs,
+        elemwise_add,
+        op.broadcast_add,
+        operator.add,
+        _internal._plus_scalar,
+        None)
+    # pylint: enable= no-member, protected-access
+
+
+def subtract(lhs, rhs):
+    """Returns element-wise difference of the input arrays with broadcasting.
+
+    Equivalent to ``lhs - rhs``, ``mx.nd.broadcast_sub(lhs, rhs)`` and
+    ``mx.nd.broadcast_minus(lhs, rhs)``.
+
+    .. note::
+
+       If the corresponding dimensions of two arrays have the same size or one of them has size 1,
+       then the arrays are broadcastable to a common shape.
+       If the shape of `lhs` array and that of `rhs` array match,
+       then `elemwise_sub` will be used instead of `broadcast_sub`.
+
+    Parameters
+    ----------
+    lhs : scalar or array
+        First array to be subtracted.
+    rhs : scalar or array
+         Second array to be subtracted.
+        If ``lhs.shape != rhs.shape``, they must be
+        broadcastable to a common shape.
+
+    Returns
+    -------
+    NDArray, CSRNDArray or RowSparseNDArray
+        The element-wise difference of the input arrays.
+
+    Examples
+    --------
+    >>> x = mx.nd.ones((2,3)).tostype('csr')
+    >>> y = mx.nd.arange(6).reshape((2,3)).tostype('csr')
+    >>> x.asnumpy()
+    array([[ 1.,  1.,  1.],
+           [ 1.,  1.,  1.]], dtype=float32)
+    >>> y.asnumpy()
+    array([[ 0.,  1.,  2.],
+           [ 3.,  4.,  5.]], dtype=float32)
+    >>> (x-y).asnumpy()
+    array([[ 1.,  0., -1.],
+           [-2., -3., -4.]], dtype=float32)
+    """
+    # pylint: disable= no-member, protected-access
+    return _ufunc_helper(
+        lhs,
+        rhs,
+        elemwise_sub,
+        op.broadcast_sub,
+        operator.sub,
+        _internal._minus_scalar,
+        _internal._rminus_scalar)
+    # pylint: enable= no-member, protected-access
+
+
+def multiply(lhs, rhs):
+    """Returns element-wise product of the input arrays with broadcasting.
+
+    Equivalent to ``lhs * rhs`` and ``mx.nd.broadcast_mul(lhs, rhs)``.
+
+    .. note::
+
+       If the corresponding dimensions of two arrays have the same size or one of them has size 1,
+       then the arrays are broadcastable to a common shape.
+       If the shape of `lhs` array and that of `rhs` array match,
+       then `elemwise_mul` will be used instead of `broadcast_mul`.
+
+    Parameters
+    ----------
+    lhs : scalar or array
+        First array to be multiplied.
+    rhs : scalar or array
+         Second array to be multiplied.
+        If ``lhs.shape != rhs.shape``, they must be
+        broadcastable to a common shape.
+
+    Returns
+    -------
+    NDArray, CSRNDArray or RowSparseNDArray
+        The element-wise multiplication of the input arrays.
+
+    Examples
+    --------
+    >>> x = mx.nd.ones((2,3)).tostype('csr')
+    >>> y = mx.nd.arange(6).reshape((2,3)).tostype('csr')
+    >>> x.asnumpy()
+    array([[ 1.,  1.,  1.],
+           [ 1.,  1.,  1.]], dtype=float32)
+    >>> y.asnumpy()
+    array([[ 0.,  1.,  2.],
+           [ 3.,  4.,  5.]], dtype=float32)
+    >>> (x*y).asnumpy()
+    array([[ 0.,  1.,  2.],
+           [ 3.,  4.,  5.]], dtype=float32)
+    """
+    # pylint: disable= no-member, protected-access
+    return _ufunc_helper(
+        lhs,
+        rhs,
+        elemwise_mul,
+        op.broadcast_mul,
+        operator.mul,
+        _internal._mul_scalar,
+        None)
+    # pylint: enable= no-member, protected-access
+
+
+def divide(lhs, rhs):
+    """Returns element-wise division of the input arrays with broadcasting.
+
+    Equivalent to ``lhs / rhs`` and ``mx.nd.broadcast_div(lhs, rhs)``.
+
+    .. note::
+
+       If the corresponding dimensions of two arrays have the same size or one of them has size 1,
+       then the arrays are broadcastable to a common shape.
+       If the shape of `lhs` array and that of `rhs` array match,
+       then `elemwise_div` will be used instead of `broadcast_div`.
+
+    Parameters
+    ----------
+    lhs : scalar or array
+        First array in division.
+    rhs : scalar or array
+         Second array in division.
+        The arrays to be divided. If ``lhs.shape != rhs.shape``, they must be
+        broadcastable to a common shape.
+
+    Returns
+    -------
+    NDArray
+        The element-wise division of the input arrays.
+
+    """
+    # pylint: disable= no-member, protected-access
+    return _ufunc_helper(
+        lhs,
+        rhs,
+        elemwise_div,
+        op.broadcast_div,
+        operator.truediv,
+        _internal._div_scalar,
+        _internal._rdiv_scalar)
+    # pylint: enable= no-member, protected-access
diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc
index b3be9e4c20..8f1ae79ec0 100644
--- a/src/operator/tensor/elemwise_binary_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_op_basic.cc
@@ -35,6 +35,7 @@ MXNET_ADD_SPARSE_OP_ALIAS(elemwise_add)
 The storage type of ``elemwise_add`` output depends on storage types of inputs
 
    - elemwise_add(row_sparse, row_sparse) = row_sparse
+   - elemwise_add(csr, csr) = csr
    - otherwise, ``elemwise_add`` generates output with default storage
 
 )code")
@@ -69,7 +70,8 @@ MXNET_ADD_SPARSE_OP_ALIAS(elemwise_sub)
 The storage type of ``elemwise_sub`` output depends on storage types of inputs
 
    - elemwise_sub(row_sparse, row_sparse) = row_sparse
-   - otherwise, ``elemwise_add`` generates output with default storage
+   - elemwise_sub(csr, csr) = csr
+   - otherwise, ``elemwise_sub`` generates output with default storage
 
 )code")
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_sub"});
@@ -100,6 +102,7 @@ The storage type of ``elemwise_mul`` output depends on storage types of inputs
    - elemwise_mul(row_sparse, row_sparse) = row_sparse
    - elemwise_mul(default, row_sparse) = default
    - elemwise_mul(row_sparse, default) = default
+   - elemwise_mul(csr, csr) = csr
    - otherwise, ``elemwise_mul`` generates output with default storage
 
 )code")
@@ -138,7 +141,7 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(elemwise_div, mshadow::op::div
 MXNET_ADD_SPARSE_OP_ALIAS(elemwise_div)
 .describe(R"code(Divides arguments element-wise.
 
-The storage type of ``elemwise_dev`` output is always dense
+The storage type of ``elemwise_div`` output is always dense
 
 )code")
 .add_alias("_div").add_alias("_Div")
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index cba9efd1a9..97e688078c 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -247,7 +247,7 @@ will return a new array with shape ``(2,1,3,4)``.
 .add_arguments(ExpandDimParam::__FIELDS__());
 
 NNVM_REGISTER_OP(slice)
-.add_alias("_sparse_slice")
+MXNET_ADD_SPARSE_OP_ALIAS(slice)
 .add_alias("crop")
 .describe(R"code(Slices a region of the array.
 
@@ -395,6 +395,7 @@ NNVM_REGISTER_OP(_backward_slice_axis)
 .set_attr<FCompute>("FCompute<cpu>", SliceAxisGrad_<cpu>);
 
 NNVM_REGISTER_OP(clip)
+MXNET_ADD_SPARSE_OP_ALIAS(clip)
 .describe(R"code(Clips (limits) the values in an array.
 
 Given an interval, values outside the interval are clipped to the interval edges.
diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py
index 7576050f55..4a0a1ff1b5 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -22,6 +22,7 @@
 from mxnet.base import mx_real_t
 from numpy.testing import assert_allclose
 import numpy.random as rnd
+from common import assertRaises
 
 from mxnet.ndarray.sparse import RowSparseNDArray, CSRNDArray
 
@@ -233,18 +234,19 @@ def check_binary(fn, stype):
             oshape = np.random.randint(1, 6, size=(ndim,))
             bdim = 2
             lshape = list(oshape)
-            rshape = list(oshape[ndim-bdim:])
-            for i in range(bdim):
-                sep = np.random.uniform(0, 1)
-                if sep < 0.33:
-                    lshape[ndim-i-1] = 1
-                elif sep < 0.66:
-                    rshape[bdim-i-1] = 1
-            lhs = np.random.uniform(0, 1, size=lshape)
-            rhs = np.random.uniform(0, 1, size=rshape)
-            lhs_nd = mx.nd.array(lhs).tostype(stype)
-            rhs_nd = mx.nd.array(rhs).tostype(stype)
-            assert_allclose(fn(lhs, rhs), fn(lhs_nd, rhs_nd).asnumpy(), rtol=1e-4, atol=1e-4)
+            rshapes = [list(oshape[ndim-bdim:]), lshape]
+            for rshape in rshapes:
+                for i in range(bdim):
+                    sep = np.random.uniform(0, 1)
+                    if sep < 0.33:
+                        lshape[ndim-i-1] = 1
+                    elif sep < 0.66:
+                        rshape[bdim-i-1] = 1
+                lhs = np.random.uniform(0, 1, size=lshape)
+                rhs = np.random.uniform(0, 1, size=rshape)
+                lhs_nd = mx.nd.array(lhs).tostype(stype)
+                rhs_nd = mx.nd.array(rhs).tostype(stype)
+                assert_allclose(fn(lhs, rhs), fn(lhs_nd, rhs_nd).asnumpy(), rtol=1e-4, atol=1e-4)
 
     stypes = ['row_sparse', 'csr']
     for stype in stypes:
@@ -736,19 +738,46 @@ def test_powerlaw_generator(csr_arr, final_row=1):
     test_powerlaw_generator(csr_arr_big, final_row=4)
     test_powerlaw_generator(csr_arr_square, final_row=6)
 
+def test_sparse_nd_fluent():
+    def check_fluent_regular(stype, func, kwargs, shape=(5, 17), equal_nan=False):
+        with mx.name.NameManager():
+            data = mx.nd.random_uniform(shape=shape, ctx=default_context()).tostype(stype)
+            regular = getattr(mx.ndarray, func)(data, **kwargs)
+            fluent = getattr(data, func)(**kwargs)
+            if isinstance(regular, list):
+                for r, f in zip(regular, fluent):
+                    assert almost_equal(r.asnumpy(), f.asnumpy(), equal_nan=equal_nan)
+            else:
+                assert almost_equal(regular.asnumpy(), fluent.asnumpy(), equal_nan=equal_nan)
+
+    common_func = ['zeros_like', 'square']
+    rsp_func = ['round', 'rint', 'fix', 'floor', 'ceil', 'trunc',
+                'abs', 'sign', 'sin', 'degrees', 'radians', 'expm1']
+    for func in common_func:
+        check_fluent_regular('csr', func, {})
+    for func in common_func + rsp_func:
+        check_fluent_regular('row_sparse', func, {})
+
+    rsp_func = ['arcsin', 'arctan', 'tan', 'sinh', 'tanh',
+                'arcsinh', 'arctanh', 'log1p', 'sqrt', 'relu']
+    for func in rsp_func:
+        check_fluent_regular('row_sparse', func, {}, equal_nan=True)
+
+    check_fluent_regular('csr', 'slice', {'begin': (2, 5), 'end': (4, 7)}, shape=(5, 17))
+    check_fluent_regular('row_sparse', 'clip', {'a_min': -0.25, 'a_max': 0.75})
+
+    for func in ['sum', 'mean']:
+        check_fluent_regular('csr', func, {'axis': 0})
+
+
 def test_sparse_nd_exception():
     """ test invalid sparse operator will throw a exception """
     a = mx.nd.ones((2,2))
-    assert_exception(mx.nd.sparse.retain, mx.base.MXNetError,
-                     a, invalid_arg="garbage_value")
-    assert_exception(mx.nd.sparse.csr_matrix, ValueError,
-                     a, shape=(3,2))
-    assert_exception(mx.nd.sparse.csr_matrix, ValueError,
-                     (2,2), shape=(3,2))
-    assert_exception(mx.nd.sparse.row_sparse_array, ValueError,
-                     (2,2), shape=(3,2))
-    assert_exception(mx.nd.sparse.zeros, ValueError,
-                     "invalid_stype", (2,2))
+    assertRaises(mx.base.MXNetError, mx.nd.sparse.retain, a, invalid_arg="garbage_value")
+    assertRaises(ValueError, mx.nd.sparse.csr_matrix, a, shape=(3,2))
+    assertRaises(ValueError, mx.nd.sparse.csr_matrix, (2,2), shape=(3,2))
+    assertraises(ValueError, mx.nd.sparse.row_sparse_array, (2,2), shape=(3,2))
+    assertRaises(ValueError, mx.nd.sparse.zeros, "invalid_stype", (2,2))
 
 
 if __name__ == '__main__':


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services