You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/04/05 19:46:53 UTC

[GitHub] eric-haibin-lin closed pull request #10208: [MXNET-117] Sparse operator broadcast_mul/div(csr, dense) = csr

eric-haibin-lin closed pull request #10208: [MXNET-117] Sparse operator broadcast_mul/div(csr, dense) = csr
URL: https://github.com/apache/incubator-mxnet/pull/10208
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/api/python/ndarray/sparse.md b/docs/api/python/ndarray/sparse.md
index b0cdd887d55..1f67e82194b 100644
--- a/docs/api/python/ndarray/sparse.md
+++ b/docs/api/python/ndarray/sparse.md
@@ -386,6 +386,8 @@ We summarize the interface for each class in the following sections.
     elemwise_add
     elemwise_sub
     elemwise_mul
+    broadcast_mul
+    broadcast_div
     negative
     dot
     add_n
diff --git a/perl-package/AI-MXNet/lib/AI/MXNet/NDArray/Sparse.pm b/perl-package/AI-MXNet/lib/AI/MXNet/NDArray/Sparse.pm
index 9b243758852..bb5171c238b 100644
--- a/perl-package/AI-MXNet/lib/AI/MXNet/NDArray/Sparse.pm
+++ b/perl-package/AI-MXNet/lib/AI/MXNet/NDArray/Sparse.pm
@@ -80,10 +80,100 @@ use overload '""' => sub {
                         my $shape_info = join('x', @{ $self->shape });
                         sprintf("\n<%s, %s @%s>", $self->_class_name, $shape_info, $self->context);
                      },
+             '+'  => \&add,
+             '-'  => \&subtract,
+             '*'  => \&multiply,
+             '/'  => \&divide,
              '+=' => \&not_implemented,
              '-=' => \&not_implemented,
              '*=' => \&not_implemented,
              '/=' => \&not_implemented;
+
+method add(AI::MXNet::NDArray|Num $other, $reverse=)
+{
+    if(blessed $other and join(',', @{ $self->shape }) eq join(',', @{ $other->shape }))
+    {
+        return AI::MXNet::NDArray::_ufunc_helper(
+            $self,
+            $other,
+            qw/elemwise_add _plus_scalar/
+        );
+    }
+    else
+    {
+        return AI::MXNet::NDArray::_ufunc_helper(
+            $self,
+            $other,
+            qw/broadcast_add _plus_scalar/
+        );
+    }
+}
+
+
+method subtract(AI::MXNet::NDArray|Num $other, $reverse=)
+{
+    if(blessed $other and join(',', @{ $self->shape }) eq join(',', @{ $other->shape }))
+    {
+        return AI::MXNet::NDArray::_ufunc_helper(
+            $self,
+            $other,
+            qw/elemwise_sub _minus_scalar _rminus_scalar/,
+            $reverse
+        );
+    }
+    else
+    {
+        return AI::MXNet::NDArray::_ufunc_helper(
+            $self,
+            $other,
+            qw/broadcast_sub _minus_scalar _rminus_scalar/,
+            $reverse
+        );
+    }
+}
+
+method multiply(AI::MXNet::NDArray|Num $other, $reverse=)
+{
+    if(blessed $other and join(',', @{ $self->shape }) eq join(',', @{ $other->shape }))
+    {
+        return AI::MXNet::NDArray::_ufunc_helper(
+            $self,
+            $other,
+            qw/elemwise_mul _mul_scalar/,
+        );
+    }
+    else
+    {
+        return AI::MXNet::NDArray::_ufunc_helper(
+            $self,
+            $other,
+            qw/broadcast_mul _mul_scalar/,
+        );
+    }
+}
+
+method divide(AI::MXNet::NDArray|Num $other, $reverse=)
+{
+    if(blessed $other and join(',', @{ $self->shape }) eq join(',', @{ $other->shape }))
+    {
+        return AI::MXNet::NDArray::_ufunc_helper(
+            $self,
+            $other,
+            qw/elemwise_div _div_scalar _rdiv_scalar/,
+            $reverse
+        );
+    }
+    else
+    {
+        return AI::MXNet::NDArray::_ufunc_helper(
+            $self,
+            $other,
+            qw/broadcast_div _div_scalar _rdiv_scalar/,
+            $reverse
+        );
+    }
+}
+
 {
     no warnings 'redefine';
     *_sync_copyfrom = *_at = *_slice = *reshape = *size = \&not_implemented;
diff --git a/python/mxnet/ndarray/sparse.py b/python/mxnet/ndarray/sparse.py
index 842e453b2dd..363ed9deb06 100644
--- a/python/mxnet/ndarray/sparse.py
+++ b/python/mxnet/ndarray/sparse.py
@@ -30,10 +30,12 @@
 
 import ctypes
 import warnings
+import operator
 from array import array as native_array
 
 __all__ = ["_ndarray_cls", "csr_matrix", "row_sparse_array",
-           "BaseSparseNDArray", "CSRNDArray", "RowSparseNDArray"]
+           "BaseSparseNDArray", "CSRNDArray", "RowSparseNDArray",
+           "add", "subtract", "multiply", "divide"]
 
 import numpy as np
 from ..base import NotSupportedForSparseNDArray
