You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/05/01 16:43:48 UTC

[incubator-mxnet] branch master updated: [MXNET-347] Logical Operators AND, XOR, OR (#10679)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 61f86fc  [MXNET-347] Logical Operators AND, XOR, OR (#10679)
61f86fc is described below

commit 61f86fcfdc2780403cbacb46903caac786bacb2b
Author: Anirudh <an...@gmail.com>
AuthorDate: Tue May 1 09:43:38 2018 -0700

    [MXNET-347] Logical Operators AND, XOR, OR (#10679)
    
    * logical and
    
    * logical OR and XOR operators.
    
    * better examples
    
    * nits.
    
    * elemwise operators
    
    * non broadcast examples and tests.
    
    * doc API
    
    * rerun CI
---
 docs/api/python/ndarray/ndarray.md                 |  12 ++
 docs/api/python/symbol/symbol.md                   |  12 ++
 python/mxnet/ndarray/ndarray.py                    | 181 ++++++++++++++++++++-
 src/operator/mshadow_op.h                          |   6 +
 src/operator/operator_tune.cc                      |   6 +
 .../tensor/elemwise_binary_broadcast_op_logic.cc   |  54 ++++++
 .../tensor/elemwise_binary_broadcast_op_logic.cu   |   9 +
 src/operator/tensor/elemwise_binary_op_logic.cc    |  15 ++
 src/operator/tensor/elemwise_binary_op_logic.cu    |   9 +
 .../tensor/elemwise_binary_scalar_op_logic.cc      |  15 ++
 .../tensor/elemwise_binary_scalar_op_logic.cu      |   9 +
 tests/python/unittest/test_operator.py             |  25 ++-
 12 files changed, 348 insertions(+), 5 deletions(-)

diff --git a/docs/api/python/ndarray/ndarray.md b/docs/api/python/ndarray/ndarray.md
index fe8abd8..5bc3c52 100644
--- a/docs/api/python/ndarray/ndarray.md
+++ b/docs/api/python/ndarray/ndarray.md
@@ -555,6 +555,18 @@ The `ndarray` package provides several classes:
     lesser_equal
 ```
 
+### Logical operators
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    logical_and
+    logical_or
+    logical_xor
+    logical_not
+```
+
 ### Random sampling
 
 ```eval_rst
diff --git a/docs/api/python/symbol/symbol.md b/docs/api/python/symbol/symbol.md
index d18b2b2..f1e90a0 100644
--- a/docs/api/python/symbol/symbol.md
+++ b/docs/api/python/symbol/symbol.md
@@ -554,6 +554,18 @@ Composite multiple symbols into a new one by an operator.
     broadcast_lesser_equal
 ```
 
+### Logical
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    broadcast_logical_and
+    broadcast_logical_or
+    broadcast_logical_xor
+    broadcast_logical_not
+```
+
 ### Random sampling
 
 ```eval_rst
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 361aa24..6b2ff23 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -44,9 +44,9 @@ from ._internal import NDArrayBase
 
 __all__ = ["NDArray", "concatenate", "_DTYPE_NP_TO_MX", "_DTYPE_MX_TO_NP", "_GRAD_REQ_MAP",
            "ones", "add", "arange", "eye", "divide", "equal", "full", "greater", "greater_equal",
-           "imdecode", "lesser", "lesser_equal", "maximum", "minimum", "moveaxis", "modulo",
-           "multiply", "not_equal", "onehot_encode", "power", "subtract", "true_divide",
-           "waitall", "_new_empty_handle"]
+           "imdecode", "lesser", "lesser_equal", "logical_and", "logical_or", "logical_xor",
+           "maximum", "minimum", "moveaxis", "modulo", "multiply", "not_equal", "onehot_encode",
+           "power", "subtract", "true_divide", "waitall", "_new_empty_handle"]
 
 _STORAGE_TYPE_UNDEFINED = -1
 _STORAGE_TYPE_DEFAULT = 0
@@ -2485,7 +2485,7 @@ def add(lhs, rhs):
     .. note::
 
        If the corresponding dimensions of two arrays have the same size or one of them has size 1,
-       then the arrays are broadcastable to a common shape.
+       then the arrays are broadcastable to a common shape
 
     Parameters
     ----------
@@ -3337,6 +3337,179 @@ def lesser_equal(lhs, rhs):
         _internal._greater_equal_scalar)
     # pylint: enable= no-member, protected-access
 
