You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/29 15:06:13 UTC

[GitHub] eric-haibin-lin closed pull request #11389: [MXNET-566] Fix flaky test_operator_gpu.test_sparse_dot

eric-haibin-lin closed pull request #11389: [MXNET-566] Fix flaky test_operator_gpu.test_sparse_dot
URL: https://github.com/apache/incubator-mxnet/pull/11389
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index faffe1bdea9..d3f44404fd8 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -202,7 +202,7 @@ class NDArray {
   /*! returns the dtypes of all aux data */
   const std::vector<int>& aux_types() const {
     CHECK_NE(storage_type(), kDefaultStorage)
-             << "aux_types() is not intended for kDefaultStorage.";
+      << "aux_types() is not intended for kDefaultStorage.";
     return ptr_->aux_types;
   }
 
@@ -214,6 +214,8 @@ class NDArray {
    * the shape is known and need to be reset using this function.
    */
   inline void set_aux_shape(size_t index, const TShape& shape) const {
+    CHECK_NE(storage_type(), kDefaultStorage)
+      << "set_aux_shape() is not intended for kDefaultStorage.";
     ptr_->set_aux_shape(index, shape);
   }
 
diff --git a/src/operator/tensor/dot-inl.cuh b/src/operator/tensor/dot-inl.cuh
index c507a9aa302..8aedec066ba 100644
--- a/src/operator/tensor/dot-inl.cuh
+++ b/src/operator/tensor/dot-inl.cuh
@@ -1053,7 +1053,7 @@ inline void DotDnsCsrDnsImpl(const OpContext& ctx, const gpu& gpu_dev,
   TBlob csr_indices = rhs.aux_data(csr::kIdx);
   TBlob csr_indptr = rhs.aux_data(csr::kIndPtr);
   if (!rhs.storage_initialized()) {
-    FillZerosCsrImpl(s, *ret);
+    Fill(s, ret->data(), req, 0);
     return;
   }
 
diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h
index 675cbe8b238..5e469108eda 100644
--- a/src/operator/tensor/dot-inl.h
+++ b/src/operator/tensor/dot-inl.h
@@ -1139,7 +1139,7 @@ inline void DotDnsCsrDnsImpl(const OpContext& ctx, const cpu& cpu_dev,
   CHECK_EQ(rhs.storage_type(), kCSRStorage);
   mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
   if (!rhs.storage_initialized()) {
-    FillZerosCsrImpl(s, *ret);
+    Fill(s, ret->data(), req, 0);
     return;
   }
 
diff --git a/src/operator/tensor/init_op.cu b/src/operator/tensor/init_op.cu
index 5841408fa1a..81d835ee3bd 100644
--- a/src/operator/tensor/init_op.cu
+++ b/src/operator/tensor/init_op.cu
@@ -34,6 +34,7 @@ namespace op {
  * \param dst - NDArray which is to be set to "all zeroes"
  */
 void FillZerosCsrImpl(mshadow::Stream<mshadow::gpu> *s, const NDArray& dst) {
+  CHECK_EQ(dst.storage_type(), kCSRStorage) << "dst is not a CSR NDArray";
   dst.set_aux_shape(csr::kIdx, mshadow::Shape1(0));
   dst.CheckAndAllocAuxData(csr::kIndPtr, mshadow::Shape1(dst.shape()[0] + 1));
   TBlob indptr_data = dst.aux_data(csr::kIndPtr);
diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h
index 6c966055637..4af3a40f42a 100644
--- a/src/operator/tensor/init_op.h
+++ b/src/operator/tensor/init_op.h
@@ -344,6 +344,7 @@ inline void FillDnsZerosRspImpl(mshadow::Stream<xpu> *s, NDArray *dst) {
  */
 template<typename xpu>
 void FillZerosRspImpl(mshadow::Stream<xpu> *, const NDArray& dst) {
+  CHECK_EQ(dst.storage_type(), kRowSparseStorage) << "dst should be an RSP NDArray";
   if (dst.storage_initialized()) {
     // reset the shapes if it's not zeros (set_aux_shape() will set storage_shape to zero as well)
     dst.set_aux_shape(rowsparse::kIdx, TShape(mshadow::Shape1(0)));
@@ -356,6 +357,7 @@ void FillZerosRspImpl(mshadow::Stream<xpu> *, const NDArray& dst) {
  * \param dst - NDArray which is to be set to "all zeroes"
  */
 inline void FillZerosCsrImpl(mshadow::Stream<mshadow::cpu> *s, const NDArray& dst) {
+  CHECK_EQ(dst.storage_type(), kCSRStorage) << "dst is not a CSR NDArray";
   dst.set_aux_shape(csr::kIdx, mshadow::Shape1(0));
   dst.CheckAndAllocAuxData(csr::kIndPtr, mshadow::Shape1(dst.shape()[0] + 1));
   TBlob indptr_data = dst.aux_data(csr::kIndPtr);
diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py
index 70af2bc8ce5..09546ac2416 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -1278,7 +1278,7 @@ def test_infer_forward_stype(lhs_shape, rhs_shape, lhs_density, rhs_density, tra
                     rhs = rhs_nd.tostype(rhs_stype)
                     out = mx.nd.dot(lhs, rhs, forward_stype=forward_stype,
                                     transpose_a=trans_a, transpose_b=trans_b)
-                    assert_almost_equal(out.tostype('default').asnumpy(), out_np, rtol=1e-4, atol=1e-5)
+                    assert_almost_equal(out.tostype('default').asnumpy(), out_np, rtol=1e-3, atol=1e-5)
                     lhs_var = mx.symbol.Variable('lhs', stype=lhs_stype)
                     rhs_var = mx.symbol.Variable('rhs', stype=rhs_stype)
                     out = mx.symbol.sparse.dot(lhs_var, rhs_var,
@@ -1295,7 +1295,7 @@ def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, lhs_density, rhs_de
         out = mx.nd.dot(lhs_nd, rhs_nd, transpose_a=trans_lhs)
         out_dns = mx.nd.dot(lhs_dns, rhs_dns, transpose_a=trans_lhs)
         out_np = out_dns.asnumpy()
-        assert_almost_equal(out.asnumpy(), out_np, rtol=1e-4, atol=1e-5)
+        assert_almost_equal(out.asnumpy(), out_np, rtol=1e-3, atol=1e-5)
 
         # test symbolic forward
         lhs = mx.symbol.Variable('lhs', stype='csr')
@@ -1324,7 +1324,7 @@ def test_dot_dns_csr(lhs_shape, rhs_shape, lhs_density, rhs_density, trans_lhs=F
         out = mx.nd.sparse.dot(lhs_nd, rhs_nd, transpose_a=trans_lhs, transpose_b=trans_rhs, forward_stype=forward_stype)
         out_dns = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs, transpose_b=trans_rhs, forward_stype=forward_stype)
         out_np = out_dns.asnumpy()
-        assert_almost_equal(out.asnumpy(), out_np, rtol=1e-4, atol=1e-5)
+        assert_almost_equal(out.asnumpy(), out_np, rtol=1e-3, atol=1e-5)
 
         # test symbolic forward
         lhs = mx.symbol.Variable('lhs', stype='default')


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services