You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/01/03 23:56:21 UTC

[incubator-mxnet] branch master updated: sparse output for binary scalar op with zero (#9227)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9d51e1a  sparse output for binary scalar op with zero (#9227)
9d51e1a is described below

commit 9d51e1af3778418a65a82f8a35a4032bedc6cbe1
Author: Ziyue Huang <zy...@gmail.com>
AuthorDate: Thu Jan 4 07:56:18 2018 +0800

    sparse output for binary scalar op with zero (#9227)
    
    * sparse output for binary scalar op with zero
    
    * same out stype for cpu/gpu
    
    * update
    
    * update
    
    * address comments
---
 .../tensor/elemwise_binary_scalar_op_basic.cc      | 29 +++++++++++++++-------
 tests/python/gpu/test_operator_gpu.py              |  2 +-
 tests/python/unittest/test_sparse_ndarray.py       |  8 ++++--
 3 files changed, 27 insertions(+), 12 deletions(-)

diff --git a/src/operator/tensor/elemwise_binary_scalar_op_basic.cc b/src/operator/tensor/elemwise_binary_scalar_op_basic.cc
index 9a278d8..6792379 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_scalar_op_basic.cc
@@ -47,25 +47,36 @@
 namespace mxnet {
 namespace op {
 
+/*!
+ * \brief FInferStorageType for binary operator with scalar,
+ *   csr -> csr and row_sparse -> row_sparse if the scalar is zero,
+ *   otherwise the output is of default storage.
+ */
 static bool BinaryScalarStorageTypeWithDenseResultStorageType(const NodeAttrs& attrs,
                                                               const int dev_mask,
                                                               DispatchMode* dispatch_mode,
                                                               std::vector<int>* in_attrs,
                                                               std::vector<int>* out_attrs)  {
+  CHECK_EQ(in_attrs->size(), 1);
+  CHECK_EQ(out_attrs->size(), 1);
   bool dispatched = false;
-  if (common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
+  const bool invalid_ctx = dev_mask != kCPU;
+  const NDArrayStorageType instype = static_cast<NDArrayStorageType>(in_attrs->at(0));
+  const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback
+                                       : DispatchMode::kFComputeEx;
+  const double alpha = nnvm::get<double>(attrs.parsed);
+  if (instype == kDefaultStorage) {
     dispatched = storage_type_assign(&out_attrs[0],
-                                     kDefaultStorage,
-                                     dispatch_mode,
-                                     DispatchMode::kFCompute);
-  } else if (dev_mask == kCPU) {
-    dispatched = storage_type_assign(&out_attrs[0],
-                                     kDefaultStorage,
-                                     dispatch_mode,
-                                     DispatchMode::kFComputeEx);
+      kDefaultStorage, dispatch_mode, DispatchMode::kFCompute);
+  }
+  if (!dispatched && (instype == kCSRStorage || instype == kRowSparseStorage)) {
+    dispatched = storage_type_assign(&out_attrs[0], alpha == 0 ? instype : kDefaultStorage,
+      dispatch_mode, dispatch_ex);
   }
   if (!dispatched) {
     dispatch_fallback(out_attrs, dispatch_mode);
+  }
+  if (static_cast<DispatchMode>(*dispatch_mode) == DispatchMode::kFComputeFallback) {
     LogStorageFallback(attrs, dev_mask, in_attrs, out_attrs);
   }
   return true;
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index 31e888b..52aca09 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -38,7 +38,7 @@ from test_sparse_ndarray import test_create_csr, test_create_row_sparse, test_sp
 from test_sparse_ndarray import test_create_sparse_nd_empty, test_create_sparse_nd_from_sparse
 from test_sparse_ndarray import test_create_sparse_nd_from_dense, test_create_sparse_nd_infer_shape
 from test_sparse_ndarray import test_sparse_nd_check_format, test_sparse_nd_copy
-from test_sparse_ndarray import test_sparse_nd_setitem
+from test_sparse_ndarray import test_sparse_nd_setitem, test_sparse_nd_binary_scalar_op
 from test_sparse_operator import *
 from test_ndarray import *
 
diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py
index e404997..185ce7f 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -264,12 +264,14 @@ def test_sparse_nd_binary():
 
 def test_sparse_nd_binary_scalar_op():
     N = 3
-    def check(fn, stype):
+    def check(fn, stype, out_stype=None):
         for _ in range(N):
             ndim = 2
             shape = np.random.randint(1, 6, size=(ndim,))
             npy = np.random.normal(0, 1, size=shape)
             nd = mx.nd.array(npy).tostype(stype)
+            if out_stype is not None:
+                assert(nd.stype == out_stype)
             assert_allclose(fn(npy), fn(nd).asnumpy(), rtol=1e-4, atol=1e-4)
 
     stypes = ['row_sparse', 'csr']
@@ -285,7 +287,9 @@ def test_sparse_nd_binary_scalar_op():
         check(lambda x: 0.5 >= x, stype)
         check(lambda x: 0.5 <= x, stype)
         check(lambda x: 0.5 == x, stype)
-        check(lambda x: x / 2, stype)
+        check(lambda x: x / 2, stype, out_stype=stype)
+        check(lambda x: x + 0, stype, out_stype=stype)
+        check(lambda x: x - 0, stype, out_stype=stype)
 
 def test_sparse_nd_binary_iop():
     N = 3

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].