+def logical_and(lhs, rhs):
+    """Returns the result of element-wise **logical and** comparison
+    operation with broadcasting.
+
+    For each element in input arrays, return 1(true) if lhs elements and rhs elements
+    are true, otherwise return 0(false).
+
+    Equivalent to ``lhs and rhs`` and ``mx.nd.broadcast_logical_and(lhs, rhs)``.
+
+    .. note::
+
+       If the corresponding dimensions of two arrays have the same size or one of them has size 1,
+       then the arrays are broadcastable to a common shape.
+
+    Parameters
+    ----------
+    lhs : scalar or array
+        First input of the function.
+    rhs : scalar or array
+         Second input of the function. If ``lhs.shape != rhs.shape``, they must be
+        broadcastable to a common shape.
+
+    Returns
+    -------
+    NDArray
+        Output array of boolean values.
+
+    Examples
+    --------
+    >>> x = mx.nd.ones((2,3))
+    >>> y = mx.nd.arange(2).reshape((2,1))
+    >>> z = mx.nd.arange(2).reshape((1,2))
+    >>> x.asnumpy()
+    array([[ 1.,  1.,  1.],
+           [ 1.,  1.,  1.]], dtype=float32)
+    >>> y.asnumpy()
+    array([[ 0.],
+           [ 1.]], dtype=float32)
+    >>> z.asnumpy()
+    array([[ 0.,  1.]], dtype=float32)
+    >>> mx.nd.logical_and(x, 1).asnumpy()
+    array([[ 1.,  1.,  1.],
+           [ 1.,  1.,  1.]], dtype=float32)
+    >>> mx.nd.logical_and(x, y).asnumpy()
+    array([[ 0.,  0.,  0.],
+           [ 1.,  1.,  1.]], dtype=float32)
+    >>> mx.nd.logical_and(z, y).asnumpy()
+    array([[ 0.,  0.],
+           [ 0.,  1.]], dtype=float32)
+    """
+    # pylint: disable= no-member, protected-access
+    return _ufunc_helper(
+        lhs,
+        rhs,
+        op.broadcast_logical_and,
+        lambda x, y: 1 if x and y else 0,
+        _internal._logical_and_scalar,
+        None)
+    # pylint: enable= no-member, protected-access
+
+def logical_or(lhs, rhs):
+    """Returns the result of element-wise **logical or** comparison
+    operation with broadcasting.
+
+    For each element in input arrays, return 1(true) if lhs elements or rhs elements
+    are true, otherwise return 0(false).
+
+    Equivalent to ``lhs or rhs`` and ``mx.nd.broadcast_logical_or(lhs, rhs)``.
+
+    .. note::
+
+       If the corresponding dimensions of two arrays have the same size or one of them has size 1,
+       then the arrays are broadcastable to a common shape.
+
+    Parameters
+    ----------
+    lhs : scalar or array
+        First input of the function.
+    rhs : scalar or array
+         Second input of the function. If ``lhs.shape != rhs.shape``, they must be
+        broadcastable to a common shape.
+
+    Returns
+    -------
+    NDArray
+        Output array of boolean values.
+
+    Examples
+    --------
+    >>> x = mx.nd.ones((2,3))
+    >>> y = mx.nd.arange(2).reshape((2,1))
+    >>> z = mx.nd.arange(2).reshape((1,2))
+    >>> x.asnumpy()
+    array([[ 1.,  1.,  1.],
+           [ 1.,  1.,  1.]], dtype=float32)
+    >>> y.asnumpy()
+    array([[ 0.],
+           [ 1.]], dtype=float32)
+    >>> z.asnumpy()
+    array([[ 0.,  1.]], dtype=float32)
+    >>> mx.nd.logical_or(x, 1).asnumpy()
+    array([[ 1.,  1.,  1.],
+           [ 1.,  1.,  1.]], dtype=float32)
+    >>> mx.nd.logical_or(x, y).asnumpy()
+    array([[ 1.,  1.,  1.],
+           [ 1.,  1.,  1.]], dtype=float32)
+    >>> mx.nd.logical_or(z, y).asnumpy()
+    array([[ 0.,  1.],
+           [ 1.,  1.]], dtype=float32)
+    """
+    # pylint: disable= no-member, protected-access
+    return _ufunc_helper(
+        lhs,
+        rhs,
+        op.broadcast_logical_or,
+        lambda x, y: 1 if x or y else 0,
+        _internal._logical_or_scalar,
+        None)
+    # pylint: enable= no-member, protected-access
+
+def logical_xor(lhs, rhs):
+    """Returns the result of element-wise **logical xor** comparison
+    operation with broadcasting.
+
+    For each element in input arrays, return 1(true) if lhs elements or rhs elements
+    are true, otherwise return 0(false).
+
+    Equivalent to ``bool(lhs) ^ bool(rhs)`` and ``mx.nd.broadcast_logical_xor(lhs, rhs)``.
+
+    .. note::
+
+       If the corresponding dimensions of two arrays have the same size or one of them has size 1,
+       then the arrays are broadcastable to a common shape.
+
+    Parameters
+    ----------
+    lhs : scalar or array
+        First input of the function.
+    rhs : scalar or array
+         Second input of the function. If ``lhs.shape != rhs.shape``, they must be
+        broadcastable to a common shape.
+
+    Returns
+    -------
+    NDArray
+        Output array of boolean values.
+
+    Examples
+    --------
+    >>> x = mx.nd.ones((2,3))
+    >>> y = mx.nd.arange(2).reshape((2,1))
+    >>> z = mx.nd.arange(2).reshape((1,2))
+    >>> x.asnumpy()
+    array([[ 1.,  1.,  1.],
+           [ 1.,  1.,  1.]], dtype=float32)
+    >>> y.asnumpy()
+    array([[ 0.],
+           [ 1.]], dtype=float32)
+    >>> z.asnumpy()
+    array([[ 0.,  1.]], dtype=float32)
+    >>> mx.nd.logical_xor(x, y).asnumpy()
+    array([[ 1.,  1.,  1.],
+           [ 0.,  0.,  0.]], dtype=float32)
+    """
+    # pylint: disable= no-member, protected-access
+    return _ufunc_helper(
+        lhs,
+        rhs,
+        op.broadcast_logical_xor,
+        lambda x, y: 1 if bool(x) ^ bool(y) else 0,
+        _internal._logical_xor_scalar,
+        None)
+    # pylint: enable= no-member, protected-access
 
 def true_divide(lhs, rhs):
 
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index 2f5dd97..19fa4f8 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -317,6 +317,12 @@ MXNET_BINARY_MATH_OP_NC(eq, a == b ? DType(1) : DType(0));
 
 MXNET_BINARY_MATH_OP_NC(ne, a != b ? DType(1) : DType(0));
 