@@ -53,6 +55,7 @@
 from .ndarray import _STORAGE_TYPE_UNDEFINED, _STORAGE_TYPE_DEFAULT
 from .ndarray import zeros as _zeros_ndarray
 from .ndarray import array as _array
+from .ndarray import _ufunc_helper
 
 
 try:
@@ -114,6 +117,18 @@ def __repr__(self):
         return '\n<%s %s @%s>' % (self.__class__.__name__,
                                   shape_info, self.context)
 
+    def __add__(self, other):
+        return add(self, other)
+
+    def __sub__(self, other):
+        return subtract(self, other)
+
+    def __mul__(self, other):
+        return multiply(self, other)
+
+    def __div__(self, other):
+        return divide(self, other)
+
     def __iadd__(self, other):
         raise NotImplementedError()
 
@@ -218,6 +233,7 @@ def copyto(self, other):
         NDArray or CSRNDArray or RowSparseNDArray
             The copied array.
         """
+        # pylint: disable= no-member, protected-access
         if isinstance(other, NDArray):
             if other.handle is self.handle:
                 warnings.warn('You are attempting to copy an array to itself', RuntimeWarning)
@@ -229,6 +245,7 @@ def copyto(self, other):
             return _internal._copyto(self, out=hret)
         else:
             raise TypeError('copyto does not support type ' + str(type(other)))
+        # pylint: enable= no-member, protected-access
 
     def check_format(self, full_check=True):
         """Check whether the NDArray format is valid.
@@ -342,6 +359,7 @@ def __getitem__(self, key):
         >>> a[-1].asnumpy()
         array([[ 4.,  5.,  6.]], dtype=float32)
         """
+        # pylint: disable= no-member, protected-access
         if isinstance(key, int):
             if key == -1:
                 begin = self.shape[0] - 1
@@ -360,6 +378,7 @@ def __getitem__(self, key):
         if isinstance(key, tuple):
             raise ValueError('Multi-dimension indexing is not supported')
         raise ValueError('Undefined behaviour for {}'.format(key))
+        # pylint: enable= no-member, protected-access
 
     def __setitem__(self, key, value):
         """x.__setitem__(i, y) <=> x[i]=y
@@ -477,9 +496,11 @@ def tostype(self, stype):
         NDArray or CSRNDArray
             A copy of the array with the chosen storage stype
         """
+        # pylint: disable= no-member, protected-access
         if stype == 'row_sparse':
             raise ValueError("cast_storage from csr to row_sparse is not supported")
         return op.cast_storage(self, stype=stype)
+        # pylint: enable= no-member, protected-access
 
     def copyto(self, other):
         """Copies the value of this array to another array.
@@ -657,6 +678,7 @@ def __setitem__(self, key, value):
                [ 1.,  1.,  1.],
                [ 1.,  1.,  1.]], dtype=float32)
         """
+        # pylint: disable= no-member, protected-access
         if not self.writable:
             raise ValueError('Failed to assign to a readonly RowSparseNDArray')
         if isinstance(key, py_slice):
@@ -679,6 +701,7 @@ def __setitem__(self, key, value):
         else:
             assert(isinstance(key, (int, tuple)))
             raise TypeError('RowSparseNDArray only supports [:] for assignment')
+        # pylint: enable= no-member, protected-access
 
     @property
     def indices(self):
@@ -720,9 +743,11 @@ def tostype(self, stype):
         NDArray or RowSparseNDArray
             A copy of the array with the chosen storage stype
         """
