You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/18 21:35:29 UTC

[incubator-mxnet] branch master updated: Doc updates for sparse operators (#8641)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 257ad1d  Doc updates for sparse operators (#8641)
257ad1d is described below

commit 257ad1d90d1a68fb3a611e5e556cc9b2c2aa07ec
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Sat Nov 18 13:35:26 2017 -0800

    Doc updates for sparse operators (#8641)
    
    * doc updates
    
    * arithmetics and symbols
    
    * add sparse embed to sym
    
    * fix lint
    
    * minor updates
    
    * remove arithemtic changes
    
    * remove unused test
    
    * fix lint
    
    * add note for optimizers
    
    * misc-updates
    
    * add check_fmt
---
 docs/api/python/ndarray/sparse.md               | 125 +++++++++++++++++++++++-
 docs/api/python/symbol/sparse.md                |  97 ++++++++++++++++++
 docs/tutorials/sparse/row_sparse.md             |   2 +-
 example/sparse/linear_classification.py         |   4 +-
 python/mxnet/ndarray/sparse.py                  |   7 ++
 python/mxnet/optimizer.py                       |  26 ++++-
 src/operator/operator_common.h                  |   4 +
 src/operator/optimizer_op.cc                    |   4 +
 src/operator/tensor/elemwise_binary_op_basic.cc |   7 +-
 src/operator/tensor/elemwise_unary_op.h         |   4 -
 src/operator/tensor/matrix_op.cc                |   3 +-
 tests/python/unittest/test_sparse_ndarray.py    |  53 +++++++---
 12 files changed, 307 insertions(+), 29 deletions(-)

diff --git a/docs/api/python/ndarray/sparse.md b/docs/api/python/ndarray/sparse.md
index 9b742f4..dd0286d 100644
--- a/docs/api/python/ndarray/sparse.md
+++ b/docs/api/python/ndarray/sparse.md
@@ -123,13 +123,22 @@ We summarize the interface for each class in the following sections.
     CSRNDArray.copy
     CSRNDArray.copyto
     CSRNDArray.as_in_context
-    CSRNDArray.asnumpy
     CSRNDArray.asscipy
+    CSRNDArray.asnumpy
     CSRNDArray.asscalar
     CSRNDArray.astype
     CSRNDArray.tostype
 ```
 
+### Array inspection
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    CSRNDArray.check_format
+```
+
 ### Array creation
 
 ```eval_rst
@@ -139,6 +148,25 @@ We summarize the interface for each class in the following sections.
     CSRNDArray.zeros_like
 ```
 
+### Array reduction
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    CSRNDArray.sum
+    CSRNDArray.mean
+```
+
+### Powers
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    CSRNDArray.square
+```
+
 ### Indexing
 
 ```eval_rst
@@ -190,6 +218,15 @@ We summarize the interface for each class in the following sections.
     RowSparseNDArray.tostype
 ```
 
+### Array inspection
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    RowSparseNDArray.check_format
+```
+
 ### Array creation
 
 ```eval_rst
@@ -213,6 +250,52 @@ We summarize the interface for each class in the following sections.
     RowSparseNDArray.trunc
 ```
 
+### Trigonometric functions
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    RowSparseNDArray.sin
+    RowSparseNDArray.tan
+    RowSparseNDArray.arcsin
+    RowSparseNDArray.arctan
+    RowSparseNDArray.degrees
+    RowSparseNDArray.radians
+```
+
+### Hyperbolic functions
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    RowSparseNDArray.sinh
+    RowSparseNDArray.tanh
+    RowSparseNDArray.arcsinh
+    RowSparseNDArray.arctanh
+```
+
+### Exponents and logarithms
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    RowSparseNDArray.expm1
+    RowSparseNDArray.log1p
+```
+
+### Powers
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    RowSparseNDArray.sqrt
+    RowSparseNDArray.square
+```
+
 ### Indexing
 
 ```eval_rst
@@ -221,6 +304,7 @@ We summarize the interface for each class in the following sections.
 
     RowSparseNDArray.__getitem__
     RowSparseNDArray.__setitem__
+    RowSparseNDArray.retain
 ```
 
 ### Lazy evaluation
@@ -232,6 +316,16 @@ We summarize the interface for each class in the following sections.
     RowSparseNDArray.wait_to_read
 ```
 
+### Miscellaneous
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    RowSparseNDArray.clip
+    RowSparseNDArray.sign
+```
+
 ## Array creation routines
 
 ```eval_rst
@@ -311,6 +405,16 @@ We summarize the interface for each class in the following sections.
     arctanh
 ```
 
+### Reduce functions
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    sum
+    mean
+```
+
 ### Rounding
 
 ```eval_rst
@@ -355,6 +459,20 @@ We summarize the interface for each class in the following sections.
     sign
 ```
 
+## Neural network
+
+### Updater
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    sgd_update
+    sgd_mom_update
+    adam_update
+    ftrl_update
+```
+
 ### More
 
 ```eval_rst
@@ -363,6 +481,7 @@ We summarize the interface for each class in the following sections.
 
     make_loss
     stop_gradient
+    mxnet.ndarray.contrib.SparseEmbedding
 ```
 
 ## API Reference
@@ -372,10 +491,10 @@ We summarize the interface for each class in the following sections.
 ```eval_rst
 
 .. autoclass:: mxnet.ndarray.sparse.CSRNDArray
-    :members: shape, context, dtype, stype, data, indices, indptr, copy, copyto, as_in_context, asnumpy, asscalar, astype, tostype, slice, wait_to_read, zeros_like, __getitem__, __setitem__
+    :members: shape, context, dtype, stype, data, indices, indptr, copy, copyto, as_in_context, asscipy, asnumpy, asscalar, astype, tostype, slice, wait_to_read, zeros_like, __neg__, sum, mean, square, __getitem__, __setitem__, check_format
 
 .. autoclass:: mxnet.ndarray.sparse.RowSparseNDArray
-    :members: shape, context, dtype, stype, data, indices, copy, copyto, as_in_context, asnumpy, asscalar, astype, tostype, wait_to_read, zeros_like, round, rint, fix, floor, ceil, trunc, __getitem__, __setitem__
+    :members: shape, context, dtype, stype, data, indices, copy, copyto, as_in_context, asnumpy, asscalar, astype, tostype, wait_to_read, zeros_like, round, rint, fix, floor, ceil, trunc, sin, tan, arcsin, arctan, degrees, radians, sinh, tanh, arcsinh, arctanh, expm1, log1p, sqrt, square, __negative__, __getitem__, __setitem__, check_format, retain, clip, sign
 
 .. automodule:: mxnet.ndarray.sparse
     :members:
diff --git a/docs/api/python/symbol/sparse.md b/docs/api/python/symbol/sparse.md
index 5ebbfcd..b40276b 100644
--- a/docs/api/python/symbol/sparse.md
+++ b/docs/api/python/symbol/sparse.md
@@ -95,10 +95,107 @@ In the rest of this document, we list sparse related routines provided by the
     :nosignatures:
 
     elemwise_add
+    elemwise_sub
+    elemwise_mul
+    negative
     dot
     add_n
 ```
 
+### Trigonometric functions
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    sin
+    tan
+    arcsin
+    arctan
+    degrees
+    radians
+```
+
+### Hyperbolic functions
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    sinh
+    tanh
+    arcsinh
+    arctanh
+```
+
+### Reduce functions
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    sum
+    mean
+```
+
+### Rounding
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    round
+    rint
+    fix
+    floor
+    ceil
+    trunc
+```
+
+### Exponents and logarithms
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    expm1
+    log1p
+```
+
+### Powers
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    sqrt
+    square
+```
+
+### Miscellaneous
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    clip
+    abs
+    sign
+```
+
+## Neural network
+
+### More
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    make_loss
+    stop_gradient
+    mxnet.symbol.contrib.SparseEmbedding
+```
+
 ## API Reference
 
 <script type="text/javascript" src='../../../_static/js/auto_module_index.js'></script>
diff --git a/docs/tutorials/sparse/row_sparse.md b/docs/tutorials/sparse/row_sparse.md
index 6a69341..55f8a7d 100644
--- a/docs/tutorials/sparse/row_sparse.md
+++ b/docs/tutorials/sparse/row_sparse.md
@@ -6,7 +6,7 @@
 Many real world datasets deal with high dimensional sparse feature vectors. When learning
 the weights of models with sparse datasets, the derived gradients of the weights could be sparse.
 
-Let's say we perform a matrix multiplication of ``X``  and ``W``, where ``X`` is a 2x2 matrix, and ``W`` is a 2x1 matrix. Let ``Y`` be the matrix multiplication of the two matrices:
+Let's say we perform a matrix multiplication of ``X``  and ``W``, where ``X`` is a 1x2 matrix, and ``W`` is a 2x3 matrix. Let ``Y`` be the matrix multiplication of the two matrices:
 
 ```python
 import mxnet as mx
diff --git a/example/sparse/linear_classification.py b/example/sparse/linear_classification.py
index 70f8963..1d63c55 100644
--- a/example/sparse/linear_classification.py
+++ b/example/sparse/linear_classification.py
@@ -126,8 +126,8 @@ if __name__ == '__main__':
         # evaluate metric on validation dataset
         score = mod.score(eval_data, ['nll_loss'])
         logging.info('epoch %d, eval nll = %s ' % (epoch, score[0][1]))
-        save_optimizer_states = 'dist' not in kv.type
-        mod.save_checkpoint("checkpoint", epoch, save_optimizer_states=False)
+        save_optimizer_states = 'dist' not in kv.type if kv else True
+        mod.save_checkpoint("checkpoint", epoch, save_optimizer_states=save_optimizer_states)
         # reset the iterator for next pass of data
         data_iter.reset()
     logging.info('Training completed.')
diff --git a/python/mxnet/ndarray/sparse.py b/python/mxnet/ndarray/sparse.py
index 044ce81..229044e 100644
--- a/python/mxnet/ndarray/sparse.py
+++ b/python/mxnet/ndarray/sparse.py
@@ -747,6 +747,13 @@ class RowSparseNDArray(BaseSparseNDArray):
         else:
             raise TypeError('copyto does not support type ' + str(type(other)))
 
+    def retain(self, *args, **kwargs):
+        """Convenience fluent method for :py:func:`retain`.
+
+        The arguments are the same as for :py:func:`retain`, with
+        this array as data.
+        """
+        return retain(self, *args, **kwargs)
 
 def _prepare_src_array(source_array, dtype):
     """Prepare `source_array` so that it can be used to construct NDArray.
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index eaaf521..5eb4f05 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -442,13 +442,20 @@ class SGD(Optimizer):
         weight = weight - state
 
     If the storage types of weight, state and grad are all ``row_sparse``, \
-    sparse updates are applied by::
+    **sparse updates** are applied by::
 
         for row in grad.indices:
             rescaled_grad[row] = lr * rescale_grad * clip(grad[row], clip_gradient) + wd * weight[row]
             state[row] = momentum[row] * state[row] + rescaled_grad[row]
             weight[row] = weight[row] - state[row]
 
+    The sparse update only updates the momentum for the weights whose row_sparse
+    gradient indices appear in the current batch, rather than updating it for all
+    indices. Compared with the original update, it can provide large
+    improvements in model training throughput for some applications. However, it
+    provides slightly different semantics than the original update, and
+    may lead to different empirical results.
+
     For details of the update algorithm see
     :class:`~mxnet.ndarray.sgd_update` and :class:`~mxnet.ndarray.sgd_mom_update`.
 
@@ -667,7 +674,7 @@ class Adam(Optimizer):
         w = w - learning_rate * m / (sqrt(v) + epsilon)
 
     If the storage types of weight, state and grad are all ``row_sparse``, \
-    sparse updates are applied by::
+    **sparse updates** are applied by::
 
         for row in grad.indices:
             rescaled_grad[row] = clip(grad[row] * rescale_grad + wd * weight[row], clip_gradient)
@@ -675,6 +682,12 @@ class Adam(Optimizer):
             v[row] = beta2 * v[row] + (1 - beta2) * (rescaled_grad[row]**2)
             w[row] = w[row] - learning_rate * m[row] / (sqrt(v[row]) + epsilon)
 
+    The sparse update only updates the mean and var for the weights whose row_sparse
+    gradient indices appear in the current batch, rather than updating it for all indices.
+    Compared with the original update, it can provide large improvements in model training
+    throughput for some applications. However, it provides slightly different semantics than
+    the original update, and may lead to different empirical results.
+
     This optimizer accepts the following parameters in addition to those accepted
     by :class:`.Optimizer`.
 
@@ -936,7 +949,7 @@ class Ftrl(Optimizer):
         w = (sign(z) * lamda1 - z) / ((beta + sqrt(n)) / learning_rate + wd) * (abs(z) > lamda1)
 
     If the storage types of weight, state and grad are all ``row_sparse``, \
-    sparse updates are applied by::
+    **sparse updates** are applied by::
 
         for row in grad.indices:
             rescaled_grad[row] = clip(grad[row] * rescale_grad, clip_gradient)
@@ -944,6 +957,13 @@ class Ftrl(Optimizer):
             n[row] += rescaled_grad[row]**2
             w[row] = (sign(z[row]) * lamda1 - z[row]) / ((beta + sqrt(n[row])) / learning_rate + wd) * (abs(z[row]) > lamda1)
 
+    The sparse update only updates the z and n for the weights whose row_sparse
+    gradient indices appear in the current batch, rather than updating it for all
+    indices. Compared with the original update, it can provide large
+    improvements in model training throughput for some applications. However, it
+    provides slightly different semantics than the original update, and
+    may lead to different empirical results.
+
     For details of the update algorithm, see :class:`~mxnet.ndarray.ftrl_update`.
 
     This optimizer accepts the following parameters in addition to those accepted
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index d036355..24fdd3c 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -207,6 +207,10 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) {
   return true;
 }
 
+/*! \brief Register op name as an alias */
+#define MXNET_ADD_SPARSE_OP_ALIAS(__name$) \
+  .add_alias("_sparse_" #__name$)
+
 /*!
  * \brief macro assign shape to out if out is unknown otherwise check consistency
  *  Use macro so we can see the error file more clearly
diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc
index 0382070..3d8774c 100644
--- a/src/operator/optimizer_op.cc
+++ b/src/operator/optimizer_op.cc
@@ -36,6 +36,7 @@ DMLC_REGISTER_PARAMETER(RMSPropAlexParam);
 DMLC_REGISTER_PARAMETER(FtrlParam);
 
 NNVM_REGISTER_OP(sgd_update)
+MXNET_ADD_SPARSE_OP_ALIAS(sgd_update)
 .describe(R"code(Update function for Stochastic Gradient Descent (SDG) optimizer.
 
 It updates the weights using::
@@ -62,6 +63,7 @@ only the row slices whose indices appear in grad.indices are updated::
 .add_arguments(SGDParam::__FIELDS__());
 
 NNVM_REGISTER_OP(sgd_mom_update)
+MXNET_ADD_SPARSE_OP_ALIAS(sgd_mom_update)
 .describe(R"code(Momentum update function for Stochastic Gradient Descent (SDG) optimizer.
 
 Momentum update has better convergence rates on neural networks. Mathematically it looks
@@ -141,6 +143,7 @@ NNVM_REGISTER_OP(mp_sgd_mom_update)
 .add_arguments(SGDMomParam::__FIELDS__());
 
 NNVM_REGISTER_OP(adam_update)
+MXNET_ADD_SPARSE_OP_ALIAS(adam_update)
 .describe(R"code(Update function for Adam optimizer. Adam is seen as a generalization
 of AdaGrad.
 
@@ -280,6 +283,7 @@ to be 0.9 and the learning rate :math:`\eta` to be 0.0001.
 .add_arguments(RMSPropAlexParam::__FIELDS__());
 
 NNVM_REGISTER_OP(ftrl_update)
+MXNET_ADD_SPARSE_OP_ALIAS(ftrl_update)
 .describe(R"code(Update function for Ftrl optimizer.
 Referenced from *Ad Click Prediction: a View from the Trenches*, available at
 http://dl.acm.org/citation.cfm?id=2488200.
diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc
index b3be9e4..8f1ae79 100644
--- a/src/operator/tensor/elemwise_binary_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_op_basic.cc
@@ -35,6 +35,7 @@ MXNET_ADD_SPARSE_OP_ALIAS(elemwise_add)
 The storage type of ``elemwise_add`` output depends on storage types of inputs
 
    - elemwise_add(row_sparse, row_sparse) = row_sparse
+   - elemwise_add(csr, csr) = csr
    - otherwise, ``elemwise_add`` generates output with default storage
 
 )code")
@@ -69,7 +70,8 @@ MXNET_ADD_SPARSE_OP_ALIAS(elemwise_sub)
 The storage type of ``elemwise_sub`` output depends on storage types of inputs
 
    - elemwise_sub(row_sparse, row_sparse) = row_sparse
-   - otherwise, ``elemwise_add`` generates output with default storage
+   - elemwise_sub(csr, csr) = csr
+   - otherwise, ``elemwise_sub`` generates output with default storage
 
 )code")
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_sub"});
@@ -100,6 +102,7 @@ The storage type of ``elemwise_mul`` output depends on storage types of inputs
    - elemwise_mul(row_sparse, row_sparse) = row_sparse
    - elemwise_mul(default, row_sparse) = default
    - elemwise_mul(row_sparse, default) = default
+   - elemwise_mul(csr, csr) = csr
    - otherwise, ``elemwise_mul`` generates output with default storage
 
 )code")
@@ -138,7 +141,7 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(elemwise_div, mshadow::op::div
 MXNET_ADD_SPARSE_OP_ALIAS(elemwise_div)
 .describe(R"code(Divides arguments element-wise.
 
-The storage type of ``elemwise_dev`` output is always dense
+The storage type of ``elemwise_div`` output is always dense
 
 )code")
 .add_alias("_div").add_alias("_Div")
diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h
index 6fbde05..6e635d9 100644
--- a/src/operator/tensor/elemwise_unary_op.h
+++ b/src/operator/tensor/elemwise_unary_op.h
@@ -416,10 +416,6 @@ void CastCompute(const nnvm::NodeAttrs& attrs,
     })                                                              \
   .add_argument("data", "NDArray-or-Symbol", "The input array.")
 
-/*! \brief Register scalar op name as an alias */
-#define MXNET_ADD_SPARSE_OP_ALIAS(__name$) \
-  .add_alias("_sparse_" #__name$)
-
 /*! \brief Unary compute, with FComputeEx for csr and rsp available  */
 #define MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(__name$, __xpu$, __kernel$)                     \
   MXNET_OPERATOR_REGISTER_UNARY(__name$)                                                           \
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index cba9efd..97e6880 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -247,7 +247,7 @@ will return a new array with shape ``(2,1,3,4)``.
 .add_arguments(ExpandDimParam::__FIELDS__());
 
 NNVM_REGISTER_OP(slice)
-.add_alias("_sparse_slice")
+MXNET_ADD_SPARSE_OP_ALIAS(slice)
 .add_alias("crop")
 .describe(R"code(Slices a region of the array.
 
@@ -395,6 +395,7 @@ NNVM_REGISTER_OP(_backward_slice_axis)
 .set_attr<FCompute>("FCompute<cpu>", SliceAxisGrad_<cpu>);
 
 NNVM_REGISTER_OP(clip)
+MXNET_ADD_SPARSE_OP_ALIAS(clip)
 .describe(R"code(Clips (limits) the values in an array.
 
 Given an interval, values outside the interval are clipped to the interval edges.
diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py
index c7ac1ca..e59e476 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -22,9 +22,8 @@ from mxnet.test_utils import *
 from mxnet.base import mx_real_t
 from numpy.testing import assert_allclose
 import numpy.random as rnd
-
-from mxnet.ndarray.sparse import RowSparseNDArray, CSRNDArray
 from common import assertRaises
+from mxnet.ndarray.sparse import RowSparseNDArray, CSRNDArray
 
 
 def sparse_nd_ones(shape, stype):
@@ -234,6 +233,7 @@ def test_sparse_nd_binary():
             oshape = np.random.randint(1, 6, size=(ndim,))
             bdim = 2
             lshape = list(oshape)
+            # one for broadcast op, another for elemwise op
             rshape = list(oshape[ndim-bdim:])
             for i in range(bdim):
                 sep = np.random.uniform(0, 1)
@@ -737,20 +737,47 @@ def test_synthetic_dataset_generator():
     test_powerlaw_generator(csr_arr_big, final_row=4)
     test_powerlaw_generator(csr_arr_square, final_row=6)
 
+def test_sparse_nd_fluent():
+    def check_fluent_regular(stype, func, kwargs, shape=(5, 17), equal_nan=False):
+        with mx.name.NameManager():
+            data = mx.nd.random_uniform(shape=shape, ctx=default_context()).tostype(stype)
+            regular = getattr(mx.ndarray, func)(data, **kwargs)
+            fluent = getattr(data, func)(**kwargs)
+            if isinstance(regular, list):
+                for r, f in zip(regular, fluent):
+                    assert almost_equal(r.asnumpy(), f.asnumpy(), equal_nan=equal_nan)
+            else:
+                assert almost_equal(regular.asnumpy(), fluent.asnumpy(), equal_nan=equal_nan)
+
+    common_func = ['zeros_like', 'square']
+    rsp_func = ['round', 'rint', 'fix', 'floor', 'ceil', 'trunc',
+                'abs', 'sign', 'sin', 'degrees', 'radians', 'expm1']
+    for func in common_func:
+        check_fluent_regular('csr', func, {})
+    for func in common_func + rsp_func:
+        check_fluent_regular('row_sparse', func, {})
+
+    rsp_func = ['arcsin', 'arctan', 'tan', 'sinh', 'tanh',
+                'arcsinh', 'arctanh', 'log1p', 'sqrt', 'relu']
+    for func in rsp_func:
+        check_fluent_regular('row_sparse', func, {}, equal_nan=True)
+
+    check_fluent_regular('csr', 'slice', {'begin': (2, 5), 'end': (4, 7)}, shape=(5, 17))
+    check_fluent_regular('row_sparse', 'clip', {'a_min': -0.25, 'a_max': 0.75})
+
+    for func in ['sum', 'mean']:
+        check_fluent_regular('csr', func, {'axis': 0})
+
+
 def test_sparse_nd_exception():
     """ test invalid sparse operator will throw a exception """
     a = mx.nd.ones((2,2))
-    assert_exception(mx.nd.sparse.retain, mx.base.MXNetError,
-                     a, invalid_arg="garbage_value")
-    assert_exception(mx.nd.sparse.csr_matrix, ValueError,
-                     a, shape=(3,2))
-    assert_exception(mx.nd.sparse.csr_matrix, ValueError,
-                     (2,2), shape=(3,2))
-    assert_exception(mx.nd.sparse.row_sparse_array, ValueError,
-                     (2,2), shape=(3,2))
-    assert_exception(mx.nd.sparse.zeros, ValueError,
-                     "invalid_stype", (2,2))
-    
+    assertRaises(mx.base.MXNetError, mx.nd.sparse.retain, a, invalid_arg="garbage_value")
+    assertRaises(ValueError, mx.nd.sparse.csr_matrix, a, shape=(3,2))
+    assertRaises(ValueError, mx.nd.sparse.csr_matrix, (2,2), shape=(3,2))
+    assertRaises(ValueError, mx.nd.sparse.row_sparse_array, (2,2), shape=(3,2))
+    assertRaises(ValueError, mx.nd.sparse.zeros, "invalid_stype", (2,2))
+
 def test_sparse_nd_check_format():
     """ test check_format for sparse ndarray """
     shape = rand_shape_2d()

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].