You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/05/01 16:43:48 UTC
[incubator-mxnet] branch master updated: [MXNET-347] Logical
Operators AND, XOR, OR (#10679)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 61f86fc [MXNET-347] Logical Operators AND, XOR, OR (#10679)
61f86fc is described below
commit 61f86fcfdc2780403cbacb46903caac786bacb2b
Author: Anirudh <an...@gmail.com>
AuthorDate: Tue May 1 09:43:38 2018 -0700
[MXNET-347] Logical Operators AND, XOR, OR (#10679)
* logical and
* logical OR and XOR operators.
* better examples
* nits.
* elemwise operators
* non broadcast examples and tests.
* doc API
* rerun CI
---
docs/api/python/ndarray/ndarray.md | 12 ++
docs/api/python/symbol/symbol.md | 12 ++
python/mxnet/ndarray/ndarray.py | 181 ++++++++++++++++++++-
src/operator/mshadow_op.h | 6 +
src/operator/operator_tune.cc | 6 +
.../tensor/elemwise_binary_broadcast_op_logic.cc | 54 ++++++
.../tensor/elemwise_binary_broadcast_op_logic.cu | 9 +
src/operator/tensor/elemwise_binary_op_logic.cc | 15 ++
src/operator/tensor/elemwise_binary_op_logic.cu | 9 +
.../tensor/elemwise_binary_scalar_op_logic.cc | 15 ++
.../tensor/elemwise_binary_scalar_op_logic.cu | 9 +
tests/python/unittest/test_operator.py | 25 ++-
12 files changed, 348 insertions(+), 5 deletions(-)
diff --git a/docs/api/python/ndarray/ndarray.md b/docs/api/python/ndarray/ndarray.md
index fe8abd8..5bc3c52 100644
--- a/docs/api/python/ndarray/ndarray.md
+++ b/docs/api/python/ndarray/ndarray.md
@@ -555,6 +555,18 @@ The `ndarray` package provides several classes:
lesser_equal
```
+### Logical operators
+
+```eval_rst
+.. autosummary::
+ :nosignatures:
+
+ logical_and
+ logical_or
+ logical_xor
+ logical_not
+```
+
### Random sampling
```eval_rst
diff --git a/docs/api/python/symbol/symbol.md b/docs/api/python/symbol/symbol.md
index d18b2b2..f1e90a0 100644
--- a/docs/api/python/symbol/symbol.md
+++ b/docs/api/python/symbol/symbol.md
@@ -554,6 +554,18 @@ Composite multiple symbols into a new one by an operator.
broadcast_lesser_equal
```
+### Logical
+
+```eval_rst
+.. autosummary::
+ :nosignatures:
+
+ broadcast_logical_and
+ broadcast_logical_or
+ broadcast_logical_xor
+ broadcast_logical_not
+```
+
### Random sampling
```eval_rst
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 361aa24..6b2ff23 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -44,9 +44,9 @@ from ._internal import NDArrayBase
__all__ = ["NDArray", "concatenate", "_DTYPE_NP_TO_MX", "_DTYPE_MX_TO_NP", "_GRAD_REQ_MAP",
"ones", "add", "arange", "eye", "divide", "equal", "full", "greater", "greater_equal",
- "imdecode", "lesser", "lesser_equal", "maximum", "minimum", "moveaxis", "modulo",
- "multiply", "not_equal", "onehot_encode", "power", "subtract", "true_divide",
- "waitall", "_new_empty_handle"]
+ "imdecode", "lesser", "lesser_equal", "logical_and", "logical_or", "logical_xor",
+ "maximum", "minimum", "moveaxis", "modulo", "multiply", "not_equal", "onehot_encode",
+ "power", "subtract", "true_divide", "waitall", "_new_empty_handle"]
_STORAGE_TYPE_UNDEFINED = -1
_STORAGE_TYPE_DEFAULT = 0
@@ -2485,7 +2485,7 @@ def add(lhs, rhs):
.. note::
If the corresponding dimensions of two arrays have the same size or one of them has size 1,
- then the arrays are broadcastable to a common shape.
+ then the arrays are broadcastable to a common shape
Parameters
----------
@@ -3337,6 +3337,179 @@ def lesser_equal(lhs, rhs):
_internal._greater_equal_scalar)
# pylint: enable= no-member, protected-access
+def logical_and(lhs, rhs):
+ """Returns the result of element-wise **logical and** comparison
+ operation with broadcasting.
+
+ For each element in input arrays, return 1(true) if lhs elements and rhs elements
+ are true, otherwise return 0(false).
+
+ Equivalent to ``lhs and rhs`` and ``mx.nd.broadcast_logical_and(lhs, rhs)``.
+
+ .. note::
+
+ If the corresponding dimensions of two arrays have the same size or one of them has size 1,
+ then the arrays are broadcastable to a common shape.
+
+ Parameters
+ ----------
+ lhs : scalar or array
+ First input of the function.
+ rhs : scalar or array
+ Second input of the function. If ``lhs.shape != rhs.shape``, they must be
+ broadcastable to a common shape.
+
+ Returns
+ -------
+ NDArray
+ Output array of boolean values.
+
+ Examples
+ --------
+ >>> x = mx.nd.ones((2,3))
+ >>> y = mx.nd.arange(2).reshape((2,1))
+ >>> z = mx.nd.arange(2).reshape((1,2))
+ >>> x.asnumpy()
+ array([[ 1., 1., 1.],
+ [ 1., 1., 1.]], dtype=float32)
+ >>> y.asnumpy()
+ array([[ 0.],
+ [ 1.]], dtype=float32)
+ >>> z.asnumpy()
+ array([[ 0., 1.]], dtype=float32)
+ >>> mx.nd.logical_and(x, 1).asnumpy()
+ array([[ 1., 1., 1.],
+ [ 1., 1., 1.]], dtype=float32)
+ >>> mx.nd.logical_and(x, y).asnumpy()
+ array([[ 0., 0., 0.],
+ [ 1., 1., 1.]], dtype=float32)
+ >>> mx.nd.logical_and(z, y).asnumpy()
+ array([[ 0., 0.],
+ [ 0., 1.]], dtype=float32)
+ """
+ # pylint: disable= no-member, protected-access
+ return _ufunc_helper(
+ lhs,
+ rhs,
+ op.broadcast_logical_and,
+ lambda x, y: 1 if x and y else 0,
+ _internal._logical_and_scalar,
+ None)
+ # pylint: enable= no-member, protected-access
+
+def logical_or(lhs, rhs):
+ """Returns the result of element-wise **logical or** comparison
+ operation with broadcasting.
+
+ For each element in input arrays, return 1(true) if lhs elements or rhs elements
+ are true, otherwise return 0(false).
+
+ Equivalent to ``lhs or rhs`` and ``mx.nd.broadcast_logical_or(lhs, rhs)``.
+
+ .. note::
+
+ If the corresponding dimensions of two arrays have the same size or one of them has size 1,
+ then the arrays are broadcastable to a common shape.
+
+ Parameters
+ ----------
+ lhs : scalar or array
+ First input of the function.
+ rhs : scalar or array
+ Second input of the function. If ``lhs.shape != rhs.shape``, they must be
+ broadcastable to a common shape.
+
+ Returns
+ -------
+ NDArray
+ Output array of boolean values.
+
+ Examples
+ --------
+ >>> x = mx.nd.ones((2,3))
+ >>> y = mx.nd.arange(2).reshape((2,1))
+ >>> z = mx.nd.arange(2).reshape((1,2))
+ >>> x.asnumpy()
+ array([[ 1., 1., 1.],
+ [ 1., 1., 1.]], dtype=float32)
+ >>> y.asnumpy()
+ array([[ 0.],
+ [ 1.]], dtype=float32)
+ >>> z.asnumpy()
+ array([[ 0., 1.]], dtype=float32)
+ >>> mx.nd.logical_or(x, 1).asnumpy()
+ array([[ 1., 1., 1.],
+ [ 1., 1., 1.]], dtype=float32)
+ >>> mx.nd.logical_or(x, y).asnumpy()
+ array([[ 1., 1., 1.],
+ [ 1., 1., 1.]], dtype=float32)
+ >>> mx.nd.logical_or(z, y).asnumpy()
+ array([[ 0., 1.],
+ [ 1., 1.]], dtype=float32)
+ """
+ # pylint: disable= no-member, protected-access
+ return _ufunc_helper(
+ lhs,
+ rhs,
+ op.broadcast_logical_or,
+ lambda x, y: 1 if x or y else 0,
+ _internal._logical_or_scalar,
+ None)
+ # pylint: enable= no-member, protected-access
+
+def logical_xor(lhs, rhs):
+ """Returns the result of element-wise **logical xor** comparison
+ operation with broadcasting.
+
+ For each element in input arrays, return 1(true) if lhs elements or rhs elements
+ are true, otherwise return 0(false).
+
+ Equivalent to ``bool(lhs) ^ bool(rhs)`` and ``mx.nd.broadcast_logical_xor(lhs, rhs)``.
+
+ .. note::
+
+ If the corresponding dimensions of two arrays have the same size or one of them has size 1,
+ then the arrays are broadcastable to a common shape.
+
+ Parameters
+ ----------
+ lhs : scalar or array
+ First input of the function.
+ rhs : scalar or array
+ Second input of the function. If ``lhs.shape != rhs.shape``, they must be
+ broadcastable to a common shape.
+
+ Returns
+ -------
+ NDArray
+ Output array of boolean values.
+
+ Examples
+ --------
+ >>> x = mx.nd.ones((2,3))
+ >>> y = mx.nd.arange(2).reshape((2,1))
+ >>> z = mx.nd.arange(2).reshape((1,2))
+ >>> x.asnumpy()
+ array([[ 1., 1., 1.],
+ [ 1., 1., 1.]], dtype=float32)
+ >>> y.asnumpy()
+ array([[ 0.],
+ [ 1.]], dtype=float32)
+ >>> z.asnumpy()
+ array([[ 0., 1.]], dtype=float32)
+ >>> mx.nd.logical_xor(x, y).asnumpy()
+ array([[ 1., 1., 1.],
+ [ 0., 0., 0.]], dtype=float32)
+ """
+ # pylint: disable= no-member, protected-access
+ return _ufunc_helper(
+ lhs,
+ rhs,
+ op.broadcast_logical_xor,
+ lambda x, y: 1 if bool(x) ^ bool(y) else 0,
+ _internal._logical_xor_scalar,
+ None)
+ # pylint: enable= no-member, protected-access
def true_divide(lhs, rhs):
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index 2f5dd97..19fa4f8 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -317,6 +317,12 @@ MXNET_BINARY_MATH_OP_NC(eq, a == b ? DType(1) : DType(0));
MXNET_BINARY_MATH_OP_NC(ne, a != b ? DType(1) : DType(0));
+MXNET_BINARY_MATH_OP(logical_and, a && b ? DType(1) : DType(0));
+
+MXNET_BINARY_MATH_OP(logical_or, a || b ? DType(1) : DType(0));
+
+MXNET_BINARY_MATH_OP(logical_xor, (a || b) && !(a && b) ? DType(1) : DType(0));
+
MXNET_UNARY_MATH_OP(square_root, math::sqrt(a));
MXNET_UNARY_MATH_OP(square_root_grad, 0.5f / math::id(a));
diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc
index 47db78b..de3c742 100644
--- a/src/operator/operator_tune.cc
+++ b/src/operator/operator_tune.cc
@@ -342,6 +342,12 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ne); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ne); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::eq); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::eq); // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_and); // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_and); // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_or); // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_or); // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_xor); // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_xor); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient); // NOLINT()
IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<0>); // NOLINT()
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc b/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc
index 31f34bb..3cb3ba3 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc
+++ b/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc
@@ -137,5 +137,59 @@ Example::
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::le>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_logical_and)
+.describe(R"code(Returns the result of element-wise **logical and** with broadcasting.
+
+Example::
+
+ x = [[ 1., 1., 1.],
+ [ 1., 1., 1.]]
+
+ y = [[ 0.],
+ [ 1.]]
+
+ broadcast_logical_and(x, y) = [[ 0., 0., 0.],
+ [ 1., 1., 1.]]
+
+)code" ADD_FILELINE)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::logical_and>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
+
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_logical_or)
+.describe(R"code(Returns the result of element-wise **logical or** with broadcasting.
+
+Example::
+
+ x = [[ 1., 1., 0.],
+ [ 1., 1., 0.]]
+
+ y = [[ 1.],
+ [ 0.]]
+
+ broadcast_logical_or(x, y) = [[ 1., 1., 1.],
+ [ 1., 1., 0.]]
+
+)code" ADD_FILELINE)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::logical_or>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
+
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_logical_xor)
+.describe(R"code(Returns the result of element-wise **logical xor** with broadcasting.
+
+Example::
+
+ x = [[ 1., 1., 0.],
+ [ 1., 1., 0.]]
+
+ y = [[ 1.],
+ [ 0.]]
+
+ broadcast_logical_xor(x, y) = [[ 0., 0., 1.],
+ [ 1., 1., 0.]]
+
+)code" ADD_FILELINE)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::logical_xor>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_logic.cu b/src/operator/tensor/elemwise_binary_broadcast_op_logic.cu
index 4e80ae9..d6b01aa 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op_logic.cu
+++ b/src/operator/tensor/elemwise_binary_broadcast_op_logic.cu
@@ -47,5 +47,14 @@ NNVM_REGISTER_OP(broadcast_lesser)
NNVM_REGISTER_OP(broadcast_lesser_equal)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::le>);
+NNVM_REGISTER_OP(broadcast_logical_and)
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::logical_and>);
+
+NNVM_REGISTER_OP(broadcast_logical_or)
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::logical_or>);
+
+NNVM_REGISTER_OP(broadcast_logical_xor)
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::logical_xor>);
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/tensor/elemwise_binary_op_logic.cc b/src/operator/tensor/elemwise_binary_op_logic.cc
index 5d328b5..5c0b442 100644
--- a/src/operator/tensor/elemwise_binary_op_logic.cc
+++ b/src/operator/tensor/elemwise_binary_op_logic.cc
@@ -57,5 +57,20 @@ MXNET_OPERATOR_REGISTER_BINARY(_lesser_equal)
.set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, mshadow_op::le>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
+MXNET_OPERATOR_REGISTER_BINARY(_logical_and)
+.add_alias("_Logical_And")
+.set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, mshadow_op::logical_and>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
+
+MXNET_OPERATOR_REGISTER_BINARY(_logical_or)
+.add_alias("_Logical_Or")
+.set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, mshadow_op::logical_or>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
+
+MXNET_OPERATOR_REGISTER_BINARY(_logical_xor)
+.add_alias("_Logical_Xor")
+.set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, mshadow_op::logical_xor>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/tensor/elemwise_binary_op_logic.cu b/src/operator/tensor/elemwise_binary_op_logic.cu
index be5b722..456622d 100644
--- a/src/operator/tensor/elemwise_binary_op_logic.cu
+++ b/src/operator/tensor/elemwise_binary_op_logic.cu
@@ -45,5 +45,14 @@ NNVM_REGISTER_OP(_lesser)
NNVM_REGISTER_OP(_lesser_equal)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<gpu, mshadow_op::le>);
+NNVM_REGISTER_OP(_logical_and)
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<gpu, mshadow_op::logical_and>);
+
+NNVM_REGISTER_OP(_logical_or)
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<gpu, mshadow_op::logical_or>);
+
+NNVM_REGISTER_OP(_logical_xor)
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<gpu, mshadow_op::logical_xor>);
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/tensor/elemwise_binary_scalar_op_logic.cc b/src/operator/tensor/elemwise_binary_scalar_op_logic.cc
index 61f1dd0..fafd840 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op_logic.cc
+++ b/src/operator/tensor/elemwise_binary_scalar_op_logic.cc
@@ -59,5 +59,20 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_lesser_equal_scalar)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_alias("_LesserEqualScalar");
+MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_logical_and_scalar)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::logical_and>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.add_alias("_LogicalAndScalar");
+
+MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_logical_or_scalar)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::logical_or>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.add_alias("_LogicalOrScalar");
+
+MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_logical_xor_scalar)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::logical_xor>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.add_alias("_LogicalXorScalar");
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/tensor/elemwise_binary_scalar_op_logic.cu b/src/operator/tensor/elemwise_binary_scalar_op_logic.cu
index 91bcaa8..92b01a0 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op_logic.cu
+++ b/src/operator/tensor/elemwise_binary_scalar_op_logic.cu
@@ -45,5 +45,14 @@ NNVM_REGISTER_OP(_lesser_scalar)
NNVM_REGISTER_OP(_lesser_equal_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::le>);
+NNVM_REGISTER_OP(_logical_and_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::logical_and>);
+
+NNVM_REGISTER_OP(_logical_or_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::logical_or>);
+
+NNVM_REGISTER_OP(_logical_xor_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::logical_xor>);
+
} // namespace op
} // namespace mxnet
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 6e5c908..838d8d8 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -1613,6 +1613,27 @@ def test_broadcast_binary_op():
data = gen_broadcast_data(idx=200)
check_bmaxmin_gradient(c, data[0], data[1], 0.001, 1e-2, 1e-3)
+ def test_band(a, b):
+ c = mx.sym.broadcast_logical_and(a, b)
+ check_binary_op_forward(c, lambda x, y: np.logical_and(x, y), gen_broadcast_data, mx_nd_func=mx.nd.logical_and)
+ # pass idx=200 to gen_broadcast_data so that generated ndarrays' sizes are not too big
+ data = gen_broadcast_data(idx=200)
+ check_bmaxmin_gradient(c, data[0], data[1], 0.001, 1e-2, 1e-3)
+
+ def test_bor(a, b):
+ c = mx.sym.broadcast_logical_or(a, b)
+ check_binary_op_forward(c, lambda x, y: np.logical_or(x, y), gen_broadcast_data, mx_nd_func=mx.nd.logical_or)
+ # pass idx=200 to gen_broadcast_data so that generated ndarrays' sizes are not too big
+ data = gen_broadcast_data(idx=200)
+ check_bmaxmin_gradient(c, data[0], data[1], 0.001, 1e-2, 1e-3)
+
+ def test_bxor(a, b):
+ c = mx.sym.broadcast_logical_xor(a, b)
+ check_binary_op_forward(c, lambda x, y: np.logical_xor(x, y), gen_broadcast_data, mx_nd_func=mx.nd.logical_xor)
+ # pass idx=200 to gen_broadcast_data so that generated ndarrays' sizes are not too big
+ data = gen_broadcast_data(idx=200)
+ check_bmaxmin_gradient(c, data[0], data[1], 0.001, 1e-2, 1e-3)
+
test_bplus(a, b)
test_bminus(a, b)
test_bmul(a, b)
@@ -1623,7 +1644,9 @@ def test_broadcast_binary_op():
test_bequal(a, b)
test_bmax(a, b)
test_bmin(a, b)
-
+ test_band(a, b)
+ test_bor(a, b)
+ test_bxor(a, b)
@with_seed()
def test_run_convolution_dilated_impulse_response(dil=(1,1), kernel_shape=(3,3), verbose=False):
--
To stop receiving notification emails like this one, please contact
haibin@apache.org.