You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/02/06 06:31:18 UTC
[incubator-mxnet] branch v1.x updated: [v1.x] ONNX export fix
slice_axis (#19853)
This is an automated email from the ASF dual-hosted git repository.
zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 12d7624 [v1.x] ONNX export fix slice_axis (#19853)
12d7624 is described below
commit 12d762406727da219eb23687f9dcf9b0ae2495ac
Author: waytrue17 <52...@users.noreply.github.com>
AuthorDate: Fri Feb 5 22:29:24 2021 -0800
[v1.x] ONNX export fix slice_axis (#19853)
* fix slice_axis
* refactor code
* fix sanity
Co-authored-by: Wei Chu <we...@amazon.com>
---
.../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 32 ++++++++++++++--------
tests/python-pytest/onnx/test_operators.py | 2 ++
2 files changed, 22 insertions(+), 12 deletions(-)
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index ba433a8..3240077 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -1752,28 +1752,36 @@ def convert_slice_axis(node, **kwargs):
begin = int(attrs.get("begin"))
end = attrs.get("end", None)
- nodes = [
- create_tensor([axis], name+'_axis', kwargs["initializer"]),
- create_tensor([begin], name+'_begin', kwargs["initializer"])
- ]
+ nodes = []
+ create_tensor([axis], name+'_axis', kwargs["initializer"])
+ create_tensor([begin], name+'_begin', kwargs["initializer"])
if not end or end == 'None':
# ONNX doesn't support None for ends. Since ends=None depicts
# length of dimension, passing dimension in this case.
nodes += [
- create_tensor([axis+1], name+"_axis_plus_1", kwargs["initializer"]),
- make_node('Shape', [input_nodes[0]], [name+"_data_shape"]),
- make_node('Slice', [name+'_data_shape', name+'_axis', name+'_axis_plus_1'],
- [name+"_end"])
+ make_node('Shape', [input_nodes[0]], [name+"_data_shape"])
]
+ # corner case when end = None and axis = -1
+ if axis == -1:
+ create_tensor([-1], name+'_-1', kwargs["initializer"])
+ nodes += [
+ make_node('Shape', [name+'_data_shape'], [name+'_data_dim']),
+ make_node('Add', [name+'_data_dim', name+'_-1'], [name+'_axis_max']),
+ make_node('Slice', [name+'_data_shape', name+'_axis_max', name+'_data_dim'], [name+'_end']),
+ ]
+ else:
+ create_tensor([axis+1], name+"_axis_plus_1", kwargs["initializer"])
+ nodes += [
+ make_node('Slice', [name+'_data_shape', name+'_axis', name+'_axis_plus_1'],
+ [name+"_end"])
+ ]
else:
- nodes += [
- create_tensor([int(end)], name+'_end', kwargs["initializer"])
- ]
+ create_tensor([int(end)], name+'_end', kwargs["initializer"])
nodes += [
make_node('Slice', [input_nodes[0], name+'_begin', name+'_end', name+'_axis'],
[name], name=name)
- ]
+ ]
return nodes
diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py
index 7e7bacb..a35f9b6 100644
--- a/tests/python-pytest/onnx/test_operators.py
+++ b/tests/python-pytest/onnx/test_operators.py
@@ -201,9 +201,11 @@ def test_onnx_export_slice_axis(tmp_path, dtype):
M1 = def_model('slice_axis', axis=0, begin=1, end=3)
M2 = def_model('slice_axis', axis=0, begin=1, end=None)
M3 = def_model('slice_axis', axis=1, begin=-3, end=-1)
+ M4 = def_model('slice_axis', axis=-1, begin=-3, end=None)
op_export_test('slice_axis_1', M1, [x], tmp_path)
op_export_test('slice_axis_2', M2, [x], tmp_path)
op_export_test('slice_axis_3', M3, [x], tmp_path)
+ op_export_test('slice_axis_4', M4, [x], tmp_path)
@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64'])