You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by li...@apache.org on 2020/07/04 01:14:51 UTC
[incubator-mxnet] branch master updated: [numpy] Fix less/greater
bug with scalar input (#18642)
This is an automated email from the ASF dual-hosted git repository.
liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 6462887 [numpy] Fix less/greater bug with scalar input (#18642)
6462887 is described below
commit 646288716cbba482d4ede0fb4f6141b2ea505090
Author: Yiyan66 <57...@users.noreply.github.com>
AuthorDate: Sat Jul 4 09:13:41 2020 +0800
[numpy] Fix less/greater bug with scalar input (#18642)
* fix ffi
* fix less/greater error
* back
* submodule
* fixed
Co-authored-by: Ubuntu <ub...@ip-172-31-8-94.us-east-2.compute.internal>
---
python/mxnet/ndarray/numpy/_op.py | 5 ++--
.../numpy/np_elemwise_broadcast_logic_op.cc | 34 ++++++++++++++++++----
.../python/unittest/test_numpy_interoperability.py | 8 +++++
3 files changed, 39 insertions(+), 8 deletions(-)
diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index 45f885a..91fea5f 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -7171,8 +7171,9 @@ def greater(x1, x2, out=None):
>>> np.greater(1, np.ones(1))
array([False])
"""
- return _ufunc_helper(x1, x2, _npi.greater, _np.greater, _npi.greater_scalar,
- _npi.less_scalar, out)
+ if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
+ return _np.greater(x1, x2, out=out)
+ return _api_internal.greater(x1, x2, out)
@set_module('mxnet.ndarray.numpy')
diff --git a/src/api/operator/numpy/np_elemwise_broadcast_logic_op.cc b/src/api/operator/numpy/np_elemwise_broadcast_logic_op.cc
index f0ca408..2248433 100644
--- a/src/api/operator/numpy/np_elemwise_broadcast_logic_op.cc
+++ b/src/api/operator/numpy/np_elemwise_broadcast_logic_op.cc
@@ -44,13 +44,35 @@ MXNET_REGISTER_API("_npi.not_equal")
UFuncHelper(args, ret, op, op_scalar, nullptr);
});
+void SetUFuncHelper(runtime::MXNetArgs args, runtime::MXNetRetValue* ret,
+ const nnvm::Op* op, const nnvm::Op* op_scalar,
+ const nnvm::Op* op_rscalar) {
+ if (args[0].type_code() == kNDArrayHandle &&
+ args[1].type_code() == kNDArrayHandle) {
+ UFuncHelper(args, ret, op, nullptr, nullptr);
+ } else if (args[0].type_code() == kNDArrayHandle) {
+ UFuncHelper(args, ret, nullptr, op_scalar, nullptr);
+ } else {
+ UFuncHelper(args, ret, nullptr, nullptr, op_rscalar);
+ }
+}
+
+MXNET_REGISTER_API("_npi.greater")
+.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
+ using namespace runtime;
+ const nnvm::Op* op = Op::Get("_npi_greater");
+ const nnvm::Op* op_scalar = Op::Get("_npi_greater_scalar");
+ const nnvm::Op* op_rscalar = Op::Get("_npi_less_scalar");
+ SetUFuncHelper(args, ret, op, op_scalar, op_rscalar);
+});
+
MXNET_REGISTER_API("_npi.less")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_less");
const nnvm::Op* op_scalar = Op::Get("_npi_less_scalar");
- const nnvm::Op* op_rscalar = Op::Get("_npi_less_scalar");
- UFuncHelper(args, ret, op, op_scalar, op_rscalar);
+ const nnvm::Op* op_rscalar = Op::Get("_npi_greater_scalar");
+ SetUFuncHelper(args, ret, op, op_scalar, op_rscalar);
});
MXNET_REGISTER_API("_npi.greater_equal")
@@ -58,8 +80,8 @@ MXNET_REGISTER_API("_npi.greater_equal")
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_greater_equal");
const nnvm::Op* op_scalar = Op::Get("_npi_greater_equal_scalar");
- const nnvm::Op* op_rscalar = Op::Get("_npi_greater_equal_scalar");
- UFuncHelper(args, ret, op, op_scalar, op_rscalar);
+ const nnvm::Op* op_rscalar = Op::Get("_npi_less_equal_scalar");
+ SetUFuncHelper(args, ret, op, op_scalar, op_rscalar);
});
MXNET_REGISTER_API("_npi.less_equal")
@@ -67,8 +89,8 @@ MXNET_REGISTER_API("_npi.less_equal")
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_less_equal");
const nnvm::Op* op_scalar = Op::Get("_npi_less_equal_scalar");
- const nnvm::Op* op_rscalar = Op::Get("_npi_less_equal_scalar");
- UFuncHelper(args, ret, op, op_scalar, op_rscalar);
+ const nnvm::Op* op_rscalar = Op::Get("_npi_greater_equal_scalar");
+ SetUFuncHelper(args, ret, op, op_scalar, op_rscalar);
});
} // namespace mxnet
diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py
index 6a2845e..8b50fc4 100644
--- a/tests/python/unittest/test_numpy_interoperability.py
+++ b/tests/python/unittest/test_numpy_interoperability.py
@@ -1947,6 +1947,8 @@ def _add_workload_greater(array_pool):
# OpArgMngr.add_workload('greater', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
OpArgMngr.add_workload('greater', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
OpArgMngr.add_workload('greater', array_pool['4x1'], array_pool['1x2'])
+ OpArgMngr.add_workload('greater', array_pool['4x1'], 2)
+ OpArgMngr.add_workload('greater', 2, array_pool['4x1'])
# TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
# OpArgMngr.add_workload('greater', np.array([np.nan]), np.array([np.nan]))
@@ -1956,6 +1958,8 @@ def _add_workload_greater_equal(array_pool):
# OpArgMngr.add_workload('greater_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
OpArgMngr.add_workload('greater_equal', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
OpArgMngr.add_workload('greater_equal', array_pool['4x1'], array_pool['1x2'])
+ OpArgMngr.add_workload('greater_equal', array_pool['4x1'], 2)
+ OpArgMngr.add_workload('greater_equal', 2, array_pool['4x1'])
# TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
# OpArgMngr.add_workload('greater_equal', np.array([np.nan]), np.array([np.nan]))
@@ -1965,6 +1969,8 @@ def _add_workload_less(array_pool):
# OpArgMngr.add_workload('less', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
OpArgMngr.add_workload('less', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
OpArgMngr.add_workload('less', array_pool['4x1'], array_pool['1x2'])
+ OpArgMngr.add_workload('less', array_pool['4x1'], 2)
+ OpArgMngr.add_workload('less', 2, array_pool['4x1'])
# TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
# OpArgMngr.add_workload('less', np.array([np.nan]), np.array([np.nan]))
@@ -1974,6 +1980,8 @@ def _add_workload_less_equal(array_pool):
# OpArgMngr.add_workload('less_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
OpArgMngr.add_workload('less_equal', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
OpArgMngr.add_workload('less_equal', array_pool['4x1'], array_pool['1x2'])
+ OpArgMngr.add_workload('less_equal', array_pool['4x1'], 2)
+ OpArgMngr.add_workload('less_equal', 2, array_pool['4x1'])
# TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
# OpArgMngr.add_workload('less_equal', np.array([np.nan]), np.array([np.nan]))