You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/11/18 06:38:48 UTC

[GitHub] piiswrong closed pull request #8470: Fix sparse dot test failure due to launching kernel when nnr = 0 and bug of square_sum

piiswrong closed pull request #8470: Fix sparse dot test failure due to launching kernel when nnr = 0 and bug of square_sum
URL: https://github.com/apache/incubator-mxnet/pull/8470
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 265adfe734..b655ad88c3 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -432,13 +432,14 @@ def __setitem__(self, key, value):
             raise ValueError('Indexing NDArray with index=%s and type=%s is not supported'
                              % (str(key), str(type(key))))
 
+    # pylint: disable=line-too-long
     def __getitem__(self, key):
         """x.__getitem__(i) <=> x[i]
         Returns a sliced view of this array if the elements fetched are contiguous in memory;
         otherwise, returns a newly created NDArray.
         This functions supports advanced indexing defined in the following reference with
         some limitations.
-        https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#combining-advanced-and-basic-indexing  # pylint: disable=line-too-long
+        https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#combining-advanced-and-basic-indexing
         The following features/functionality are not supported for now:
         1. If key is a list type, only a list of integers is supported,
            i.e. key=[1, 2] is okay, while not for key=[[1]].
@@ -489,6 +490,7 @@ def __getitem__(self, key):
         else:
             raise ValueError('Indexing NDArray with index=%s and type=%s is not supported'
                              % (str(key), str(type(key))))
+    # pylint: enable=line-too-long
 
     def _get_index_nd(self, key):
         """Returns an index array for use in scatter_nd and gather_nd."""
diff --git a/src/operator/tensor/dot-inl.cuh b/src/operator/tensor/dot-inl.cuh
index 2b346bfaf2..c546c4351a 100644
--- a/src/operator/tensor/dot-inl.cuh
+++ b/src/operator/tensor/dot-inl.cuh
@@ -454,13 +454,16 @@ inline void DotCsrDnsDnsImpl(const OpContext& ctx,
                              TBlob* ret) {
   if (kNullOp == req) return;
   CHECK_EQ(lhs.storage_type(), kCSRStorage);
-  if (!lhs.storage_initialized()) return;
+  mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
+  if (!lhs.storage_initialized()) {
+    Fill(s, *ret, req, 0);
+    return;
+  }
 
   using mshadow::cuda::kBaseThreadNum;
   using mxnet_op::Kernel;
   using mxnet_op::set_zero;
   using nnvm::dim_t;
-  mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
 
   const dim_t num_rows_l = lhs.shape()[0];
   const dim_t num_cols_r = rhs.shape_[1];
@@ -587,13 +590,16 @@ inline void DotCsrDnsRspImpl(const OpContext& ctx,
   CHECK_EQ(lhs.storage_type(), kCSRStorage);
   CHECK_EQ(ret->storage_type(), kRowSparseStorage);
   CHECK_EQ(req, kWriteTo);
-  if (!lhs.storage_initialized()) return;
+  mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
+  if (!lhs.storage_initialized()) {
+    FillZerosRspImpl(s, *ret);
+    return;
+  }
 
   using mshadow::Shape1;
   using mxnet_op::Kernel;
   using mxnet_op::set_zero;
   using nnvm::dim_t;
-  mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
 
   const TBlob data_l = lhs.data();
   const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
@@ -648,6 +654,10 @@ inline void DotCsrDnsRspImpl(const OpContext& ctx,
           dim_t nnr_out = 0;
           CUDA_CALL(cudaMemcpy(&nnr_out, &row_flg_out[num_cols_l-1], sizeof(dim_t),
                                cudaMemcpyDeviceToHost));
+          if (0 == nnr_out) {
+            FillZerosRspImpl(s, *ret);
+            return;
+          }
 
           // Allocate output matrix space
           ret->CheckAndAlloc({Shape1(nnr_out)});
@@ -702,14 +712,17 @@ inline void DotCsrRspRspImpl(const OpContext& ctx,
   CHECK_EQ(lhs.storage_type(), kCSRStorage);
   CHECK_EQ(rhs.storage_type(), kRowSparseStorage);
   CHECK_EQ(ret->storage_type(), kRowSparseStorage);
-  if (!lhs.storage_initialized() || !rhs.storage_initialized()) return;
+  mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
+  if (!lhs.storage_initialized() || !rhs.storage_initialized()) {
+    FillZerosRspImpl(s, *ret);
+    return;
+  }
   CHECK_EQ(req, kWriteTo);
 
   using mshadow::Shape1;
   using mxnet_op::Kernel;
   using mxnet_op::set_zero;
   using nnvm::dim_t;
-  mshadow::Stream<gpu>* s = ctx.get_stream<gpu>();
 
   const TBlob data_l = lhs.data();
   const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
@@ -767,6 +780,10 @@ inline void DotCsrRspRspImpl(const OpContext& ctx,
             dim_t nnr_out = 0;
             CUDA_CALL(cudaMemcpy(&nnr_out, &row_flg_out[num_cols_l-1], sizeof(dim_t),
                                  cudaMemcpyDeviceToHost));
+            if (0 == nnr_out) {
+              FillZerosRspImpl(s, *ret);
+              return;
+            }
 
             // Allocate output matrix space
             ret->CheckAndAlloc({mshadow::Shape1(nnr_out)});
diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h
index 26061cbab6..7ab4710090 100644
--- a/src/operator/tensor/dot-inl.h
+++ b/src/operator/tensor/dot-inl.h
@@ -30,6 +30,7 @@
 #include <algorithm>
 #include <utility>
 #include <type_traits>
+#include "./init_op.h"
 #include "../mshadow_op.h"
 #include "../elemwise_op_common.h"
 #include "../mxnet_op.h"
@@ -535,11 +536,14 @@ inline void DotCsrDnsDnsImpl(const OpContext& ctx,
                              TBlob* ret) {
   if (kNullOp == req) return;
   CHECK_EQ(lhs.storage_type(), kCSRStorage);
-  if (!lhs.storage_initialized()) return;
+  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
+  if (!lhs.storage_initialized()) {
+    Fill(s, *ret, req, 0);
+    return;
+  }
 
   using nnvm::dim_t;
 
-  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
   const TBlob data_l = lhs.data();
   const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
   const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
@@ -586,13 +590,16 @@ inline void DotCsrDnsRspImpl(const OpContext& ctx,
   if (kNullOp == req) return;
   CHECK_EQ(lhs.storage_type(), kCSRStorage);
   CHECK_EQ(ret->storage_type(), kRowSparseStorage);
-  if (!lhs.storage_initialized()) return;
+  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
+  if (!lhs.storage_initialized()) {
+    FillZerosRspImpl(s, *ret);
+    return;
+  }
   CHECK_EQ(req, kWriteTo);
 
   using mxnet_op::set_zero;
   using nnvm::dim_t;
 
-  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
   const TBlob data_l = lhs.data();
   const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
   const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
@@ -621,8 +628,11 @@ inline void DotCsrDnsRspImpl(const OpContext& ctx,
                 seg_len, lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]);
             dim_t nnr = 0;
             nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr);
+            if (0 == nnr) {
+              FillZerosRspImpl(s, *ret);
+              return;
+            }
             ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr));
-            if (0 == nnr) return;
             mshadow::Tensor<cpu, 2, DType> rsp_data = data_out.FlatTo2D<cpu, DType>(s);
             dim_t idx = 0;
             for (index_t i = 0; i < ret->shape()[0]; ++i) {
@@ -725,13 +735,16 @@ inline void DotCsrRspRspImpl(const OpContext& ctx,
   CHECK_EQ(lhs.storage_type(), kCSRStorage);
   CHECK_EQ(rhs.storage_type(), kRowSparseStorage);
   CHECK_EQ(ret->storage_type(), kRowSparseStorage);
-  if (!lhs.storage_initialized() || !rhs.storage_initialized()) return;
+  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
+  if (!lhs.storage_initialized() || !rhs.storage_initialized()) {
+    FillZerosRspImpl(s, *ret);
+    return;
+  }
   CHECK_EQ(req, kWriteTo);
 
   using mxnet_op::set_zero;
   using nnvm::dim_t;
 
-  mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
   const TBlob data_l = lhs.data();
   const TBlob indptr_l = lhs.aux_data(csr::kIndPtr);
   const TBlob col_idx_l = lhs.aux_data(csr::kIdx);
@@ -764,8 +777,11 @@ inline void DotCsrRspRspImpl(const OpContext& ctx,
                 ret->shape()[0], ret->shape()[1], seg_len);
             dim_t nnr = 0;
             nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr);
+            if (0 == nnr) {
+              FillZerosRspImpl(s, *ret);
+              return;
+            }
             ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr));
-            if (0 == nnr) return;
             mshadow::Tensor<cpu, 2, DType> rsp_data = data_out.FlatTo2D<cpu, DType>(s);
             dim_t idx = 0;
             for (index_t i = 0; i < ret->shape()[0]; ++i) {
diff --git a/src/operator/tensor/square_sum-inl.h b/src/operator/tensor/square_sum-inl.h
index 7ce5b1e1b0..a052ad96cf 100644
--- a/src/operator/tensor/square_sum-inl.h
+++ b/src/operator/tensor/square_sum-inl.h
@@ -179,7 +179,7 @@ struct SquareSumRspKernel<req, 1, true> {
   }
 };
 
-template<int req, int axis, int ograd_stype = kDefaultStorage>
+template<int req, int axis, int ograd_stype = kDefaultStorage, bool is_data_full_rsp = false>
 struct SquareSumRspGradKernel;
 
 template<int req>
@@ -224,11 +224,10 @@ struct SquareSumRspGradKernel<req, 1> {
 
 /*!
  * Note: This kernel assumes that the ograd and in_data
- * are all rsp and have equal row_idx array, or
- * in_data is a full rsp.
+ * are all rsp and have equal row_idx array.
  */
 template<int req>
-struct SquareSumRspGradKernel<req, 1, kRowSparseStorage> {
+struct SquareSumRspGradKernel<req, 1, kRowSparseStorage, false> {
   /*!
    * \param i index of igrad.data()
    * \param in_grad_row_idx row_idx of the gradient of the op's input
@@ -243,10 +242,36 @@ struct SquareSumRspGradKernel<req, 1, kRowSparseStorage> {
                                   const DType* in_data, const int64_t num_cols) {
     const int64_t row = i / num_cols;
     in_grad_row_idx[row] = out_grad_row_idx[row];
-    KERNEL_ASSIGN(in_grad[i], req, 2*in_data[i]*out_grad[row]);
+    KERNEL_ASSIGN(in_grad[i], req, 2 * in_data[i] * out_grad[row]);
   }
 };
 
+/*!
+ * Note: This kernel assumes that the ograd and in_data
+ * are all rsp and in_data is a full rsp.
+ */
+template<int req>
+struct SquareSumRspGradKernel<req, 1, kRowSparseStorage, true> {
+  /*!
+   * \param i index of igrad.data()
+   * \param in_grad_row_idx row_idx of the gradient of the op's input
+   * \param in_grad gradient of the op's input
+   * \param out_grad_row_idx row_idx of the gradient of the op's output
+   * \param out_grad gradient of the op's output
+   * \param in_data op's input
+   */
+  template<typename IType, typename DType>
+  MSHADOW_XINLINE static void Map(int i, IType* in_grad_row_idx, DType* in_grad,
+                                  const IType* out_grad_row_idx, const DType* out_grad,
+                                  const DType* in_data, const int64_t num_cols) {
+    const int64_t row = i / num_cols;
+    const int64_t row_dns = out_grad_row_idx[row];
+    in_grad_row_idx[row] = row_dns;
+    KERNEL_ASSIGN(in_grad[i], req, 2 * in_data[row_dns*num_cols+i%num_cols] * out_grad[row]);
+  }
+};
+
+
 template<typename xpu>
 void SquareSumRspImpl(const nnvm::NodeAttrs& attrs,
                       mshadow::Stream<xpu>* s,
@@ -334,6 +359,12 @@ void SquareSumRspImpl(const nnvm::NodeAttrs& attrs,
   }
 }
 
+/*!\brief
+ * This function only supports the following three situations:
+ * 1. ograd is a dns and input is an rsp
+ * 2. ograd and input are both rsp and have the same row_idx array
+ * 3. ograd and input are both rsp and input is a full rsp
+ */
 template<typename xpu>
 void SquareSumRspGradImpl(const nnvm::NodeAttrs& attrs,
                           mshadow::Stream<xpu>* s,
@@ -350,23 +381,21 @@ void SquareSumRspGradImpl(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(input.storage_type(), kRowSparseStorage);
   CHECK_EQ(igrad->storage_type(), kRowSparseStorage);
   CHECK_EQ(req, kWriteTo);
-  if (!input.storage_initialized()) {
+  if (!input.storage_initialized()
+      || (ograd.storage_type() == kRowSparseStorage && !ograd.storage_initialized())) {
     FillZerosRspImpl(s, *igrad);
     return;
   }
 
   using namespace mxnet_op;
-  // TODO(junwu) change the input of CheckAndAlloc
-  // if we want to support differen row idx arrays
-  // for ograd and input when they are both row-sparse ndarrays
-  igrad->CheckAndAlloc({input.aux_shape(rowsparse::kIdx)});
   const int64_t num_cols = input.storage_shape()[1];
-  const TBlob& igrad_data = igrad->data();
-  const TBlob igrad_row_idx = igrad->aux_data(rowsparse::kIdx);
   const TBlob& ograd_data = ograd.data();
   const TBlob& in_data = input.data();
   const TBlob in_row_idx = input.aux_data(rowsparse::kIdx);
   if (ograd.storage_type() == kDefaultStorage) {
+    igrad->CheckAndAlloc({input.aux_shape(rowsparse::kIdx)});
+    const TBlob& igrad_data = igrad->data();
+    const TBlob igrad_row_idx = igrad->aux_data(rowsparse::kIdx);
     if (0 == param.axis[0]) {  // forward is sum per column
       MSHADOW_TYPE_SWITCH(igrad_data.type_flag_, DType, {
         MSHADOW_IDX_TYPE_SWITCH(igrad_row_idx.type_flag_, IType, {
@@ -396,18 +425,28 @@ void SquareSumRspGradImpl(const nnvm::NodeAttrs& attrs,
     CHECK_EQ(ograd.shape().ndim(), 2U);
     const TBlob ograd_row_idx = ograd.aux_data(rowsparse::kIdx);
     CHECK(ograd_row_idx.Size() == in_row_idx.Size() || in_row_idx.Size() == in_data.shape_[0]);
+    igrad->CheckAndAlloc({ograd.aux_shape(rowsparse::kIdx)});
+    const TBlob& igrad_data = igrad->data();
+    const TBlob igrad_row_idx = igrad->aux_data(rowsparse::kIdx);
     MSHADOW_IDX_TYPE_SWITCH(igrad_row_idx.type_flag_, IType, {
       if (std::is_same<xpu, cpu>::value) {
-        const IType* first1 = ograd_row_idx.dptr<IType>();
-        const IType* last1 = first1 + ograd_row_idx.Size();
-        const IType* first2 = in_row_idx.dptr<IType>();
         // when ograd_row_idx and in_row_idx have the same size and input is not a full rsp
         // ograd_row_idx and in_row_idx are expected to have the same elements
-        if (ograd_row_idx.Size() == in_row_idx.Size() && in_row_idx.Size() != in_data.shape_[0]) {
+        if (in_row_idx.Size() != input.shape()[0]) {  // if input data is not a full rsp
+          CHECK_EQ(ograd_row_idx.Size(), in_row_idx.Size()) << "SquareSumRspGradImpl only supports"
+                                                               " equal ograd_row_idx and"
+                                                               " input_row_idx when ograd and"
+                                                               " input are both row-sparse and"
+                                                               " input data is not a full"
+                                                               " row-sparse matrix";
+          const IType* first1 = ograd_row_idx.dptr<IType>();
+          const IType* last1 = first1 + ograd_row_idx.Size();
+          const IType* first2 = in_row_idx.dptr<IType>();
           CHECK(std::equal(first1, last1, first2)) << "SquareSumRspGradImpl only supports"
                                                       " equal ograd_row_idx and input_row_idx"
                                                       " when ograd and input are both"
-                                                      " row-sparse";
+                                                      " row-sparse and input data is not a full"
+                                                      " row-sparse matrix";
         }
       } else {
         LOG(FATAL) << "SquareSumRspGradImpl has not implemented GPU version when"
@@ -415,10 +454,17 @@ void SquareSumRspGradImpl(const nnvm::NodeAttrs& attrs,
       }
       MSHADOW_TYPE_SWITCH(igrad_data.type_flag_, DType, {
         MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
-          Kernel<SquareSumRspGradKernel<req_type, 1, kRowSparseStorage>, xpu>::Launch(
-              s, igrad_data.Size(), igrad_row_idx.dptr<IType>(),
-              igrad_data.dptr<DType>(), ograd_row_idx.dptr<IType>(),
-              ograd_data.dptr<DType>(), in_data.dptr<DType>(), num_cols);
+          if (in_row_idx.Size() != input.shape()[0]) {  // input data is not a full rsp
+            Kernel<SquareSumRspGradKernel<req_type, 1, kRowSparseStorage, false>, xpu>::Launch(
+                s, igrad_data.Size(), igrad_row_idx.dptr<IType>(),
+                igrad_data.dptr<DType>(), ograd_row_idx.dptr<IType>(),
+                ograd_data.dptr<DType>(), in_data.dptr<DType>(), num_cols);
+          } else {  // input data is a full rsp
+            Kernel<SquareSumRspGradKernel<req_type, 1, kRowSparseStorage, true>, xpu>::Launch(
+                s, igrad_data.Size(), igrad_row_idx.dptr<IType>(),
+                igrad_data.dptr<DType>(), ograd_row_idx.dptr<IType>(),
+                ograd_data.dptr<DType>(), in_data.dptr<DType>(), num_cols);
+          }
         })
       })
     })
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index d322fa4c2a..18544bade5 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -1267,6 +1267,18 @@ def test_bneq(a, b):
 
 
 def test_broadcast_binary_op():
+    def check_bmaxmin_gradient(test_sym, x, y, delta, rtol, atol):
+        """This function ensures that checking the numerical gradient of
+        broadcast_max/min is not crossing the boundary y=x where there
+        is no gradient definition at those sigularities."""
+        x_max = np.max(x)
+        y = x_max + 2 * delta + np.random.random(y.shape)
+        check_numeric_gradient(test_sym, [x, y], numeric_eps=delta, rtol=rtol, atol=atol)
+
+        x_min = np.min(x)
+        y = x_min - 2 * delta - np.random.random(y.shape)
+        check_numeric_gradient(test_sym, [x, y], numeric_eps=delta, rtol=rtol, atol=atol)
+
     a = mx.sym.Variable('a')
     b = mx.sym.Variable('b')
 
@@ -1316,13 +1328,15 @@ def test_bmax(a, b):
         c = mx.sym.broadcast_maximum(a, b)
         check_binary_op_forward(c, lambda x, y: np.maximum(x, y), gen_broadcast_data, mx_nd_func=mx.nd.maximum)
         # pass idx=200 to gen_broadcast_data so that generated ndarrays' sizes are not too big
-        check_numeric_gradient(c, gen_broadcast_data(idx=200), rtol=1e-2, atol=1e-3)
+        data = gen_broadcast_data(idx=200)
+        check_bmaxmin_gradient(c, data[0], data[1], 0.001, 1e-2, 1e-3)
 
     def test_bmin(a, b):
         c = mx.sym.broadcast_minimum(a, b)
         check_binary_op_forward(c, lambda x, y: np.minimum(x, y), gen_broadcast_data, mx_nd_func=mx.nd.minimum)
         # pass idx=200 to gen_broadcast_data so that generated ndarrays' sizes are not too big
-        check_numeric_gradient(c, gen_broadcast_data(idx=200), rtol=1e-2, atol=1e-3)
+        data = gen_broadcast_data(idx=200)
+        check_bmaxmin_gradient(c, data[0], data[1], 0.001, 1e-2, 1e-3)
 
     test_bplus(a, b)
     test_bminus(a, b)
diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py
index 31c3c46889..a08b6187bc 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -1195,6 +1195,7 @@ def check_cast_storage(shape, density, from_stype, to_stype, check_numeric_grad=
             check_cast_storage((dim0, rnd.randint(512, 1024)), d, 'default', 'row_sparse',
                                check_numeric_grad=False)
 
+
 def test_sparse_dot():
     def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, lhs_density, rhs_density):
         lhs_nd = rand_ndarray(lhs_shape, 'csr', density=lhs_density, shuffle_csr_indices=False)
@@ -1222,18 +1223,38 @@ def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, lhs_density, rhs_de
                                 grad_req={'lhs': 'null', 'rhs': 'write'},
                                 rtol=1e-3, atol=1e-4)
 
+    def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols):
+        """Test for nnr_out = 0. Before the fix, the test would fail."""
+        lhs = mx.nd.zeros(lhs_shape)
+        irow = np.random.randint(0, lhs_shape[0])
+        icol = np.random.randint(0, lhs_shape[1])
+        lhs[irow, icol] = 1.0
+        if trans_lhs:
+            rhs = rand_ndarray(shape=(lhs_shape[0], rhs_num_cols), stype='default')
+            rhs[irow, :] = 0
+        else:
+            rhs = rand_ndarray(shape=(lhs_shape[1], rhs_num_cols), stype='default')
+            rhs[icol, :] = 0
+        dns_out = mx.nd.dot(lhs, rhs, transpose_a=trans_lhs)
+        assert mx.nd.sum(mx.nd.abs(dns_out)).asscalar() == 0
+        sps_out = mx.nd.sparse.dot(lhs.tostype('csr'), rhs.tostype('row_sparse'), transpose_a=trans_lhs)
+        assert same(dns_out.asnumpy(), sps_out.asnumpy())
+
     density = [1.00, 0.50, 0.01]
     for lhs_d in density:
         lhs_shape = rand_shape_2d(50, 200)
         rhs_d = 1
-        test_dot_csr(lhs_shape, (lhs_shape[1], 1), 'default', False, lhs_d, rhs_d) # test gpu SpMV
-        test_dot_csr(lhs_shape, (lhs_shape[0], 1), 'default', True , lhs_d, rhs_d) # (vector kernel)
-        test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(5, 10)), 'default', False, lhs_d, rhs_d) # test gpu SpMM
-        test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(5, 10)), 'default', True , lhs_d, rhs_d) # (scalar kernel)
+        test_dot_csr(lhs_shape, (lhs_shape[1], 1), 'default', False, lhs_d, rhs_d)  # test gpu SpMV
+        test_dot_csr(lhs_shape, (lhs_shape[0], 1), 'default', True,  lhs_d, rhs_d)  # (vector kernel)
+        test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(5, 10)), 'default', False, lhs_d, rhs_d)  # test gpu SpMM
+        test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(5, 10)), 'default', True, lhs_d, rhs_d)  # (scalar kernel)
         for rhs_d in density:
             test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False, lhs_d, rhs_d)
             test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True, lhs_d, rhs_d)
 
+    test_sparse_dot_zero_output(rand_shape_2d(50, 200), False, 40)
+    test_sparse_dot_zero_output(rand_shape_2d(50, 200), True, 40)
+
 
 def test_sparse_slice():
     def check_csr_slice(shape, slice_input):
@@ -1353,6 +1374,7 @@ def check_sparse_function(name, mxnet_func, forward_numpy_call, backward_numpy_c
                           lambda output, outg: outg * assign_each(output, lambda x: x * (1.0 - x)),
                           backward_is_use_output=True)
 
+
 def test_sparse_nd_zeros():
     def check_sparse_nd_zeros(stype, shape):
         zero = mx.nd.zeros(shape)
@@ -1364,6 +1386,7 @@ def check_sparse_nd_zeros(stype, shape):
     check_sparse_nd_zeros('csr', shape)
     check_sparse_nd_zeros('default', shape)
 
+
 def test_sparse_nd_zeros_like():
     def check_sparse_nd_zeros_like(stype, shape):
         zero = mx.nd.zeros(shape, stype=stype)
@@ -1374,6 +1397,7 @@ def check_sparse_nd_zeros_like(stype, shape):
     check_sparse_nd_zeros_like('row_sparse', shape)
     check_sparse_nd_zeros_like('csr', shape)
 
+
 def test_sparse_axis_operations():
     def test_variations(func_name):
         dim0 = 30
@@ -1403,13 +1427,14 @@ def test_fallback(func_name, axis=0, keepdims=True, exclude=True):
     test_variations(mx.nd.mean)
     test_fallback(mx.nd.mean, axis=0, keepdims=True, exclude=True)
 
+
 def test_sparse_square_sum():
     if default_context().device_type == 'cpu':
         dim0 = 30
         dim1 = 30
         axes = [0, 1]
         keepdims = [False, True]
-        densities = [0, 0.01, 0.1, 0.2, 0.5]
+        densities = [0, 0.01, 0.2, 0.5, 1.0]
         for density in densities:
             shape = rand_shape_2d(dim0, dim1)
             rsp = rand_ndarray(shape, 'row_sparse', density)
@@ -1428,11 +1453,11 @@ def test_sparse_square_sum():
                     rsp_data = mx.sym.Variable('data', stype='row_sparse')
                     test = mx.symbol._internal._square_sum(rsp_data, axis=axis, keepdims=keepdim)
 
-                    # check symbolic backward since ograd can be a rsp
+                    # check symbolic backward since ograd can be an rsp
                     # and cannot be checked through check_numeric_gradient
                     # because it will add a loss layer as the output layer
                     # which makes ograd of the square_sum dense
-                    if axis == 1 and keepdims:
+                    if axis == 1 and keepdim:
                         dns_data = mx.sym.Variable('data')
                         baseline = mx.sym.sum(mx.sym.square(dns_data), axis=axis, keepdims=keepdim)
                         igrad_expected = mx.nd.empty(dns.shape)
@@ -1440,14 +1465,29 @@ def test_sparse_square_sum():
                                                       args_grad=[igrad_expected])
                         baseline_exec.forward(is_train=True)
                         baseline_exec.backward([ret_expected])
-                        check_symbolic_backward(test, [rsp], [ret], [igrad_expected.asnumpy()],
+                        # check backward when ograd is row sparse
+                        check_symbolic_backward(test, [rsp], [ret_expected.tostype('row_sparse')],
+                                                [igrad_expected.asnumpy()], grad_stypes={'data': 'row_sparse'})
+
+                        # check backward when ograd is dense
+                        # the stype of output of the square_sum is deteremined in symbol binding stage.
+                        # The ograd stype of the last layer is the same as the output stype of the last layer.
+                        # Need to add one more layer after square_sum to trigger the kernel for ograd
+                        # with default stype in square_sum op.
+                        baseline1 = baseline + 1
+                        baseline_exec1 = baseline1.bind(default_context(), args=[dns],
+                                                        args_grad=[igrad_expected])
+                        baseline_exec1.forward(is_train=True)
+                        baseline_exec1.backward([ret_expected])
+                        test1 = test + 1
+                        check_symbolic_backward(test1, [rsp], [ret_expected], [igrad_expected.asnumpy()],
                                                 grad_stypes={'data': 'row_sparse'})
 
                     # check numeric gradient
                     check_numeric_gradient(test, [rsp], grad_stype_dict={'data': 'row_sparse'},
                                            atol=1e-2, rtol=0.1)
 
-                    
+
 def test_sparse_storage_fallback():
     """ test operators which don't implement FComputeEx or FStatefulComputeEx """
     def check_broadcast_add(shape, lhs_stype, rhs_stype):