+MXNET_BINARY_MATH_OP(logical_and, a && b ? DType(1) : DType(0));
+
+MXNET_BINARY_MATH_OP(logical_or, a || b ? DType(1) : DType(0));
+
+MXNET_BINARY_MATH_OP(logical_xor, (a || b) && !(a && b) ? DType(1) : DType(0));
+
 MXNET_UNARY_MATH_OP(square_root, math::sqrt(a));
 
 MXNET_UNARY_MATH_OP(square_root_grad, 0.5f / math::id(a));
diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc
index 47db78b..de3c742 100644
--- a/src/operator/operator_tune.cc
+++ b/src/operator/operator_tune.cc
@@ -342,6 +342,12 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ne);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ne);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::eq);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::eq);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_and);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_and);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_or);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_or);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_xor);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_xor);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient);  // NOLINT()
 IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<0>);  // NOLINT()
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc b/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc
index 31f34bb..3cb3ba3 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc
+++ b/src/operator/tensor/elemwise_binary_broadcast_op_logic.cc
@@ -137,5 +137,59 @@ Example::
 .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::le>)
 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
 
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_logical_and)
+.describe(R"code(Returns the result of element-wise **logical and** with broadcasting.
+
+Example::
+
+   x = [[ 1.,  1.,  1.],
+        [ 1.,  1.,  1.]]
+
+   y = [[ 0.],
+        [ 1.]]
+
+   broadcast_logical_and(x, y) = [[ 0.,  0.,  0.],
+                                  [ 1.,  1.,  1.]]
+
+)code" ADD_FILELINE)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::logical_and>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
+
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_logical_or)
+.describe(R"code(Returns the result of element-wise **logical or** with broadcasting.
+
+Example::
+
+   x = [[ 1.,  1.,  0.],
+        [ 1.,  1.,  0.]]
+
+   y = [[ 1.],
+        [ 0.]]
+
+   broadcast_logical_or(x, y) = [[ 1.,  1.,  1.],
+                                 [ 1.,  1.,  0.]]
+
+)code" ADD_FILELINE)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::logical_or>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
+
+MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_logical_xor)
+.describe(R"code(Returns the result of element-wise **logical xor** with broadcasting.
+
+Example::
+
+   x = [[ 1.,  1.,  0.],
+        [ 1.,  1.,  0.]]
+
+   y = [[ 1.],
+        [ 0.]]
+
+   broadcast_logical_xor(x, y) = [[ 0.,  0.,  1.],
+                                  [ 1.,  1.,  0.]]
+
+)code" ADD_FILELINE)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::logical_xor>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_logic.cu b/src/operator/tensor/elemwise_binary_broadcast_op_logic.cu
index 4e80ae9..d6b01aa 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op_logic.cu
+++ b/src/operator/tensor/elemwise_binary_broadcast_op_logic.cu
@@ -47,5 +47,14 @@ NNVM_REGISTER_OP(broadcast_lesser)
 NNVM_REGISTER_OP(broadcast_lesser_equal)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::le>);
 
