You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/01/03 23:56:21 UTC
[incubator-mxnet] branch master updated: sparse output for binary
scalar op with zero (#9227)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 9d51e1a sparse output for binary scalar op with zero (#9227)
9d51e1a is described below
commit 9d51e1af3778418a65a82f8a35a4032bedc6cbe1
Author: Ziyue Huang <zy...@gmail.com>
AuthorDate: Thu Jan 4 07:56:18 2018 +0800
sparse output for binary scalar op with zero (#9227)
* sparse output for binary scalar op with zero
* same out stype for cpu/gpu
* update
* update
* address comments
---
.../tensor/elemwise_binary_scalar_op_basic.cc | 29 +++++++++++++++-------
tests/python/gpu/test_operator_gpu.py | 2 +-
tests/python/unittest/test_sparse_ndarray.py | 8 ++++--
3 files changed, 27 insertions(+), 12 deletions(-)
diff --git a/src/operator/tensor/elemwise_binary_scalar_op_basic.cc b/src/operator/tensor/elemwise_binary_scalar_op_basic.cc
index 9a278d8..6792379 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_scalar_op_basic.cc
@@ -47,25 +47,36 @@
namespace mxnet {
namespace op {
+/*!
+ * \brief FInferStorageType for binary operator with scalar,
+ * csr -> csr and row_sparse -> row_sparse if the scalar is zero,
+ * otherwise the output is of default storage.
+ */
static bool BinaryScalarStorageTypeWithDenseResultStorageType(const NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
+ CHECK_EQ(in_attrs->size(), 1);
+ CHECK_EQ(out_attrs->size(), 1);
bool dispatched = false;
- if (common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
+ const bool invalid_ctx = dev_mask != kCPU;
+ const NDArrayStorageType instype = static_cast<NDArrayStorageType>(in_attrs->at(0));
+ const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback
+ : DispatchMode::kFComputeEx;
+ const double alpha = nnvm::get<double>(attrs.parsed);
+ if (instype == kDefaultStorage) {
dispatched = storage_type_assign(&out_attrs[0],
- kDefaultStorage,
- dispatch_mode,
- DispatchMode::kFCompute);
- } else if (dev_mask == kCPU) {
- dispatched = storage_type_assign(&out_attrs[0],
- kDefaultStorage,
- dispatch_mode,
- DispatchMode::kFComputeEx);
+ kDefaultStorage, dispatch_mode, DispatchMode::kFCompute);
+ }
+ if (!dispatched && (instype == kCSRStorage || instype == kRowSparseStorage)) {
+ dispatched = storage_type_assign(&out_attrs[0], alpha == 0 ? instype : kDefaultStorage,
+ dispatch_mode, dispatch_ex);
}
if (!dispatched) {
dispatch_fallback(out_attrs, dispatch_mode);
+ }
+ if (static_cast<DispatchMode>(*dispatch_mode) == DispatchMode::kFComputeFallback) {
LogStorageFallback(attrs, dev_mask, in_attrs, out_attrs);
}
return true;
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index 31e888b..52aca09 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -38,7 +38,7 @@ from test_sparse_ndarray import test_create_csr, test_create_row_sparse, test_sp
from test_sparse_ndarray import test_create_sparse_nd_empty, test_create_sparse_nd_from_sparse
from test_sparse_ndarray import test_create_sparse_nd_from_dense, test_create_sparse_nd_infer_shape
from test_sparse_ndarray import test_sparse_nd_check_format, test_sparse_nd_copy
-from test_sparse_ndarray import test_sparse_nd_setitem
+from test_sparse_ndarray import test_sparse_nd_setitem, test_sparse_nd_binary_scalar_op
from test_sparse_operator import *
from test_ndarray import *
diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py
index e404997..185ce7f 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -264,12 +264,14 @@ def test_sparse_nd_binary():
def test_sparse_nd_binary_scalar_op():
N = 3
- def check(fn, stype):
+ def check(fn, stype, out_stype=None):
for _ in range(N):
ndim = 2
shape = np.random.randint(1, 6, size=(ndim,))
npy = np.random.normal(0, 1, size=shape)
nd = mx.nd.array(npy).tostype(stype)
+ if out_stype is not None:
+ assert(nd.stype == out_stype)
assert_allclose(fn(npy), fn(nd).asnumpy(), rtol=1e-4, atol=1e-4)
stypes = ['row_sparse', 'csr']
@@ -285,7 +287,9 @@ def test_sparse_nd_binary_scalar_op():
check(lambda x: 0.5 >= x, stype)
check(lambda x: 0.5 <= x, stype)
check(lambda x: 0.5 == x, stype)
- check(lambda x: x / 2, stype)
+ check(lambda x: x / 2, stype, out_stype=stype)
+ check(lambda x: x + 0, stype, out_stype=stype)
+ check(lambda x: x - 0, stype, out_stype=stype)
def test_sparse_nd_binary_iop():
N = 3
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].