You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/05/12 17:49:16 UTC
[incubator-mxnet] branch v1.x updated: [v1.x] ONNX export support
broadcast_not_equal (#20259)
This is an automated email from the ASF dual-hosted git repository.
zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 3de8641 [v1.x] ONNX export support broadcast_not_equal (#20259)
3de8641 is described below
commit 3de864163b0940fa679cc7c8bd1069245ddcc17d
Author: waytrue17 <52...@users.noreply.github.com>
AuthorDate: Wed May 12 10:45:59 2021 -0700
[v1.x] ONNX export support broadcast_not_equal (#20259)
* broadcast_not_equal
* increase node output number
Co-authored-by: Wei Chu <we...@amazon.com>
---
python/mxnet/onnx/mx2onnx/_export_onnx.py | 2 +-
.../_op_translations/_op_translations_opset12.py | 20 ++++++++++++++++++++
tests/python-pytest/onnx/test_operators.py | 11 +++++++++++
3 files changed, 32 insertions(+), 1 deletion(-)
diff --git a/python/mxnet/onnx/mx2onnx/_export_onnx.py b/python/mxnet/onnx/mx2onnx/_export_onnx.py
index 307095b..e3aa59e 100644
--- a/python/mxnet/onnx/mx2onnx/_export_onnx.py
+++ b/python/mxnet/onnx/mx2onnx/_export_onnx.py
@@ -352,7 +352,7 @@ class MXNetGraph(object):
)
if isinstance(converted, list):
# Collect all the node's output names
- node_possible_names = [name] + [name + str(i) for i in range(10)]
+ node_possible_names = [name] + [name + str(i) for i in range(100)]
node_output_names = []
# Collect all the graph's output names
graph_output_names = []
diff --git a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py
index 9415388..0d9e21b 100644
--- a/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py
+++ b/python/mxnet/onnx/mx2onnx/_op_translations/_op_translations_opset12.py
@@ -2315,6 +2315,26 @@ def convert_broadcast_equal(node, **kwargs):
return nodes
+@mx_op.register("broadcast_not_equal")
+def convert_broadcast_not_equal(node, **kwargs):
+ """Map MXNet's broadcast_not_equal operator attributes to onnx's Equal operator
+ and return the created node.
+ """
+ from onnx.helper import make_node
+ name, input_nodes, _ = get_inputs(node, kwargs)
+ input_dtypes = get_input_dtypes(node, kwargs)
+
+ dtype = input_dtypes[0]
+ dtype_t = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[dtype]
+
+ nodes = [
+ make_node("Equal", input_nodes, [name+"_equal"]),
+ make_node("Not", [name+"_equal"], [name+"_not"]),
+ make_node("Cast", [name+"_not"], [name], name=name, to=int(dtype_t))
+ ]
+ return nodes
+
+
@mx_op.register("broadcast_logical_and")
def convert_broadcast_logical_and(node, **kwargs):
"""Map MXNet's broadcast logical and operator attributes to onnx's And operator
diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py
index dc6d389..533d7e9 100644
--- a/tests/python-pytest/onnx/test_operators.py
+++ b/tests/python-pytest/onnx/test_operators.py
@@ -393,6 +393,17 @@ def test_onnx_export_broadcast_equal(tmp_path, dtype):
op_export_test('broadcast_equal', M, [x, y], tmp_path)
+@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64'])
+def test_onnx_export_broadcast_not_equal(tmp_path, dtype):
+ M = def_model('broadcast_not_equal')
+ x = mx.nd.zeros((4,5,6), dtype=dtype)
+ y = mx.nd.ones((4,5,6), dtype=dtype)
+ op_export_test('broadcast_not_equal', M, [x, y], tmp_path)
+ x1 = mx.nd.ones((4,5,6), dtype=dtype)
+ y1 = mx.nd.ones((5,6), dtype=dtype)
+ op_export_test('broadcast_not_equal', M, [x1, y1], tmp_path)
+
+
@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
def test_onnx_export_broadcast_minimum(tmp_path, dtype):
M = def_model('broadcast_minimum')