+        # pylint: disable= no-member, protected-access
         if stype == 'csr':
             raise ValueError("cast_storage from row_sparse to csr is not supported")
         return op.cast_storage(self, stype=stype)
+        # pylint: enable= no-member, protected-access
 
     def copyto(self, other):
         """Copies the value of this array to another array.
@@ -949,6 +974,7 @@ def csr_matrix(arg1, shape=None, ctx=None, dtype=None):
 def _csr_matrix_from_definition(data, indices, indptr, shape=None, ctx=None,
                                 dtype=None, indices_type=None, indptr_type=None):
     """Create a `CSRNDArray` based on data, indices and indptr"""
+    # pylint: disable= no-member, protected-access
     storage_type = 'csr'
     # context
     ctx = Context.default_ctx if ctx is None else ctx
@@ -985,6 +1011,7 @@ def _csr_matrix_from_definition(data, indices, indptr, shape=None, ctx=None,
     check_call(_LIB.MXNDArraySyncCopyFromNDArray(result.handle, indptr.handle, ctypes.c_int(0)))
     check_call(_LIB.MXNDArraySyncCopyFromNDArray(result.handle, indices.handle, ctypes.c_int(1)))
     return result
+    # pylint: enable= no-member, protected-access
 
 def row_sparse_array(arg1, shape=None, ctx=None, dtype=None):
     """Creates a `RowSparseNDArray`, a multidimensional row sparse array with a set of \
@@ -1159,6 +1186,320 @@ def _ndarray_cls(handle, writable=True, stype=_STORAGE_TYPE_UNDEFINED):
 _set_ndarray_class(_ndarray_cls)
 
 
