You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/02/03 04:09:07 UTC
[incubator-mxnet] branch v1.x updated: [v1.x] Add ONNX export
support for equal_scalar operator (#19824)
This is an automated email from the ASF dual-hosted git repository.
zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 90836ad [v1.x] Add ONNX export support for equal_scalar operator (#19824)
90836ad is described below
commit 90836ad572313afcfbae3c5b735bbc59e2a95606
Author: Joe Evans <jo...@gmail.com>
AuthorDate: Tue Feb 2 20:07:23 2021 -0800
[v1.x] Add ONNX export support for equal_scalar operator (#19824)
* Allow axis to be an optional parameter to squeeze, since onnx supports it now.
* Add onnx export function for equal_scalar, add unit test.
* Use Constant instead of ConstantOfShape for scalar functions.
* Check for 'True' value for squeeze_axis.
Co-authored-by: Joe Evans <jo...@amazon.com>
---
.../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 61 ++++++++++++++++------
tests/python-pytest/onnx/test_operators.py | 12 +++++
2 files changed, 58 insertions(+), 15 deletions(-)
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 9e13b05..37fc542 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -1758,7 +1758,7 @@ def convert_slice_channel(node, **kwargs):
num_outputs = int(attrs.get("num_outputs"))
axis = int(attrs.get("axis", 1))
- squeeze_axis = int(attrs.get("squeeze_axis", 0))
+ squeeze_axis = int(attrs.get("squeeze_axis", 0) in [1, 'True'])
if squeeze_axis == 1 and num_outputs == 1:
node = onnx.helper.make_node(
@@ -1810,17 +1810,22 @@ def convert_squeeze(node, **kwargs):
axis = attrs.get("axis", None)
if not axis:
- raise AttributeError("Squeeze: Missing axis attribute: ONNX currently requires axis to "
- "be specified for squeeze operator")
- axis = convert_string_to_list(axis)
+ node = onnx.helper.make_node(
+ "Squeeze",
+ input_nodes,
+ [name],
+ name=name
+ )
+ else:
+ axis = convert_string_to_list(axis)
- node = onnx.helper.make_node(
- "Squeeze",
- input_nodes,
- [name],
- axes=axis,
- name=name,
- )
+ node = onnx.helper.make_node(
+ "Squeeze",
+ input_nodes,
+ [name],
+ axes=axis,
+ name=name,
+ )
return [node]
@@ -3141,8 +3146,7 @@ def convert_greater_scalar(node, **kwargs):
tensor_value = make_tensor(name+"_scalar", input_type, [1], [scalar])
nodes = [
- make_node("Shape", [input_nodes[0]], [name+"_shape"]),
- make_node("ConstantOfShape", [name+"_shape"], [name+"_rhs"], value=tensor_value),
+ make_node("Constant", [], [name+"_rhs"], value=tensor_value),
make_node("Greater", [input_nodes[0], name+"_rhs"], [name+"_gt"]),
make_node("Cast", [name+"_gt"], [name], to=input_type, name=name)
]
@@ -3171,14 +3175,41 @@ def convert_lesser_scalar(node, **kwargs):
tensor_value = make_tensor(name+"_scalar", input_type, [1], [scalar])
nodes = [
- make_node("Shape", [input_nodes[0]], [name+"_shape"]),
- make_node("ConstantOfShape", [name+"_shape"], [name+"_rhs"], value=tensor_value),
+ make_node("Constant", [], [name+"_rhs"], value=tensor_value),
make_node("Less", [input_nodes[0], name+"_rhs"], [name+"_lt"]),
make_node("Cast", [name+"_lt"], [name], to=input_type, name=name)
]
return nodes
+@mx_op.register("_equal_scalar")
+def convert_equal_scalar(node, **kwargs):
+ """Map MXNet's equal_scalar operator attributes to onnx.
+ """
+ from onnx.helper import make_node, make_tensor
+ name, input_nodes, attrs = get_inputs(node, kwargs)
+
+ scalar = float(attrs.get('scalar'))
+ input_type = kwargs['in_type']
+ dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type]
+
+ if str(dtype).startswith('int'):
+ scalar = int(scalar)
+ else:
+ if dtype == 'float16':
+ # when using float16, we must convert it to np.uint16 view first
+ # pylint: disable=too-many-function-args
+ scalar = np.float16(scalar).view(np.uint16)
+
+ tensor_value = make_tensor(name+"_scalar", input_type, [1], [scalar])
+ nodes = [
+ make_node("Constant", [], [name+"_rhs"], value=tensor_value),
+ make_node("Equal", [input_nodes[0], name+"_rhs"], [name+"_eq"]),
+ make_node("Cast", [name+"_eq"], [name], to=input_type, name=name)
+ ]
+ return nodes
+
+
@mx_op.register("where")
def convert_where(node, **kwargs):
"""Map MXNet's where operator attributes to onnx's Where
diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py
index 8d09d98..049f33e 100644
--- a/tests/python-pytest/onnx/test_operators.py
+++ b/tests/python-pytest/onnx/test_operators.py
@@ -512,6 +512,18 @@ def test_onnx_export_lesser_scalar(tmp_path, dtype, scalar):
@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"])
+@pytest.mark.parametrize("scalar", [0., 0.1, 0.5, 1., 5, 555.])
+def test_onnx_export_equal_scalar(tmp_path, dtype, scalar):
+ if 'int' in dtype:
+ scalar = int(scalar)
+ x = mx.nd.arange(0, 12, dtype=dtype).reshape((3, 4))
+ else:
+ x = mx.random.uniform(0, 9999, (5,10), dtype=dtype)
+ M = def_model('_internal._equal_scalar', scalar=scalar)
+ op_export_test('_internal._equal_scalar', M, [x], tmp_path)
+
+
+@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"])
@pytest.mark.parametrize("shape", [(1,1), (3,3), (10,2), (20,30,40)])
def test_onnx_export_where(tmp_path, dtype, shape):
M = def_model('where')