You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/02/06 06:31:18 UTC

[incubator-mxnet] branch v1.x updated: [v1.x] ONNX export fix slice_axis (#19853)

This is an automated email from the ASF dual-hosted git repository.

zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 12d7624  [v1.x] ONNX export fix slice_axis (#19853)
12d7624 is described below

commit 12d762406727da219eb23687f9dcf9b0ae2495ac
Author: waytrue17 <52...@users.noreply.github.com>
AuthorDate: Fri Feb 5 22:29:24 2021 -0800

    [v1.x] ONNX export fix slice_axis (#19853)
    
    * fix slice_axis
    
    * refactor code
    
    * fix sanity
    
    Co-authored-by: Wei Chu <we...@amazon.com>
---
 .../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 32 ++++++++++++++--------
 tests/python-pytest/onnx/test_operators.py         |  2 ++
 2 files changed, 22 insertions(+), 12 deletions(-)

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index ba433a8..3240077 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -1752,28 +1752,36 @@ def convert_slice_axis(node, **kwargs):
     begin = int(attrs.get("begin"))
     end = attrs.get("end", None)
 
-    nodes = [
-        create_tensor([axis], name+'_axis', kwargs["initializer"]),
-        create_tensor([begin], name+'_begin', kwargs["initializer"])
-    ]
+    nodes = []
+    create_tensor([axis], name+'_axis', kwargs["initializer"])
+    create_tensor([begin], name+'_begin', kwargs["initializer"])
     if not end or end == 'None':
         # ONNX doesn't support None for ends. Since ends=None depicts
         # length of dimension, passing dimension in this case.
         nodes += [
-            create_tensor([axis+1], name+"_axis_plus_1", kwargs["initializer"]),
-            make_node('Shape', [input_nodes[0]], [name+"_data_shape"]),
-            make_node('Slice', [name+'_data_shape', name+'_axis', name+'_axis_plus_1'],
-                      [name+"_end"])
+            make_node('Shape', [input_nodes[0]], [name+"_data_shape"])
         ]
+        # corner case when end = None and axis = -1
+        if axis == -1:
+            create_tensor([-1], name+'_-1', kwargs["initializer"])
+            nodes += [
+                make_node('Shape', [name+'_data_shape'], [name+'_data_dim']),
+                make_node('Add', [name+'_data_dim', name+'_-1'], [name+'_axis_max']),
+                make_node('Slice', [name+'_data_shape', name+'_axis_max', name+'_data_dim'], [name+'_end']),
+            ]
+        else:
+            create_tensor([axis+1], name+"_axis_plus_1", kwargs["initializer"])
+            nodes += [
+                make_node('Slice', [name+'_data_shape', name+'_axis', name+'_axis_plus_1'],
+                          [name+"_end"])
+            ]
     else:
-        nodes += [
-            create_tensor([int(end)], name+'_end', kwargs["initializer"])
-        ]
+        create_tensor([int(end)], name+'_end', kwargs["initializer"])
 
     nodes += [
         make_node('Slice', [input_nodes[0], name+'_begin', name+'_end', name+'_axis'],
                   [name], name=name)
-    ]
+        ]
 
     return nodes
 
diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py
index 7e7bacb..a35f9b6 100644
--- a/tests/python-pytest/onnx/test_operators.py
+++ b/tests/python-pytest/onnx/test_operators.py
@@ -201,9 +201,11 @@ def test_onnx_export_slice_axis(tmp_path, dtype):
     M1 = def_model('slice_axis', axis=0, begin=1, end=3)
     M2 = def_model('slice_axis', axis=0, begin=1, end=None)
     M3 = def_model('slice_axis', axis=1, begin=-3, end=-1)
+    M4 = def_model('slice_axis', axis=-1, begin=-3, end=None)
     op_export_test('slice_axis_1', M1, [x], tmp_path)
     op_export_test('slice_axis_2', M2, [x], tmp_path)
     op_export_test('slice_axis_3', M3, [x], tmp_path)
+    op_export_test('slice_axis_4', M4, [x], tmp_path)
 
 
 @pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64'])