+def add(lhs, rhs):
+    """Returns element-wise sum of the input arrays with broadcasting.
+
+    Equivalent to ``lhs + rhs``, ``mx.nd.broadcast_add(lhs, rhs)`` and
+    ``mx.nd.broadcast_plus(lhs, rhs)`` when shapes of lhs and rhs do not
+    match. If lhs.shape == rhs.shape, this is equivalent to
+    ``mx.nd.elemwise_add(lhs, rhs)``
+
+    .. note::
+
+        If the corresponding dimensions of two arrays have the same size or one of them has size 1,
+        then the arrays are broadcastable to a common shape.abs
+
+    Parameters
+    ----------
+    lhs : scalar or array
+        First array to be added.
+    rhs : scalar or array
+         Second array to be added.
+        If ``lhs.shape != rhs.shape``, they must be
+        broadcastable to a common shape.
+
+    Returns
+    -------
+    NDArray
+        The element-wise sum of the input arrays.
+
+    Examples
+    --------
+    >>> a = mx.nd.ones((2,3)).tostype('csr')
+    >>> b = mx.nd.ones((2,3)).tostype('csr')
+    >>> a.asnumpy()
+    array([[ 1.,  1.,  1.],
+           [ 1.,  1.,  1.]], dtype=float32)
+    >>> b.asnumpy()
+    array([[ 1.,  1.,  1.],
+           [ 1.,  1.,  1.]], dtype=float32)
+    >>> (a+b).asnumpy()
+    array([[ 2.,  2.,  2.],
+           [ 2.,  2.,  2.]], dtype=float32)
+    >>> c = mx.nd.ones((2,3)).tostype('row_sparse')
+    >>> d = mx.nd.ones((2,3)).tostype('row_sparse')
+    >>> c.asnumpy()
+    array([[ 1.,  1.,  1.],
+           [ 1.,  1.,  1.]], dtype=float32)
+    >>> d.asnumpy()
+    array([[ 1.,  1.,  1.],
+           [ 1.,  1.,  1.]], dtype=float32)
+    >>> (c+d).asnumpy()
+    array([[ 2.,  2.,  2.],
+           [ 2.,  2.,  2.]], dtype=float32)
+    """
+    # pylint: disable= no-member, protected-access
+    if isinstance(lhs, NDArray) and isinstance(rhs, NDArray) and lhs.shape == rhs.shape:
+        return _ufunc_helper(
+            lhs,
+            rhs,
+            op.elemwise_add,
+            operator.add,
+            _internal._plus_scalar,
+            None)
+
+    return _ufunc_helper(
+        lhs,
+        rhs,
+        op.broadcast_add,
+        operator.add,
+        _internal._plus_scalar,
+        None)
+    # pylint: enable= no-member, protected-access
+
+
+def subtract(lhs, rhs):
+    """Returns element-wise difference of the input arrays with broadcasting.
+
+    Equivalent to ``lhs - rhs``, ``mx.nd.broadcast_sub(lhs, rhs)`` and
+    ``mx.nd.broadcast_minus(lhs, rhs)`` when shapes of lhs and rhs do not
+    match. If lhs.shape == rhs.shape, this is equivalent to
+    ``mx.nd.elemwise_sub(lhs, rhs)``
+
+    .. note::
+
+        If the corresponding dimensions of two arrays have the same size or one of them has size 1,
+        then the arrays are broadcastable to a common shape.
+
+    Parameters
+    ----------
+    lhs : scalar or array
+        First array to be subtracted.
+    rhs : scalar or array
+         Second array to be subtracted.
+        If ``lhs.shape != rhs.shape``, they must be
+        broadcastable to a common shape.__spec__
+
+    Returns
+    -------
+    NDArray
+        The element-wise difference of the input arrays.
+
+    Examples
+    --------
+    >>> a = mx.nd.ones((2,3)).tostype('csr')
+    >>> b = mx.nd.ones((2,3)).tostype('csr')
+    >>> a.asnumpy()
+    array([[ 1.,  1.,  1.],
+           [ 1.,  1.,  1.]], dtype=float32)
+    >>> b.asnumpy()
+    array([[ 1.,  1.,  1.],
+           [ 1.,  1.,  1.]], dtype=float32)
+    >>> (a-b).asnumpy()
+    array([[ 0.,  0.,  0.],
+           [ 0.,  0.,  0.]], dtype=float32)
+    >>> c = mx.nd.ones((2,3)).tostype('row_sparse')
+    >>> d = mx.nd.ones((2,3)).tostype('row_sparse')
+    >>> c.asnumpy()
+    array([[ 1.,  1.,  1.],
+           [ 1.,  1.,  1.]], dtype=float32)
+    >>> d.asnumpy()
+    array([[ 1.,  1.,  1.],
+           [ 1.,  1.,  1.]], dtype=float32)
+    >>> (c-d).asnumpy()
+    array([[ 0.,  0.,  0.],
+           [ 0.,  0.,  0.]], dtype=float32)
+    """
+    # pylint: disable= no-member, protected-access
+    if isinstance(lhs, NDArray) and isinstance(rhs, NDArray) and lhs.shape == rhs.shape:
+        return _ufunc_helper(
+            lhs,
+            rhs,
+            op.elemwise_sub,
+            operator.sub,
+            _internal._minus_scalar,
+            None)
+
+    return _ufunc_helper(
+        lhs,
+        rhs,
+        op.broadcast_sub,
+        operator.sub,
+        _internal._minus_scalar,
+        None)
+    # pylint: enable= no-member, protected-access
+
+
+def multiply(lhs, rhs):
+    """Returns element-wise product of the input arrays with broadcasting.
+
+        Equivalent to ``lhs * rhs`` and ``mx.nd.broadcast_mul(lhs, rhs)``
+        when shapes of lhs and rhs do not match. If lhs.shape == rhs.shape,
+        this is equivalent to ``mx.nd.elemwise_mul(lhs, rhs)``
+
+    .. note::
+
+        If the corresponding dimensions of two arrays have the same size or one of them has size 1,
+        then the arrays are broadcastable to a common shape.
+
+    Parameters
+    ----------
+    lhs : scalar or array
+        First array to be multiplied.
+    rhs : scalar or array
+         Second array to be multiplied.
+        If ``lhs.shape != rhs.shape``, they must be
+        broadcastable to a common shape.
+
+    Returns
+    -------
+    NDArray
+        The element-wise multiplication of the input arrays.
+
+    Examples
+    --------
+    >>> x = mx.nd.ones((2,3)).tostype('csr')
+    >>> y = mx.nd.arange(2).reshape((2,1))
+    >>> z = mx.nd.arange(3)
+    >>> x.asnumpy()
+    array([[ 1.,  1.,  1.],
+           [ 1.,  1.,  1.]], dtype=float32)
+    >>> y.asnumpy()
+    array([[ 0.],
+           [ 1.]], dtype=float32)
+    >>> z.asnumpy()
+    array([ 0.,  1.,  2.], dtype=float32)
+    >>> (x*2).asnumpy()
+    array([[ 2.,  2.,  2.],
+           [ 2.,  2.,  2.]], dtype=float32)
+    >>> (x*y).asnumpy()
+    array([[ 0.,  0.,  0.],
+           [ 1.,  1.,  1.]], dtype=float32)
+    >>> mx.nd.sparse.multiply(x, y).asnumpy()
+    array([[ 0.,  0.,  0.],
+           [ 1.,  1.,  1.]], dtype=float32)
+    >>> (x*z).asnumpy()
+    array([[ 0.,  1.,  2.],
+           [ 0.,  1.,  2.]], dtype=float32)
+    >>> mx.nd.sparse.multiply(x, z).asnumpy()
+    array([[ 0.,  1.,  2.],
+           [ 0.,  1.,  2.]], dtype=float32)
+    >>> z = z.reshape((1, 3))
+    >>> z.asnumpy()
+    array([[ 0.,  1.,  2.]], dtype=float32)
+    >>> (x*z).asnumpy()
+    array([[ 0.,  1.,  2.],
+           [ 0.,  1.,  2.]], dtype=float32)
+    >>> mx.nd.sparse.multiply(x, z).asnumpy()
+    array([[ 0.,  1.,  2.],
+           [ 0.,  1.,  2.]], dtype=float32)
+    """
+    # pylint: disable= no-member, protected-access
+    if isinstance(lhs, NDArray) and isinstance(rhs, NDArray) and lhs.shape == rhs.shape:
+        return _ufunc_helper(
+            lhs,
+            rhs,
+            op.elemwise_mul,
+            operator.mul,
+            _internal._mul_scalar,
+            None)
+
+    return _ufunc_helper(
+        lhs,
+        rhs,
+        op.broadcast_mul,
+        operator.mul,
+        _internal._mul_scalar,
+        None)
+    # pylint: enable= no-member, protected-access
+
+
+def divide(lhs, rhs):
+    """Returns element-wise division of the input arrays with broadcasting.
+
+    Equivalent to ``lhs / rhs`` and ``mx.nd.broadcast_div(lhs, rhs)``
+    when shapes of lhs and rhs do not match. If lhs.shape == rhs.shape,
+    this is equivalent to ``mx.nd.elemwise_div(lhs, rhs)``
+
+    .. note::
+
+        If the corresponding dimensions of two arrays have the same size or one of them has size 1,
+        then the arrays are broadcastable to a common shape.
+
+    Parameters
+    ----------
+    lhs : scalar or array
+        First array in division.
+    rhs : scalar or array
+         Second array in division.
+        The arrays to be divided. If ``lhs.shape != rhs.shape``, they must be
+        broadcastable to a common shape.
+
+    Returns
+    -------
+    NDArray
+        The element-wise division of the input arrays.
+
+    Examples
+    --------
+    >>> x = (mx.nd.ones((2,3))*6).tostype('csr')
+    >>> y = mx.nd.arange(2).reshape((2,1)) + 1
+    >>> z = mx.nd.arange(3) + 1
+    >>> x.asnumpy()
+    array([[ 6.,  6.,  6.],
+           [ 6.,  6.,  6.]], dtype=float32)
+    >>> y.asnumpy()
+    array([[ 1.],
+           [ 2.]], dtype=float32)
+    >>> z.asnumpy()
+    array([ 1.,  2.,  3.], dtype=float32)
+    >>> x/2
+    <NDArray 2x3 @cpu(0)>
+    >>> (x/3).asnumpy()
+    array([[ 2.,  2.,  2.],
+           [ 2.,  2.,  2.]], dtype=float32)
+    >>> (x/y).asnumpy()
+    array([[ 6.,  6.,  6.],
+           [ 3.,  3.,  3.]], dtype=float32)
+    >>> mx.nd.sparse.divide(x,y).asnumpy()
+    array([[ 6.,  6.,  6.],
+           [ 3.,  3.,  3.]], dtype=float32)
+    >>> (x/z).asnumpy()
+    array([[ 6.,  3.,  2.],
+           [ 6.,  3.,  2.]], dtype=float32)
+    >>> mx.nd.sprase.divide(x,z).asnumpy()
+    array([[ 6.,  3.,  2.],
+           [ 6.,  3.,  2.]], dtype=float32)
+    >>> z = z.reshape((1,3))
+    >>> z.asnumpy()
+    array([[ 1.,  2.,  3.]], dtype=float32)
+    >>> (x/z).asnumpy()
+    array([[ 6.,  3.,  2.],
+           [ 6.,  3.,  2.]], dtype=float32)
+    >>> mx.nd.sparse.divide(x,z).asnumpy()
+    array([[ 6.,  3.,  2.],
+           [ 6.,  3.,  2.]], dtype=float32)
+    """
+    # pylint: disable= no-member, protected-access
+    if isinstance(lhs, NDArray) and isinstance(rhs, NDArray) and lhs.shape == rhs.shape:
+        return _ufunc_helper(
+            lhs,
+            rhs,
+            op.elemwise_div,
+            operator.truediv,
+            _internal._div_scalar,
+            None)
+
+    return _ufunc_helper(
+        lhs,
+        rhs,
+        op.broadcast_div,
+        operator.truediv,
+        _internal._div_scalar,
+        None)
+    # pylint: enable= no-member, protected-access
+
+
 def zeros(stype, shape, ctx=None, dtype=None, **kwargs):
     """Return a new array of given shape and type, filled with zeros.
 
@@ -1184,6 +1525,7 @@ def zeros(stype, shape, ctx=None, dtype=None, **kwargs):
     >>> mx.nd.sparse.zeros('row_sparse', (1,2), ctx=mx.cpu(), dtype='float16').asnumpy()
     array([[ 0.,  0.]], dtype=float16)
     """