+NNVM_REGISTER_OP(broadcast_logical_and)
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::logical_and>);
+
+NNVM_REGISTER_OP(broadcast_logical_or)
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::logical_or>);
+
+NNVM_REGISTER_OP(broadcast_logical_xor)
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::logical_xor>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/elemwise_binary_op_logic.cc b/src/operator/tensor/elemwise_binary_op_logic.cc
index 5d328b5..5c0b442 100644
--- a/src/operator/tensor/elemwise_binary_op_logic.cc
+++ b/src/operator/tensor/elemwise_binary_op_logic.cc
@@ -57,5 +57,20 @@ MXNET_OPERATOR_REGISTER_BINARY(_lesser_equal)
 .set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, mshadow_op::le>)
 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
 
+MXNET_OPERATOR_REGISTER_BINARY(_logical_and)
+.add_alias("_Logical_And")
+.set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, mshadow_op::logical_and>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
+
+MXNET_OPERATOR_REGISTER_BINARY(_logical_or)
+.add_alias("_Logical_Or")
+.set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, mshadow_op::logical_or>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
+
+MXNET_OPERATOR_REGISTER_BINARY(_logical_xor)
+.add_alias("_Logical_Xor")
+.set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, mshadow_op::logical_xor>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/elemwise_binary_op_logic.cu b/src/operator/tensor/elemwise_binary_op_logic.cu
index be5b722..456622d 100644
--- a/src/operator/tensor/elemwise_binary_op_logic.cu
+++ b/src/operator/tensor/elemwise_binary_op_logic.cu
@@ -45,5 +45,14 @@ NNVM_REGISTER_OP(_lesser)
 NNVM_REGISTER_OP(_lesser_equal)
 .set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<gpu, mshadow_op::le>);
 