@@ -1516,6 +1556,7 @@ def check_operator_with_temp_resource(shape, stype):
             check_softmax_with_shape(lhs, rhs, shape, preserve_shape=False)
             check_softmax_with_shape(rhs, rhs, shape, preserve_shape=True)
 
+
 def test_sparse_elementwise_sum():
     def check_sparse_elementwise_sum_with_shape(stype, shape, n):
         # forward
@@ -1546,6 +1587,7 @@ def check_sparse_elementwise_sum_with_shape(stype, shape, n):
         shape = tuple(np.random.randint(5, 10, size=dim))
         check_sparse_elementwise_sum_with_shape('row_sparse', shape, np.random.randint(1, 9))
 
+
 def test_sparse_embedding():
     ''' test sparse embedding op on cpu '''
     def check_sparse_embedding(executor, weight_ref, data_onehot, grad, density):
@@ -1585,6 +1627,7 @@ def check_sparse_embedding(executor, weight_ref, data_onehot, grad, density):
     for density in densities:
         check_sparse_embedding(exe_test, weight, np_onehot, grad, density)
 
+
 def test_scatter_ops():
     def csr_get_seen_points(name, csr_array, verbose=False):
         """Get a unique list of points int he CSR array as well as a
@@ -1729,6 +1772,7 @@ def check_scatter_ops(name, shape, lhs_stype, rhs_stype, forward_mxnet_call, for
                           lambda l, r: l + r,
                           rhs_is_scalar=True, verbose=False, density=0.5)
 
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services