+    # pylint: disable= no-member, protected-access
     if stype == 'default':
         return _zeros_ndarray(shape, ctx=ctx, dtype=dtype, **kwargs)
     if ctx is None:
@@ -1195,6 +1537,7 @@ def zeros(stype, shape, ctx=None, dtype=None, **kwargs):
         raise ValueError("unknown storage type" + stype)
     out = _ndarray_cls(_new_alloc_handle(stype, shape, ctx, True, dtype, aux_types))
     return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, out=out, **kwargs)
+    # pylint: enable= no-member, protected-access
 
 
 def empty(stype, shape, ctx=None, dtype=None):
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h
index a2e63fefad5..42a8f0f2c15 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -26,6 +26,7 @@
 #define MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_BROADCAST_OP_H_
 
 #include <mxnet/operator_util.h>
+#include <mxnet/op_attr_types.h>
 #include <algorithm>
 #include <vector>
 #include <string>
@@ -76,6 +77,34 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
+inline bool BinaryBroadcastMulStorageType(const nnvm::NodeAttrs& attrs,
+                                          const int dev_mask,
+                                          DispatchMode* dispatch_mode,
+                                          std::vector<int>* in_attrs,
+                                          std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 2U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  const int lhs_stype = in_attrs->at(0);
+  const int rhs_stype = in_attrs->at(1);
+  int& out_stype = out_attrs->at(0);
+  bool dispatched = false;
+  // For GPU, directly fallback
+  const auto dispatch_ex = (dev_mask == mshadow::gpu::kDevMask)? DispatchMode::kFComputeFallback :
+                           DispatchMode::kFComputeEx;
+  if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
+    dispatched = storage_type_assign(&out_stype, kDefaultStorage,
+                                     dispatch_mode, DispatchMode::kFCompute);
+  }
+  if (!dispatched && lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) {
+    dispatched = storage_type_assign(&out_stype, kCSRStorage,
+                                     dispatch_mode, dispatch_ex);
+  }
+  if (!dispatched) {
+    dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+  }
+  return dispatched;
+}
+
 #define BROADCAST_NDIM_SWITCH(ndim, NDim, ...)  \
   if (ndim <= 2) {                    \
     const int NDim = 2;               \
@@ -155,6 +184,22 @@ struct binary_broadcast_kernel {
     }
   }
 };
