You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/12/03 04:30:37 UTC

[GitHub] nswamy closed pull request #12852: [MXNET-879] ONNX export: Logical operators

nswamy closed pull request #12852: [MXNET-879] ONNX export: Logical operators
URL: https://github.com/apache/incubator-mxnet/pull/12852
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index e2aab6b1efa..eb0108b42db 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -1587,3 +1587,35 @@ def convert_broadcast_equal(node, **kwargs):
     and return the created node.
     """
     return create_basic_op_node('Equal', node, kwargs)
+
+
+@mx_op.register("broadcast_logical_and")
+def convert_broadcast_logical_and(node, **kwargs):
+    """Map MXNet's broadcast logical and operator attributes to onnx's Add operator
+    and return the created node.
+    """
+    return create_basic_op_node('And', node, kwargs)
+
+
+@mx_op.register("broadcast_logical_or")
+def convert_broadcast_logical_or(node, **kwargs):
+    """Map MXNet's broadcast logical or operator attributes to onnx's Or operator
+    and return the created node.
+    """
+    return create_basic_op_node('Or', node, kwargs)
+
+
+@mx_op.register("broadcast_logical_xor")
+def convert_broadcast_logical_xor(node, **kwargs):
+    """Map MXNet's broadcast logical xor operator attributes to onnx's Xor operator
+    and return the created node.
+    """
+    return create_basic_op_node('Xor', node, kwargs)
+
+
+@mx_op.register("logical_not")
+def convert_logical_not(node, **kwargs):
+    """Map MXNet's logical not operator attributes to onnx's Not operator
+    and return the created node.
+    """
+    return create_basic_op_node('Not', node, kwargs)
diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py b/tests/python-pytest/onnx/export/mxnet_export_test.py
index 964d0e760ca..6b858f05e24 100644
--- a/tests/python-pytest/onnx/export/mxnet_export_test.py
+++ b/tests/python-pytest/onnx/export/mxnet_export_test.py
@@ -268,6 +268,45 @@ def test_ops(op_name, inputs, input_tensors, numpy_op):
     test_ops("Equal", input_data, input_tensor,
              np.equal(input_data[0], input_data[1]).astype(np.float32))
 
+
+def get_int_inputs(interval, shape):
+    """Helper to get integer input of given shape and range"""
+    assert len(interval) == len(shape)
+    inputs = []
+    input_tensors = []
+    for idx in range(len(interval)):
+        low, high = interval[idx]
+        inputs.append(np.random.randint(low, high, size=shape[idx]).astype("float32"))
+        input_tensors.append(helper.make_tensor_value_info("input"+str(idx+1),
+                                                        TensorProto.FLOAT, shape=shape[idx]))
+    return inputs, input_tensors
+
+
+@with_seed()
+def test_logical_ops():
+    """Test for logical and, or, not, xor operators"""
+    def test_ops(op_name, inputs, input_tensors, numpy_op):
+        outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=np.shape(inputs[0]))]
+        nodes = [helper.make_node(op_name, ["input"+str(i+1) for i in range(len(inputs))], ["output"])]
+        graph = helper.make_graph(nodes,
+                                  op_name + "_test",
+                                  input_tensors,
+                                  outputs)
+        model = helper.make_model(graph)
+        bkd_rep = backend.prepare(model)
+        output = bkd_rep.run(inputs)
+        npt.assert_almost_equal(output[0], numpy_op)
+    input_data, input_tensor = get_int_inputs([(0, 2), (0, 2)], [(3, 4, 5), (3, 4, 5)])
+    test_ops("And", input_data, input_tensor,
+             np.logical_and(input_data[0], input_data[1]).astype(np.float32))
+    test_ops("Or", input_data, input_tensor,
+             np.logical_or(input_data[0], input_data[1]).astype(np.float32))
+    test_ops("Xor", input_data, input_tensor,
+             np.logical_xor(input_data[0], input_data[1]).astype(np.float32))
+    test_ops("Not", [input_data[0]], [input_tensor[0]],
+             np.logical_not(input_data[0]).astype(np.float32))
+
+
 def _assert_sym_equal(lhs, rhs):
     assert lhs.list_inputs() == rhs.list_inputs()  # input names must be identical
     assert len(lhs.list_outputs()) == len(rhs.list_outputs())  # number of outputs must be identical
diff --git a/tests/python-pytest/onnx/import/test_cases.py b/tests/python-pytest/onnx/import/test_cases.py
index aed68ffa114..f41fe92352d 100644
--- a/tests/python-pytest/onnx/import/test_cases.py
+++ b/tests/python-pytest/onnx/import/test_cases.py
@@ -55,7 +55,6 @@
     'test_argmax',
     'test_argmin',
     'test_min',
-    'test_logical_',
     # enabling partial test cases for matmul
     'test_matmul_3d',
     'test_matmul_4d',


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services