You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/11/18 06:38:47 UTC

[incubator-mxnet] branch master updated: Fix sparse dot test failure due to launching kernel when nnr = 0 and bug of square_sum (#8470)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 33698ea  Fix sparse dot test failure due to launching kernel when nnr = 0 and bug of square_sum (#8470)
33698ea is described below

commit 33698ea9c1d641f418314c99333d3d08f668be1f
Author: reminisce <wu...@gmail.com>
AuthorDate: Fri Nov 17 22:38:44 2017 -0800

    Fix sparse dot test failure due to launching kernel when nnr = 0 and bug of square_sum (#8470)
    
    * Fix nnr = 0 crash for sparse dot
    
    * Change comment
    
    * Add sparse dot zero output unit test
    
    * Fix test case
    
    * Fix bug of square_sum
    
    * Remove code
    
    * Fix square_sum test case
    
    * Fix broadcast min/max unit test failure
    
    * Fix test
    
    * Fix doc for __getitem__
---
 python/mxnet/ndarray/ndarray.py               |  4 +-
 src/operator/tensor/dot-inl.cuh               | 29 +++++++--
 src/operator/tensor/dot-inl.h                 | 32 +++++++---
 src/operator/tensor/square_sum-inl.h          | 88 ++++++++++++++++++++-------
 tests/python/unittest/test_operator.py        | 18 +++++-
 tests/python/unittest/test_sparse_operator.py | 62 ++++++++++++++++---
 6 files changed, 186 insertions(+), 47 deletions(-)

diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 265adfe..b655ad8 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -432,13 +432,14 @@ fixed-size items.
             raise ValueError('Indexing NDArray with index=%s and type=%s is not supported'
                              % (str(key), str(type(key))))
 
+    # pylint: disable=line-too-long
     def __getitem__(self, key):
         """x.__getitem__(i) <=> x[i]
         Returns a sliced view of this array if the elements fetched are contiguous in memory;
         otherwise, returns a newly created NDArray.
         This functions supports advanced indexing defined in the following reference with
         some limitations.
-        https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#combining-advanced-and-basic-indexing  # pylint: disable=line-too-long
+        https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#combining-advanced-and-basic-indexing
         The following features/functionality are not supported for now:
         1. If key is a list type, only a list of integers is supported,
            i.e. key=[1, 2] is okay, while not for key=[[1]].
@@ -489,6 +490,7 @@ fixed-size items.
         else:
             raise ValueError('Indexing NDArray with index=%s and type=%s is not supported'
                              % (str(key), str(type(key))))
+    # pylint: enable=line-too-long
 
     def _get_index_nd(self, key):
         """Returns an index array for use in scatter_nd and gather_nd."""
diff --git a/src/operator/tensor/dot-inl.cuh b/src/operator/tensor/dot-inl.cuh
index 2b346bf..c546c43 100644
--- a/src/operator/tensor/dot-inl.cuh
+++ b/src/operator/tensor/dot-inl.cuh
@@ -454,13 +454,16 @@ inline void DotCsrDnsDnsImpl(const OpContext& ctx,
                              TBlob* ret) {
   if (kNullOp == req) return;
   CHECK_EQ(lhs.storage_type(), kCSRStorage);
-  if (!lhs.storage_initialized()) return;
+  mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
+  if (!lhs.storage_initialized()) {
+    Fill(s, *ret, req, 0);
+    return;
+  }
 
   using mshadow::cuda::kBaseThreadNum;
   using mxnet_op::Kernel;
   using mxnet_op::set_zero;
   using nnvm::dim_t;
-  mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
 
   const dim_t num_rows_l = lhs.shape()[0];
   const dim_t num_cols_r = rhs.shape_[1];
@@ -587,13 +590,16 @@ inline void DotCsrDnsRspImpl(const OpContext& ctx,
   CHECK_EQ(lhs.storage_type(), kCSRStorage);
   CHECK_EQ(ret->storage_type(), kRowSparseStorage);
   CHECK_EQ(req, kWriteTo);
-  if (!lhs.storage_initialized()) return;
+  mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
+  if (!lhs.storage_initialized()) {
+    FillZerosRspImpl(s, *ret);
+    return;
+  }
 
   using mshadow::Shape1;
   using mxnet_op::Kernel;
   using mxnet_op::set_zero;
   using nnvm::dim_t;
-  mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
 
   const TBlob data_l = lhs.data();
   const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
@@ -648,6 +654,10 @@ inline void DotCsrDnsRspImpl(const OpContext& ctx,
           dim_t nnr_out = 0;
           CUDA_CALL(cudaMemcpy(&nnr_out, &row_flg_out[num_cols_l-1], sizeof(dim_t),
                                cudaMemcpyDeviceToHost));
+          if (0 == nnr_out) {
+            FillZerosRspImpl(s, *ret);
+            return;
+          }
 
           // Allocate output matrix space
           ret->CheckAndAlloc({Shape1(nnr_out)});
@@ -702,14 +712,17 @@ inline void DotCsrRspRspImpl(const OpContext& ctx,
   CHECK_EQ(lhs.storage_type(), kCSRStorage);
   CHECK_EQ(rhs.storage_type(), kRowSparseStorage);
   CHECK_EQ(ret->storage_type(), kRowSparseStorage);
-  if (!lhs.storage_initialized() || !rhs.storage_initialized()) return;
+  mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
+  if (!lhs.storage_initialized() || !rhs.storage_initialized()) {
+    FillZerosRspImpl(s, *ret);
+    return;
+  }
   CHECK_EQ(req, kWriteTo);
 
   using mshadow::Shape1;
   using mxnet_op::Kernel;
   using mxnet_op::set_zero;
   using nnvm::dim_t;
-  mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
 
   const TBlob data_l = lhs.data();
   const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
@@ -767,6 +780,10 @@ inline void DotCsrRspRspImpl(const OpContext& ctx,
             dim_t nnr_out = 0;
             CUDA_CALL(cudaMemcpy(&nnr_out, &row_flg_out[num_cols_l-1], sizeof(dim_t),
                                  cudaMemcpyDeviceToHost));
+            if (0 == nnr_out) {
+              FillZerosRspImpl(s, *ret);
+              return;
+            }
 
             // Allocate output matrix space
             ret->CheckAndAlloc({mshadow::Shape1(nnr_out)});
diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h
index 26061cb..7ab4710 100644
--- a/src/operator/tensor/dot-inl.h
+++ b/src/operator/tensor/dot-inl.h
@@ -30,6 +30,7 @@
 #include <algorithm>
 #include <utility>
 #include <type_traits>
+#include "./init_op.h"
 #include "../mshadow_op.h"
 #include "../elemwise_op_common.h"
 #include "../mxnet_op.h"
@@ -535,11 +536,14 @@ inline void DotCsrDnsDnsImpl(const OpContext& ctx,
                              TBlob* ret) {
   if (kNullOp == req) return;
   CHECK_EQ(lhs.storage_type(), kCSRStorage);
-  if (!lhs.storage_initialized()) return;
+  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
+  if (!lhs.storage_initialized()) {
+    Fill(s, *ret, req, 0);
+    return;
+  }
 
   using nnvm::dim_t;
 
-  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
   const TBlob data_l = lhs.data();
   const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
   const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
@@ -586,13 +590,16 @@ inline void DotCsrDnsRspImpl(const OpContext& ctx,
   if (kNullOp == req) return;
   CHECK_EQ(lhs.storage_type(), kCSRStorage);
   CHECK_EQ(ret->storage_type(), kRowSparseStorage);
-  if (!lhs.storage_initialized()) return;
+  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
+  if (!lhs.storage_initialized()) {
+    FillZerosRspImpl(s, *ret);
+    return;
+  }
   CHECK_EQ(req, kWriteTo);
 
   using mxnet_op::set_zero;
   using nnvm::dim_t;
 
-  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
   const TBlob data_l = lhs.data();
   const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
   const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
@@ -621,8 +628,11 @@ inline void DotCsrDnsRspImpl(const OpContext& ctx,
                 seg_len, lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]);
             dim_t nnr = 0;
             nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr);
+            if (0 == nnr) {
+              FillZerosRspImpl(s, *ret);
+              return;
+            }
             ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr));
-            if (0 == nnr) return;
             mshadow::Tensor<cpu, 2, DType> rsp_data = data_out.FlatTo2D<cpu, DType>(s);
             dim_t idx = 0;
             for (index_t i = 0; i < ret->shape()[0]; ++i) {
@@ -725,13 +735,16 @@ inline void DotCsrRspRspImpl(const OpContext& ctx,
   CHECK_EQ(lhs.storage_type(), kCSRStorage);
   CHECK_EQ(rhs.storage_type(), kRowSparseStorage);
   CHECK_EQ(ret->storage_type(), kRowSparseStorage);
-  if (!lhs.storage_initialized() || !rhs.storage_initialized()) return;
+  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
+  if (!lhs.storage_initialized() || !rhs.storage_initialized()) {
+    FillZerosRspImpl(s, *ret);
+    return;
+  }
   CHECK_EQ(req, kWriteTo);
 
   using mxnet_op::set_zero;
   using nnvm::dim_t;
 
-  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
   const TBlob data_l = lhs.data();
   const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
   const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
@@ -764,8 +777,11 @@ inline void DotCsrRspRspImpl(const OpContext& ctx,
                 ret->shape()[0], ret->shape()[1], seg_len);
             dim_t nnr = 0;
             nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr);
+            if (0 == nnr) {
+              FillZerosRspImpl(s, *ret);
+              return;
+            }
             ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr));
-            if (0 == nnr) return;
             mshadow::Tensor<cpu, 2, DType> rsp_data = data_out.FlatTo2D<cpu, DType>(s);
             dim_t idx = 0;
             for (index_t i = 0; i < ret->shape()[0]; ++i) {
diff --git a/src/operator/tensor/square_sum-inl.h b/src/operator/tensor/square_sum-inl.h
index 7ce5b1e..a052ad9 100644
--- a/src/operator/tensor/square_sum-inl.h
+++ b/src/operator/tensor/square_sum-inl.h
@@ -179,7 +179,7 @@ struct SquareSumRspKernel<req, 1, true> {
   }
 };
 
-template<int req, int axis, int ograd_stype = kDefaultStorage>
+template<int req, int axis, int ograd_stype = kDefaultStorage, bool is_data_full_rsp = false>
 struct SquareSumRspGradKernel;
 
 template<int req>
@@ -224,11 +224,10 @@ struct SquareSumRspGradKernel<req, 1> {
 
 /*!
  * Note: This kernel assumes that the ograd and in_data
- * are all rsp and have equal row_idx array, or
- * in_data is a full rsp.
+ * are all rsp and have equal row_idx array.
  */
 template<int req>
-struct SquareSumRspGradKernel<req, 1, kRowSparseStorage> {
+struct SquareSumRspGradKernel<req, 1, kRowSparseStorage, false> {
   /*!
    * \param i index of igrad.data()
    * \param in_grad_row_idx row_idx of the gradient of the op's input
@@ -243,10 +242,36 @@ struct SquareSumRspGradKernel<req, 1, kRowSparseStorage> {
                                   const DType* in_data, const int64_t num_cols) {
     const int64_t row = i / num_cols;
     in_grad_row_idx[row] = out_grad_row_idx[row];
-    KERNEL_ASSIGN(in_grad[i], req, 2*in_data[i]*out_grad[row]);
+    KERNEL_ASSIGN(in_grad[i], req, 2 * in_data[i] * out_grad[row]);
   }
 };
 
+/*!
+ * Note: This kernel assumes that the ograd and in_data
+ * are all rsp and in_data is a full rsp.
+ */
+template<int req>
+struct SquareSumRspGradKernel<req, 1, kRowSparseStorage, true> {
+  /*!
+   * \param i index of igrad.data()
+   * \param in_grad_row_idx row_idx of the gradient of the op's input
+   * \param in_grad gradient of the op's input
+   * \param out_grad_row_idx row_idx of the gradient of the op's output
+   * \param out_grad gradient of the op's output
+   * \param in_data op's input
+   */
+  template<typename IType, typename DType>
+  MSHADOW_XINLINE static void Map(int i, IType* in_grad_row_idx, DType* in_grad,
+                                  const IType* out_grad_row_idx, const DType* out_grad,
+                                  const DType* in_data, const int64_t num_cols) {
+    const int64_t row = i / num_cols;
+    const int64_t row_dns = out_grad_row_idx[row];
+    in_grad_row_idx[row] = row_dns;
+    KERNEL_ASSIGN(in_grad[i], req, 2 * in_data[row_dns*num_cols+i%num_cols] * out_grad[row]);
+  }
+};
+
+
 template<typename xpu>
 void SquareSumRspImpl(const nnvm::NodeAttrs& attrs,
                       mshadow::Stream<xpu>* s,
@@ -334,6 +359,12 @@ void SquareSumRspImpl(const nnvm::NodeAttrs& attrs,
   }
 }
 
+/*!\brief
+ * This function only supports the following three situations:
+ * 1. ograd is a dns and input is an rsp
+ * 2. ograd and input are both rsp and have the same row_idx array
+ * 3. ograd and input are both rsp and input is a full rsp
+ */
 template<typename xpu>
 void SquareSumRspGradImpl(const nnvm::NodeAttrs& attrs,
                           mshadow::Stream<xpu>* s,
@@ -350,23 +381,21 @@ void SquareSumRspGradImpl(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(input.storage_type(), kRowSparseStorage);
   CHECK_EQ(igrad->storage_type(), kRowSparseStorage);
   CHECK_EQ(req, kWriteTo);
-  if (!input.storage_initialized()) {
+  if (!input.storage_initialized()
+      || (ograd.storage_type() == kRowSparseStorage && !ograd.storage_initialized())) {
     FillZerosRspImpl(s, *igrad);
     return;
   }
 
   using namespace mxnet_op;
-  // TODO(junwu) change the input of CheckAndAlloc
-  // if we want to support differen row idx arrays
-  // for ograd and input when they are both row-sparse ndarrays
-  igrad->CheckAndAlloc({input.aux_shape(rowsparse::kIdx)});
   const int64_t num_cols = input.storage_shape()[1];
-  const TBlob& igrad_data = igrad->data();
-  const TBlob igrad_row_idx = igrad->aux_data(rowsparse::kIdx);
   const TBlob& ograd_data = ograd.data();
   const TBlob& in_data = input.data();
   const TBlob in_row_idx = input.aux_data(rowsparse::kIdx);
   if (ograd.storage_type() == kDefaultStorage) {
+    igrad->CheckAndAlloc({input.aux_shape(rowsparse::kIdx)});
+    const TBlob& igrad_data = igrad->data();
+    const TBlob igrad_row_idx = igrad->aux_data(rowsparse::kIdx);
     if (0 == param.axis[0]) {  // forward is sum per column
       MSHADOW_TYPE_SWITCH(igrad_data.type_flag_, DType, {
         MSHADOW_IDX_TYPE_SWITCH(igrad_row_idx.type_flag_, IType, {
@@ -396,18 +425,28 @@ void SquareSumRspGradImpl(const nnvm::NodeAttrs& attrs,
     CHECK_EQ(ograd.shape().ndim(), 2U);
     const TBlob ograd_row_idx = ograd.aux_data(rowsparse::kIdx);
     CHECK(ograd_row_idx.Size() == in_row_idx.Size() || in_row_idx.Size() == in_data.shape_[0]);
+    igrad->CheckAndAlloc({ograd.aux_shape(rowsparse::kIdx)});
+    const TBlob& igrad_data = igrad->data();
+    const TBlob igrad_row_idx = igrad->aux_data(rowsparse::kIdx);
     MSHADOW_IDX_TYPE_SWITCH(igrad_row_idx.type_flag_, IType, {
       if (std::is_same<xpu, cpu>::value) {
-        const IType* first1 = ograd_row_idx.dptr<IType>();
-        const IType* last1 = first1 + ograd_row_idx.Size();
-        const IType* first2 = in_row_idx.dptr<IType>();
         // when ograd_row_idx and in_row_idx have the same size and input is not a full rsp
         // ograd_row_idx and in_row_idx are expected to have the same elements
-        if (ograd_row_idx.Size() == in_row_idx.Size() && in_row_idx.Size() != in_data.shape_[0]) {
+        if (in_row_idx.Size() != input.shape()[0]) {  // if input data is not a full rsp
+          CHECK_EQ(ograd_row_idx.Size(), in_row_idx.Size()) << "SquareSumRspGradImpl only supports"
+                                                               " equal ograd_row_idx and"
+                                                               " input_row_idx when ograd and"
+                                                               " input are both row-sparse and"
+                                                               " input data is not a full"
+                                                               " row-sparse matrix";
+          const IType* first1 = ograd_row_idx.dptr<IType>();
+          const IType* last1 = first1 + ograd_row_idx.Size();
+          const IType* first2 = in_row_idx.dptr<IType>();
           CHECK(std::equal(first1, last1, first2)) << "SquareSumRspGradImpl only supports"
                                                       " equal ograd_row_idx and input_row_idx"
                                                       " when ograd and input are both"
-                                                      " row-sparse";
+                                                      " row-sparse and input data is not a full"
+                                                      " row-sparse matrix";
         }
       } else {
         LOG(FATAL) << "SquareSumRspGradImpl has not implemented GPU version when"
@@ -415,10 +454,17 @@ void SquareSumRspGradImpl(const nnvm::NodeAttrs& attrs,
       }
       MSHADOW_TYPE_SWITCH(igrad_data.type_flag_, DType, {
         MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
-          Kernel<SquareSumRspGradKernel<req_type, 1, kRowSparseStorage>, xpu>::Launch(
-              s, igrad_data.Size(), igrad_row_idx.dptr<IType>(),
-              igrad_data.dptr<DType>(), ograd_row_idx.dptr<IType>(),
-              ograd_data.dptr<DType>(), in_data.dptr<DType>(), num_cols);
+          if (in_row_idx.Size() != input.shape()[0]) {  // input data is not a full rsp
+            Kernel<SquareSumRspGradKernel<req_type, 1, kRowSparseStorage, false>, xpu>::Launch(
+                s, igrad_data.Size(), igrad_row_idx.dptr<IType>(),
+                igrad_data.dptr<DType>(), ograd_row_idx.dptr<IType>(),
+                ograd_data.dptr<DType>(), in_data.dptr<DType>(), num_cols);
+          } else {  // input data is a full rsp
+            Kernel<SquareSumRspGradKernel<req_type, 1, kRowSparseStorage, true>, xpu>::Launch(
+                s, igrad_data.Size(), igrad_row_idx.dptr<IType>(),
+                igrad_data.dptr<DType>(), ograd_row_idx.dptr<IType>(),
+                ograd_data.dptr<DType>(), in_data.dptr<DType>(), num_cols);
+          }
         })
       })
     })
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index d322fa4..18544ba 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -1267,6 +1267,18 @@ def test_binary_op():
 
 
 def test_broadcast_binary_op():
+    def check_bmaxmin_gradient(test_sym, x, y, delta, rtol, atol):
+        """This function ensures that checking the numerical gradient of
+        broadcast_max/min is not crossing the boundary y=x where there
+        is no gradient definition at those sigularities."""
+        x_max = np.max(x)
+        y = x_max + 2 * delta + np.random.random(y.shape)
+        check_numeric_gradient(test_sym, [x, y], numeric_eps=delta, rtol=rtol, atol=atol)
+
+        x_min = np.min(x)
+        y = x_min - 2 * delta - np.random.random(y.shape)
+        check_numeric_gradient(test_sym, [x, y], numeric_eps=delta, rtol=rtol, atol=atol)
+
     a = mx.sym.Variable('a')
     b = mx.sym.Variable('b')
 
@@ -1316,13 +1328,15 @@ def test_broadcast_binary_op():
         c = mx.sym.broadcast_maximum(a, b)
         check_binary_op_forward(c, lambda x, y: np.maximum(x, y), gen_broadcast_data, mx_nd_func=mx.nd.maximum)
         # pass idx=200 to gen_broadcast_data so that generated ndarrays' sizes are not too big
-        check_numeric_gradient(c, gen_broadcast_data(idx=200), rtol=1e-2, atol=1e-3)
+        data = gen_broadcast_data(idx=200)
+        check_bmaxmin_gradient(c, data[0], data[1], 0.001, 1e-2, 1e-3)
 
     def test_bmin(a, b):
         c = mx.sym.broadcast_minimum(a, b)
         check_binary_op_forward(c, lambda x, y: np.minimum(x, y), gen_broadcast_data, mx_nd_func=mx.nd.minimum)
         # pass idx=200 to gen_broadcast_data so that generated ndarrays' sizes are not too big
-        check_numeric_gradient(c, gen_broadcast_data(idx=200), rtol=1e-2, atol=1e-3)
+        data = gen_broadcast_data(idx=200)
+        check_bmaxmin_gradient(c, data[0], data[1], 0.001, 1e-2, 1e-3)
 
     test_bplus(a, b)
     test_bminus(a, b)
diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py
index 31c3c46..a08b618 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -1195,6 +1195,7 @@ def test_cast_storage_ex():
             check_cast_storage((dim0, rnd.randint(512, 1024)), d, 'default', 'row_sparse',
                                check_numeric_grad=False)
 
+
 def test_sparse_dot():
     def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, lhs_density, rhs_density):
         lhs_nd = rand_ndarray(lhs_shape, 'csr', density=lhs_density, shuffle_csr_indices=False)
@@ -1222,18 +1223,38 @@ def test_sparse_dot():
                                 grad_req={'lhs': 'null', 'rhs': 'write'},
                                 rtol=1e-3, atol=1e-4)
 
+    def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols):
+        """Test for nnr_out = 0. Before the fix, the test would fail."""
+        lhs = mx.nd.zeros(lhs_shape)
+        irow = np.random.randint(0, lhs_shape[0])
+        icol = np.random.randint(0, lhs_shape[1])
+        lhs[irow, icol] = 1.0
+        if trans_lhs:
+            rhs = rand_ndarray(shape=(lhs_shape[0], rhs_num_cols), stype='default')
+            rhs[irow, :] = 0
+        else:
+            rhs = rand_ndarray(shape=(lhs_shape[1], rhs_num_cols), stype='default')
+            rhs[icol, :] = 0
+        dns_out = mx.nd.dot(lhs, rhs, transpose_a=trans_lhs)
+        assert mx.nd.sum(mx.nd.abs(dns_out)).asscalar() == 0
+        sps_out = mx.nd.sparse.dot(lhs.tostype('csr'), rhs.tostype('row_sparse'), transpose_a=trans_lhs)
+        assert same(dns_out.asnumpy(), sps_out.asnumpy())
+
     density = [1.00, 0.50, 0.01]
     for lhs_d in density:
         lhs_shape = rand_shape_2d(50, 200)
         rhs_d = 1
-        test_dot_csr(lhs_shape, (lhs_shape[1], 1), 'default', False, lhs_d, rhs_d) # test gpu SpMV
-        test_dot_csr(lhs_shape, (lhs_shape[0], 1), 'default', True , lhs_d, rhs_d) # (vector kernel)
-        test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(5, 10)), 'default', False, lhs_d, rhs_d) # test gpu SpMM
-        test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(5, 10)), 'default', True , lhs_d, rhs_d) # (scalar kernel)
+        test_dot_csr(lhs_shape, (lhs_shape[1], 1), 'default', False, lhs_d, rhs_d)  # test gpu SpMV
+        test_dot_csr(lhs_shape, (lhs_shape[0], 1), 'default', True,  lhs_d, rhs_d)  # (vector kernel)
+        test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(5, 10)), 'default', False, lhs_d, rhs_d)  # test gpu SpMM
+        test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(5, 10)), 'default', True, lhs_d, rhs_d)  # (scalar kernel)
         for rhs_d in density:
             test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False, lhs_d, rhs_d)
             test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True, lhs_d, rhs_d)
 
+    test_sparse_dot_zero_output(rand_shape_2d(50, 200), False, 40)
+    test_sparse_dot_zero_output(rand_shape_2d(50, 200), True, 40)
+
 
 def test_sparse_slice():
     def check_csr_slice(shape, slice_input):
@@ -1353,6 +1374,7 @@ def test_sparse_unary_with_numerics():
                           lambda output, outg: outg * assign_each(output, lambda x: x * (1.0 - x)),
                           backward_is_use_output=True)
 
+
 def test_sparse_nd_zeros():
     def check_sparse_nd_zeros(stype, shape):
         zero = mx.nd.zeros(shape)
@@ -1364,6 +1386,7 @@ def test_sparse_nd_zeros():
     check_sparse_nd_zeros('csr', shape)
     check_sparse_nd_zeros('default', shape)
 
+
 def test_sparse_nd_zeros_like():
     def check_sparse_nd_zeros_like(stype, shape):
         zero = mx.nd.zeros(shape, stype=stype)
@@ -1374,6 +1397,7 @@ def test_sparse_nd_zeros_like():
     check_sparse_nd_zeros_like('row_sparse', shape)
     check_sparse_nd_zeros_like('csr', shape)
 
+
 def test_sparse_axis_operations():
     def test_variations(func_name):
         dim0 = 30
@@ -1403,13 +1427,14 @@ def test_sparse_axis_operations():
     test_variations(mx.nd.mean)
     test_fallback(mx.nd.mean, axis=0, keepdims=True, exclude=True)
 
+
 def test_sparse_square_sum():
     if default_context().device_type == 'cpu':
         dim0 = 30
         dim1 = 30
         axes = [0, 1]
         keepdims = [False, True]
-        densities = [0, 0.01, 0.1, 0.2, 0.5]
+        densities = [0, 0.01, 0.2, 0.5, 1.0]
         for density in densities:
             shape = rand_shape_2d(dim0, dim1)
             rsp = rand_ndarray(shape, 'row_sparse', density)
@@ -1428,11 +1453,11 @@ def test_sparse_square_sum():
                     rsp_data = mx.sym.Variable('data', stype='row_sparse')
                     test = mx.symbol._internal._square_sum(rsp_data, axis=axis, keepdims=keepdim)
 
-                    # check symbolic backward since ograd can be a rsp
+                    # check symbolic backward since ograd can be an rsp
                     # and cannot be checked through check_numeric_gradient
                     # because it will add a loss layer as the output layer
                     # which makes ograd of the square_sum dense
-                    if axis == 1 and keepdims:
+                    if axis == 1 and keepdim:
                         dns_data = mx.sym.Variable('data')
                         baseline = mx.sym.sum(mx.sym.square(dns_data), axis=axis, keepdims=keepdim)
                         igrad_expected = mx.nd.empty(dns.shape)
@@ -1440,14 +1465,29 @@ def test_sparse_square_sum():
                                                       args_grad=[igrad_expected])
                         baseline_exec.forward(is_train=True)
                         baseline_exec.backward([ret_expected])
-                        check_symbolic_backward(test, [rsp], [ret], [igrad_expected.asnumpy()],
+                        # check backward when ograd is row sparse
+                        check_symbolic_backward(test, [rsp], [ret_expected.tostype('row_sparse')],
+                                                [igrad_expected.asnumpy()], grad_stypes={'data': 'row_sparse'})
+
+                        # check backward when ograd is dense
+                        # the stype of output of the square_sum is deteremined in symbol binding stage.
+                        # The ograd stype of the last layer is the same as the output stype of the last layer.
+                        # Need to add one more layer after square_sum to trigger the kernel for ograd
+                        # with default stype in square_sum op.
+                        baseline1 = baseline + 1
+                        baseline_exec1 = baseline1.bind(default_context(), args=[dns],
+                                                        args_grad=[igrad_expected])
+                        baseline_exec1.forward(is_train=True)
+                        baseline_exec1.backward([ret_expected])
+                        test1 = test + 1
+                        check_symbolic_backward(test1, [rsp], [ret_expected], [igrad_expected.asnumpy()],
                                                 grad_stypes={'data': 'row_sparse'})
 
                     # check numeric gradient
                     check_numeric_gradient(test, [rsp], grad_stype_dict={'data': 'row_sparse'},
                                            atol=1e-2, rtol=0.1)
 
-                    
+
 def test_sparse_storage_fallback():
     """ test operators which don't implement FComputeEx or FStatefulComputeEx """
     def check_broadcast_add(shape, lhs_stype, rhs_stype):
@@ -1516,6 +1556,7 @@ def test_sparse_storage_fallback():
             check_softmax_with_shape(lhs, rhs, shape, preserve_shape=False)
             check_softmax_with_shape(rhs, rhs, shape, preserve_shape=True)
 
+
 def test_sparse_elementwise_sum():
     def check_sparse_elementwise_sum_with_shape(stype, shape, n):
         # forward
@@ -1546,6 +1587,7 @@ def test_sparse_elementwise_sum():
         shape = tuple(np.random.randint(5, 10, size=dim))
         check_sparse_elementwise_sum_with_shape('row_sparse', shape, np.random.randint(1, 9))
 
+
 def test_sparse_embedding():
     ''' test sparse embedding op on cpu '''
     def check_sparse_embedding(executor, weight_ref, data_onehot, grad, density):
@@ -1585,6 +1627,7 @@ def test_sparse_embedding():
     for density in densities:
         check_sparse_embedding(exe_test, weight, np_onehot, grad, density)
 
+
 def test_scatter_ops():
     def csr_get_seen_points(name, csr_array, verbose=False):
         """Get a unique list of points int he CSR array as well as a
@@ -1729,6 +1772,7 @@ def test_scatter_ops():
                           lambda l, r: l + r,
                           rhs_is_scalar=True, verbose=False, density=0.5)
 
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].