+
+template<int req, typename OP>
+struct csr_dns_csr_broadcast_kernel {
+  template <typename DType, typename CType, typename RType>
+  MSHADOW_XINLINE static void Map(int row, const DType *csr_data, const CType *csr_indices,
+                                  const RType *csr_indptr, const DType *dns,
+                                  DType *out, const nnvm::dim_t row_length, bool col_vec) {
+    const nnvm::dim_t curr_row_i = csr_indptr[row];
+    const nnvm::dim_t next_row_i = csr_indptr[row + 1];
+    for (nnvm::dim_t iter = curr_row_i; iter < next_row_i; iter++) {
+      KERNEL_ASSIGN(out[iter], req, OP::Map(csr_data[iter],
+                    (col_vec)? dns[row] : dns[csr_indices[iter]]));
+    }
+  }
+};
+
 }  // namespace mxnet_op
 
 template<typename xpu, typename OP>
@@ -185,6 +230,93 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
   }
 }
 
+template<typename xpu, typename OP>
+void BinaryBroadcastCsrDnsCsrImpl(const OpContext& ctx,
+                                  const NDArray& csr,
+                                  const NDArray& dns,
+                                  const OpReqType req,
+                                  const NDArray& output) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using namespace csr;
+  CHECK(req != kAddTo && req != kWriteInplace);
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  bool col_vec;
+  if (dns.shape().ndim() == 1) {
+    col_vec = false;
+  } else {
+    col_vec = (dns.shape()[0] == csr.shape()[0])? true : false;
+  }
+
+  if (csr.storage_initialized()) {
+    const nnvm::dim_t nnz = csr.storage_shape()[0];
+    const nnvm::dim_t num_rows = output.shape()[0];
+    output.CheckAndAlloc({Shape1(num_rows + 1), Shape1(nnz)});
+
+    MSHADOW_TYPE_SWITCH(output.dtype(), DType, {
+      MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), CType, {
+        MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIndPtr), RType, {
+          MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
+            if ((dns.shape().ndim() == 2 && dns.shape()[0] == 1 && dns.shape()[1] == 1) ||
+                (dns.shape().ndim() == 1 && dns.shape()[0] == 1)) {
+              Kernel<op_with_req<OP, req_type>, xpu>::Launch(
+                s, nnz, output.data().dptr<DType>(), csr.data().dptr<DType>(),
+                dns.data().dptr<DType>()[0]);
+            } else {
+              Kernel<csr_dns_csr_broadcast_kernel<req_type, OP>, xpu>::Launch(
+                s, num_rows, csr.data().dptr<DType>(), csr.aux_data(kIdx).dptr<CType>(),
+                csr.aux_data(kIndPtr).dptr<RType>(), dns.data().dptr<DType>(),
+                output.data().dptr<DType>(), csr.shape()[1], col_vec);
+            }
+            Copy(output.aux_data(kIdx).FlatTo1D<xpu, CType>(),
+                 csr.aux_data(kIdx).FlatTo1D<xpu, CType>());
+            Copy(output.aux_data(kIndPtr).FlatTo1D<xpu, RType>(),
+                 csr.aux_data(kIndPtr).FlatTo1D<xpu, RType>());
+          });
+        });
+      });
+    });
+  // If input csr is an empty matrix, fill zeros and return
+  } else {
+    FillZerosCsrImpl(s, output);
+    return;
+  }
+}
+
+template<typename xpu, typename OP>
+void BinaryBroadcastComputeEx(const nnvm::NodeAttrs& attrs,
+                              const OpContext& ctx,
+                              const std::vector<NDArray>& inputs,
+                              const std::vector<OpReqType>& req,
+                              const std::vector<NDArray>& outputs) {
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  CHECK_LE(inputs[1].shape().ndim(), 2U)
+    << "input dense matrix should have less than or equal to 2 dimensions";
+  if (req[0] == kNullOp) return;
+  const NDArray& lhs = inputs[0];
+  const NDArray& rhs = inputs[1];
+  const NDArray& out = outputs[0];
+  const auto lhs_stype = lhs.storage_type();
+  const auto rhs_stype = rhs.storage_type();
+  const auto out_stype = out.storage_type();
+  // If the input is a matrix with the same shape, should be elemwise
+  if ((rhs.shape().ndim() != 1U) && (rhs.shape()[0] != 1) && (rhs.shape()[1] != 1)) {
+    // Currently do not support elementwise_mul/div(csr, dense) = csr, log and exit
+    using common::operator_string;
+    LOG(FATAL) << operator_string(attrs, ctx, inputs, req, outputs)
+               << "\nIf shape of lhs and rhs match, please explicitly use elemwise_mul/div\n";
+  } else {
+    // broadcast(CSR, Dense(1D)) = CSR
+    if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kCSRStorage) {
+      BinaryBroadcastCsrDnsCsrImpl<xpu, OP>(ctx, lhs, rhs, req[0], out);
+    } else {
+      LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
+    }
+  }
+}
+
 template<typename xpu, typename LOP, typename ROP>
 void BinaryBroadcastBackwardUseNone(const nnvm::NodeAttrs& attrs,
                                     const OpContext& ctx,
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
index 634e90557ef..6be4c265b9e 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
@@ -120,8 +120,13 @@ Example::
    broadcast_mul(x, y) = [[ 0.,  0.,  0.],
                           [ 1.,  1.,  1.]]
 
+Supported sparse operations:
+   broadcast_mul(csr, dense(1D)) = csr (CPU only)
+
 )code" ADD_FILELINE)
 .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::mul>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryBroadcastComputeEx<cpu, op::mshadow_op::mul>)
