You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/02/03 04:09:07 UTC

[incubator-mxnet] branch v1.x updated: [v1.x] Add ONNX export support for equal_scalar operator (#19824)

This is an automated email from the ASF dual-hosted git repository.

zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 90836ad  [v1.x] Add ONNX export support for equal_scalar operator (#19824)
90836ad is described below

commit 90836ad572313afcfbae3c5b735bbc59e2a95606
Author: Joe Evans <jo...@gmail.com>
AuthorDate: Tue Feb 2 20:07:23 2021 -0800

    [v1.x] Add ONNX export support for equal_scalar operator (#19824)
    
    * Allow axis to be an optional parameter to squeeze, since onnx supports it now.
    
    * Add onnx export function for equal_scalar, add unit test.
    
    * Use Constant instead of ConstantOfShape for scalar functions.
    
    * Check for 'True' value for squeeze_axis.
    
    Co-authored-by: Joe Evans <jo...@amazon.com>
---
 .../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 61 ++++++++++++++++------
 tests/python-pytest/onnx/test_operators.py         | 12 +++++
 2 files changed, 58 insertions(+), 15 deletions(-)

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 9e13b05..37fc542 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -1758,7 +1758,7 @@ def convert_slice_channel(node, **kwargs):
 
     num_outputs = int(attrs.get("num_outputs"))
     axis = int(attrs.get("axis", 1))
-    squeeze_axis = int(attrs.get("squeeze_axis", 0))
+    squeeze_axis = int(attrs.get("squeeze_axis", 0) in [1, 'True'])
 
     if squeeze_axis == 1 and num_outputs == 1:
         node = onnx.helper.make_node(
@@ -1810,17 +1810,22 @@ def convert_squeeze(node, **kwargs):
 
     axis = attrs.get("axis", None)
     if not axis:
-        raise AttributeError("Squeeze: Missing axis attribute: ONNX currently requires axis to "
-                             "be specified for squeeze operator")
-    axis = convert_string_to_list(axis)
+        node = onnx.helper.make_node(
+            "Squeeze",
+            input_nodes,
+            [name],
+            name=name
+        )
+    else:
+        axis = convert_string_to_list(axis)
 
-    node = onnx.helper.make_node(
-        "Squeeze",
-        input_nodes,
-        [name],
-        axes=axis,
-        name=name,
-    )
+        node = onnx.helper.make_node(
+            "Squeeze",
+            input_nodes,
+            [name],
+            axes=axis,
+            name=name,
+        )
     return [node]
 
 
@@ -3141,8 +3146,7 @@ def convert_greater_scalar(node, **kwargs):
 
     tensor_value = make_tensor(name+"_scalar", input_type, [1], [scalar])
     nodes = [
-        make_node("Shape", [input_nodes[0]], [name+"_shape"]),
-        make_node("ConstantOfShape", [name+"_shape"], [name+"_rhs"], value=tensor_value),
+        make_node("Constant", [], [name+"_rhs"], value=tensor_value),
         make_node("Greater", [input_nodes[0], name+"_rhs"], [name+"_gt"]),
         make_node("Cast", [name+"_gt"], [name], to=input_type, name=name)
     ]
@@ -3171,14 +3175,41 @@ def convert_lesser_scalar(node, **kwargs):
 
     tensor_value = make_tensor(name+"_scalar", input_type, [1], [scalar])
     nodes = [
-        make_node("Shape", [input_nodes[0]], [name+"_shape"]),
-        make_node("ConstantOfShape", [name+"_shape"], [name+"_rhs"], value=tensor_value),
+        make_node("Constant", [], [name+"_rhs"], value=tensor_value),
         make_node("Less", [input_nodes[0], name+"_rhs"], [name+"_lt"]),
         make_node("Cast", [name+"_lt"], [name], to=input_type, name=name)
     ]
     return nodes
 
 
+@mx_op.register("_equal_scalar")
+def convert_equal_scalar(node, **kwargs):
+    """Map MXNet's equal_scalar operator attributes to onnx.
+    """
+    from onnx.helper import make_node, make_tensor
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    scalar = float(attrs.get('scalar'))
+    input_type = kwargs['in_type']
+    dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type]
+
+    if str(dtype).startswith('int'):
+        scalar = int(scalar)
+    else:
+        if dtype == 'float16':
+            # when using float16, we must convert it to np.uint16 view first
+            # pylint: disable=too-many-function-args
+            scalar = np.float16(scalar).view(np.uint16)
+
+    tensor_value = make_tensor(name+"_scalar", input_type, [1], [scalar])
+    nodes = [
+        make_node("Constant", [], [name+"_rhs"], value=tensor_value),
+        make_node("Equal", [input_nodes[0], name+"_rhs"], [name+"_eq"]),
+        make_node("Cast", [name+"_eq"], [name], to=input_type, name=name)
+    ]
+    return nodes
+
+
 @mx_op.register("where")
 def convert_where(node, **kwargs):
     """Map MXNet's where operator attributes to onnx's Where
diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py
index 8d09d98..049f33e 100644
--- a/tests/python-pytest/onnx/test_operators.py
+++ b/tests/python-pytest/onnx/test_operators.py
@@ -512,6 +512,18 @@ def test_onnx_export_lesser_scalar(tmp_path, dtype, scalar):
 
 
 @pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"])
+@pytest.mark.parametrize("scalar", [0., 0.1, 0.5, 1., 5, 555.])
+def test_onnx_export_equal_scalar(tmp_path, dtype, scalar):
+    if 'int' in dtype:
+        scalar = int(scalar)
+        x = mx.nd.arange(0, 12, dtype=dtype).reshape((3, 4))
+    else:
+        x = mx.random.uniform(0, 9999, (5,10), dtype=dtype)
+    M = def_model('_internal._equal_scalar', scalar=scalar)
+    op_export_test('_internal._equal_scalar', M, [x], tmp_path)
+
+
+@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"])
 @pytest.mark.parametrize("shape", [(1,1), (3,3), (10,2), (20,30,40)])
 def test_onnx_export_where(tmp_path, dtype, shape):
     M = def_model('where')