You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by li...@apache.org on 2020/07/04 01:14:51 UTC

[incubator-mxnet] branch master updated: [numpy] Fix less/greater bug with scalar input (#18642)

This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 6462887  [numpy] Fix less/greater bug with scalar input (#18642)
6462887 is described below

commit 646288716cbba482d4ede0fb4f6141b2ea505090
Author: Yiyan66 <57...@users.noreply.github.com>
AuthorDate: Sat Jul 4 09:13:41 2020 +0800

    [numpy] Fix less/greater bug with scalar input (#18642)
    
    * fix ffi
    
    * fix less/greater error
    
    * back
    
    * submodule
    
    * fixed
    
    Co-authored-by: Ubuntu <ub...@ip-172-31-8-94.us-east-2.compute.internal>
---
 python/mxnet/ndarray/numpy/_op.py                  |  5 ++--
 .../numpy/np_elemwise_broadcast_logic_op.cc        | 34 ++++++++++++++++++----
 .../python/unittest/test_numpy_interoperability.py |  8 +++++
 3 files changed, 39 insertions(+), 8 deletions(-)

diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index 45f885a..91fea5f 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -7171,8 +7171,9 @@ def greater(x1, x2, out=None):
     >>> np.greater(1, np.ones(1))
     array([False])
     """
-    return _ufunc_helper(x1, x2, _npi.greater, _np.greater, _npi.greater_scalar,
-                         _npi.less_scalar, out)
+    if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
+        return _np.greater(x1, x2, out=out)
+    return _api_internal.greater(x1, x2, out)
 
 
 @set_module('mxnet.ndarray.numpy')
diff --git a/src/api/operator/numpy/np_elemwise_broadcast_logic_op.cc b/src/api/operator/numpy/np_elemwise_broadcast_logic_op.cc
index f0ca408..2248433 100644
--- a/src/api/operator/numpy/np_elemwise_broadcast_logic_op.cc
+++ b/src/api/operator/numpy/np_elemwise_broadcast_logic_op.cc
@@ -44,13 +44,35 @@ MXNET_REGISTER_API("_npi.not_equal")
   UFuncHelper(args, ret, op, op_scalar, nullptr);
 });
 
+void SetUFuncHelper(runtime::MXNetArgs args, runtime::MXNetRetValue* ret,
+                     const nnvm::Op* op, const nnvm::Op* op_scalar,
+                     const nnvm::Op* op_rscalar) {
+  if (args[0].type_code() == kNDArrayHandle &&
+      args[1].type_code() == kNDArrayHandle) {
+    UFuncHelper(args, ret, op, nullptr, nullptr);
+  } else if (args[0].type_code() == kNDArrayHandle) {
+    UFuncHelper(args, ret, nullptr, op_scalar, nullptr);
+  } else {
+    UFuncHelper(args, ret, nullptr, nullptr, op_rscalar);
+  }
+}
+
+MXNET_REGISTER_API("_npi.greater")
+.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
+  using namespace runtime;
+  const nnvm::Op* op = Op::Get("_npi_greater");
+  const nnvm::Op* op_scalar = Op::Get("_npi_greater_scalar");
+  const nnvm::Op* op_rscalar = Op::Get("_npi_less_scalar");
+  SetUFuncHelper(args, ret, op, op_scalar, op_rscalar);
+});
+
 MXNET_REGISTER_API("_npi.less")
 .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
   using namespace runtime;
   const nnvm::Op* op = Op::Get("_npi_less");
   const nnvm::Op* op_scalar = Op::Get("_npi_less_scalar");
-  const nnvm::Op* op_rscalar = Op::Get("_npi_less_scalar");
-  UFuncHelper(args, ret, op, op_scalar, op_rscalar);
+  const nnvm::Op* op_rscalar = Op::Get("_npi_greater_scalar");
+  SetUFuncHelper(args, ret, op, op_scalar, op_rscalar);
 });
 
 MXNET_REGISTER_API("_npi.greater_equal")
@@ -58,8 +80,8 @@ MXNET_REGISTER_API("_npi.greater_equal")
   using namespace runtime;
   const nnvm::Op* op = Op::Get("_npi_greater_equal");
   const nnvm::Op* op_scalar = Op::Get("_npi_greater_equal_scalar");
-  const nnvm::Op* op_rscalar = Op::Get("_npi_greater_equal_scalar");
-  UFuncHelper(args, ret, op, op_scalar, op_rscalar);
+  const nnvm::Op* op_rscalar = Op::Get("_npi_less_equal_scalar");
+  SetUFuncHelper(args, ret, op, op_scalar, op_rscalar);
 });
 
 MXNET_REGISTER_API("_npi.less_equal")
@@ -67,8 +89,8 @@ MXNET_REGISTER_API("_npi.less_equal")
   using namespace runtime;
   const nnvm::Op* op = Op::Get("_npi_less_equal");
   const nnvm::Op* op_scalar = Op::Get("_npi_less_equal_scalar");
-  const nnvm::Op* op_rscalar = Op::Get("_npi_less_equal_scalar");
-  UFuncHelper(args, ret, op, op_scalar, op_rscalar);
+  const nnvm::Op* op_rscalar = Op::Get("_npi_greater_equal_scalar");
+  SetUFuncHelper(args, ret, op, op_scalar, op_rscalar);
 });
 
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py
index 6a2845e..8b50fc4 100644
--- a/tests/python/unittest/test_numpy_interoperability.py
+++ b/tests/python/unittest/test_numpy_interoperability.py
@@ -1947,6 +1947,8 @@ def _add_workload_greater(array_pool):
     # OpArgMngr.add_workload('greater', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
     OpArgMngr.add_workload('greater', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
     OpArgMngr.add_workload('greater', array_pool['4x1'], array_pool['1x2'])
+    OpArgMngr.add_workload('greater', array_pool['4x1'], 2)
+    OpArgMngr.add_workload('greater', 2, array_pool['4x1'])
     # TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
     # OpArgMngr.add_workload('greater', np.array([np.nan]), np.array([np.nan]))
 
@@ -1956,6 +1958,8 @@ def _add_workload_greater_equal(array_pool):
     # OpArgMngr.add_workload('greater_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
     OpArgMngr.add_workload('greater_equal', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
     OpArgMngr.add_workload('greater_equal', array_pool['4x1'], array_pool['1x2'])
+    OpArgMngr.add_workload('greater_equal', array_pool['4x1'], 2)
+    OpArgMngr.add_workload('greater_equal', 2, array_pool['4x1'])
     # TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
     # OpArgMngr.add_workload('greater_equal', np.array([np.nan]), np.array([np.nan]))
 
@@ -1965,6 +1969,8 @@ def _add_workload_less(array_pool):
     # OpArgMngr.add_workload('less', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
     OpArgMngr.add_workload('less', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
     OpArgMngr.add_workload('less', array_pool['4x1'], array_pool['1x2'])
+    OpArgMngr.add_workload('less', array_pool['4x1'], 2)
+    OpArgMngr.add_workload('less', 2, array_pool['4x1'])
     # TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
     # OpArgMngr.add_workload('less', np.array([np.nan]), np.array([np.nan]))
 
@@ -1974,6 +1980,8 @@ def _add_workload_less_equal(array_pool):
     # OpArgMngr.add_workload('less_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
     OpArgMngr.add_workload('less_equal', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
     OpArgMngr.add_workload('less_equal', array_pool['4x1'], array_pool['1x2'])
+    OpArgMngr.add_workload('less_equal', array_pool['4x1'], 2)
+    OpArgMngr.add_workload('less_equal', 2, array_pool['4x1'])
     # TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
     # OpArgMngr.add_workload('less_equal', np.array([np.nan]), np.array([np.nan]))