You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/05/12 17:49:16 UTC

[incubator-mxnet] branch v1.x updated: [v1.x] ONNX export support broadcast_not_equal (#20259)

This is an automated email from the ASF dual-hosted git repository.

zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 3de8641  [v1.x] ONNX export support broadcast_not_equal  (#20259)
3de8641 is described below

commit 3de864163b0940fa679cc7c8bd1069245ddcc17d
Author: waytrue17 <52...@users.noreply.github.com>
AuthorDate: Wed May 12 10:45:59 2021 -0700

    [v1.x] ONNX export support broadcast_not_equal  (#20259)
    
    * broadcast_not_equal
    
    * increase node output number
    
    Co-authored-by: Wei Chu <we...@amazon.com>
---
 python/mxnet/onnx/mx2onnx/_export_onnx.py            |  2 +-
 .../_op_translations/_op_translations_opset12.py     | 20 ++++++++++++++++++++
 tests/python-pytest/onnx/test_operators.py           | 11 +++++++++++
 3 files changed, 32 insertions(+), 1 deletion(-)

diff --git a/python/mxnet/onnx/mx2onnx/_export_onnx.py b/python/mxnet/onnx/mx2onnx/_export_onnx.py
index 307095b..e3aa59e 100644
--- a/python/mxnet/onnx/mx2onnx/_export_onnx.py
+++ b/python/mxnet/onnx/mx2onnx/_export_onnx.py
@@ -352,7 +352,7 @@ class MXNetGraph(object):
                 )
             if isinstance(converted, list):
                 # Collect all the node's output names
-                node_possible_names = [name] + [name + str(i) for i in range(10)]
+                node_possible_names = [name] + [name + str(i) for i in range(100)]
                 node_output_names = []
                 # Collect all the graph's output names
                 graph_output_names = []
diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py
index 9415388..0d9e21b 100644
--- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py
+++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py
@@ -2315,6 +2315,26 @@ def convert_broadcast_equal(node, **kwargs):
     return nodes
 
 
+@mx_op.register("broadcast_not_equal")
+def convert_broadcast_not_equal(node, **kwargs):
+    """Map MXNet's broadcast_not_equal operator attributes to onnx's Equal operator
+    and return the created node.
+    """
+    from onnx.helper import make_node
+    name, input_nodes, _ = get_inputs(node, kwargs)
+    input_dtypes = get_input_dtypes(node, kwargs)
+
+    dtype = input_dtypes[0]
+    dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]
+
+    nodes = [
+        make_node("Equal", input_nodes, [name+"_equal"]),
+        make_node("Not", [name+"_equal"], [name+"_not"]),
+        make_node("Cast", [name+"_not"], [name], name=name, to=int(dtype_t))
+    ]
+    return nodes
+
+
 @mx_op.register("broadcast_logical_and")
 def convert_broadcast_logical_and(node, **kwargs):
     """Map MXNet's broadcast logical and operator attributes to onnx's And operator
diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py
index dc6d389..533d7e9 100644
--- a/tests/python-pytest/onnx/test_operators.py
+++ b/tests/python-pytest/onnx/test_operators.py
@@ -393,6 +393,17 @@ def test_onnx_export_broadcast_equal(tmp_path, dtype):
     op_export_test('broadcast_equal', M, [x, y], tmp_path)
 
 
+@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64'])
+def test_onnx_export_broadcast_not_equal(tmp_path, dtype):
+    M = def_model('broadcast_not_equal')
+    x = mx.nd.zeros((4,5,6), dtype=dtype)
+    y = mx.nd.ones((4,5,6), dtype=dtype)
+    op_export_test('broadcast_not_equal', M, [x, y], tmp_path)
+    x1 = mx.nd.ones((4,5,6), dtype=dtype)
+    y1 = mx.nd.ones((5,6), dtype=dtype)
+    op_export_test('broadcast_not_equal', M, [x1, y1], tmp_path)
+
+
 @pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
 def test_onnx_export_broadcast_minimum(tmp_path, dtype):
     M = def_model('broadcast_minimum')