+NNVM_REGISTER_OP(_logical_and)
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<gpu, mshadow_op::logical_and>);
+
+NNVM_REGISTER_OP(_logical_or)
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<gpu, mshadow_op::logical_or>);
+
+NNVM_REGISTER_OP(_logical_xor)
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<gpu, mshadow_op::logical_xor>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/elemwise_binary_scalar_op_logic.cc b/src/operator/tensor/elemwise_binary_scalar_op_logic.cc
index 61f1dd0..fafd840 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op_logic.cc
+++ b/src/operator/tensor/elemwise_binary_scalar_op_logic.cc
@@ -59,5 +59,20 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_lesser_equal_scalar)
 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
 .add_alias("_LesserEqualScalar");
 
+MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_logical_and_scalar)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::logical_and>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.add_alias("_LogicalAndScalar");
+
+MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_logical_or_scalar)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::logical_or>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.add_alias("_LogicalOrScalar");
+
+MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_logical_xor_scalar)
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::logical_xor>)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.add_alias("_LogicalXorScalar");
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/elemwise_binary_scalar_op_logic.cu b/src/operator/tensor/elemwise_binary_scalar_op_logic.cu
index 91bcaa8..92b01a0 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op_logic.cu
+++ b/src/operator/tensor/elemwise_binary_scalar_op_logic.cu
@@ -45,5 +45,14 @@ NNVM_REGISTER_OP(_lesser_scalar)
 NNVM_REGISTER_OP(_lesser_equal_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::le>);
 
+NNVM_REGISTER_OP(_logical_and_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::logical_and>);
+
+NNVM_REGISTER_OP(_logical_or_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::logical_or>);
+
+NNVM_REGISTER_OP(_logical_xor_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::logical_xor>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 6e5c908..838d8d8 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -1613,6 +1613,27 @@ def test_broadcast_binary_op():
         data = gen_broadcast_data(idx=200)
         check_bmaxmin_gradient(c, data[0], data[1], 0.001, 1e-2, 1e-3)
 
+    def test_band(a, b):
+        c = mx.sym.broadcast_logical_and(a, b)
+        check_binary_op_forward(c, lambda x, y: np.logical_and(x, y), gen_broadcast_data, mx_nd_func=mx.nd.logical_and)
+        # pass idx=200 to gen_broadcast_data so that generated ndarrays' sizes are not too big
+        data = gen_broadcast_data(idx=200)
+        check_bmaxmin_gradient(c, data[0], data[1], 0.001, 1e-2, 1e-3)
+
+    def test_bor(a, b):
+        c = mx.sym.broadcast_logical_or(a, b)
+        check_binary_op_forward(c, lambda x, y: np.logical_or(x, y), gen_broadcast_data, mx_nd_func=mx.nd.logical_or)
+        # pass idx=200 to gen_broadcast_data so that generated ndarrays' sizes are not too big
+        data = gen_broadcast_data(idx=200)
+        check_bmaxmin_gradient(c, data[0], data[1], 0.001, 1e-2, 1e-3)
+
+    def test_bxor(a, b):
+        c = mx.sym.broadcast_logical_xor(a, b)
+        check_binary_op_forward(c, lambda x, y: np.logical_xor(x, y), gen_broadcast_data, mx_nd_func=mx.nd.logical_xor)
+        # pass idx=200 to gen_broadcast_data so that generated ndarrays' sizes are not too big
+        data = gen_broadcast_data(idx=200)
+        check_bmaxmin_gradient(c, data[0], data[1], 0.001, 1e-2, 1e-3)
+
     test_bplus(a, b)
     test_bminus(a, b)
     test_bmul(a, b)
@@ -1623,7 +1644,9 @@ def test_broadcast_binary_op():
     test_bequal(a, b)
     test_bmax(a, b)
     test_bmin(a, b)
-
+    test_band(a, b)
+    test_bor(a, b)
+    test_bxor(a, b)
 
 @with_seed()
 def test_run_convolution_dilated_impulse_response(dil=(1,1), kernel_shape=(3,3), verbose=False):

-- 
To stop receiving notification emails like this one, please contact
haibin@apache.org.