+.set_attr<FInferStorageType>("FInferStorageType", BinaryBroadcastMulStorageType)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"});
 
 
@@ -154,8 +159,13 @@ Example::
    broadcast_div(x, y) = [[ 3.,  3.,  3.],
                           [ 2.,  2.,  2.]]
 
+Supported sparse operations:
+   broadcast_div(csr, dense(1D)) = csr (CPU only)
+
 )code" ADD_FILELINE)
 .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::div>)
+.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryBroadcastComputeEx<cpu, op::mshadow_op::div>)
+.set_attr<FInferStorageType>("FInferStorageType", BinaryBroadcastMulStorageType)
 .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_div"});
 
 NNVM_REGISTER_OP(_backward_broadcast_div)
diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py
index 5ad5215036d..34794866546 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -1678,6 +1678,29 @@ def check_sparse_embedding(in_dim, out_dim, batch, densities, deterministic):
     check_sparse_embedding(in_dim, out_dim, batch, densities, False)
 
 
+@with_seed()
+def test_sparse_broadcast_mul_div():
+    def check_broadcast_mul(mx_lhs, mx_rhs, np_lhs, np_rhs, dtype):
+        assert_almost_equal(mx.nd.sparse.multiply(mx_lhs, mx_rhs).asnumpy(), np.multiply(np_lhs, np_rhs), atol=1e-4)
+    def check_broadcast_div(mx_lhs, mx_rhs, np_lhs, np_rhs, dtype):
+        assert_almost_equal(mx.nd.sparse.divide(mx_lhs, mx_rhs).asnumpy(), np.divide(np_lhs, np_rhs), atol=1e-4)
+    stype = 'csr'
+    shape = rand_shape_2d()
+    num_rows = shape[0]
+    num_cols = shape[1]
+    for density in [0.1 * i for i in range(10)]:
+        mx_lhs = rand_ndarray(shape, stype, density)
+        np_lhs = mx_lhs.asnumpy()
+        mx_rhs_row_2D = rand_ndarray((1, num_cols), 'default')
+        mx_rhs_row_1D = mx_rhs_row_2D.reshape((num_cols))
+        mx_rhs_col = rand_ndarray((num_rows, 1), 'default')
+        mx_rhs_scalar_2D = rand_ndarray((1, 1), 'default')
+        mx_rhs_scalar_1D = mx_rhs_scalar_2D.reshape((1, ))
+        for mx_rhs in [mx_rhs_row_2D, mx_rhs_row_1D, mx_rhs_col, mx_rhs_scalar_2D, mx_rhs_scalar_1D]:
+            np_rhs = mx_rhs.asnumpy()
+            check_broadcast_mul(mx_lhs, mx_rhs, np_lhs, np_rhs, np.float32)
+            check_broadcast_div(mx_lhs, mx_rhs, np_lhs, np_rhs, np.float32)
+
 @with_seed()
 def test_scatter_ops():
     def csr_get_seen_points(name, csr_array, verbose=False):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services