You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/02/16 01:45:05 UTC

[incubator-mxnet] branch v1.x updated: [v1.x] ONNX export rewrite Take (#19851)

This is an automated email from the ASF dual-hosted git repository.

zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new c123c32  [v1.x] ONNX export rewrite Take (#19851)
c123c32 is described below

commit c123c32e32b3570b214f20d826e349d44604837a
Author: waytrue17 <52...@users.noreply.github.com>
AuthorDate: Mon Feb 15 17:43:28 2021 -0800

    [v1.x] ONNX export rewrite Take (#19851)
    
    * rewrite take
    
    * fix typo
    
    * add test for raise
    
    * fix test_raise
    
    Co-authored-by: Wei Chu <we...@amazon.com>
---
 .../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 67 +++++++++++++++++++---
 tests/python-pytest/onnx/test_operators.py         | 21 +++++++
 2 files changed, 79 insertions(+), 9 deletions(-)

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index c5e42a0..fe81af9 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -2393,18 +2393,67 @@ def convert_topk(node, **kwargs):
 def convert_take(node, **kwargs):
     """Map MXNet's Take operator attributes to onnx's Gather operator.
     """
+    from onnx.helper import make_node
+    from onnx import TensorProto
     name, input_nodes, attrs = get_inputs(node, kwargs)
-
     axis = int(attrs.get('axis', 0))
+    mode = str(attrs.get('mode', 'clip'))
 
-    node = onnx.helper.make_node(
-        "Gather",
-        input_nodes,
-        [name],
-        axis=axis,
-        name=name,
-    )
-    return [node]
+    data = input_nodes[0]
+    indices = input_nodes[1]
+
+    nodes = [
+        make_node('Cast', [indices], [name+'_indices'], to=int(TensorProto.INT64)),
+    ]
+
+    if mode == 'raise':
+        nodes += [
+            make_node('Gather', [data, name+'_indices'], [name], axis=axis, name=name)
+        ]
+
+        return nodes
+
+    nodes += [
+        create_tensor([-1], name+'_-1', kwargs["initializer"]),
+        make_node('Shape', [data], [name+'_data_shape']),
+    ]
+
+    # corner case
+    if axis == -1:
+        nodes += [
+            make_node('Shape', [name+'_data_shape'], [name+'_data_dim']),
+            make_node('Add', [name+'_data_dim', name+'_-1'], [name+'_axis_max']),
+            make_node('Slice', [name+'_data_shape', name+'_axis_max', name+'_data_dim'], [name+'_slice0_out']),
+        ]
+
+    else:
+        nodes += [
+            create_tensor([axis], name+'_axis', kwargs["initializer"]),
+            create_tensor([axis+1], name+'_axis+1', kwargs["initializer"]),
+            make_node('Slice', [name+'_data_shape', name+'_axis', name+'_axis+1'], [name+'_slice0_out']),
+        ]
+
+    if mode == 'clip':
+        nodes += [
+            create_tensor([0], name+'_0', kwargs["initializer"]),
+            make_node('Add', [name+'_slice0_out', name+'_-1'], [name+'_max']),
+            make_node('Greater', [name+'_indices', name+'_max'], [name+'_max_mask']),
+            make_node('Where', [name+'_max_mask', name+'_max', name+'_indices'], [name+'_where0_out']),
+            make_node('Less', [name+'_indices', name+'_0'], [name+'_min_mask']),
+            make_node('Where', [name+'_min_mask', name+'_0', name+'_where0_out'], [name+'_where1_out']),
+            make_node('Gather', [data, name+'_where1_out'], [name], axis=axis, name=name)
+        ]
+
+    elif mode == 'wrap':
+        nodes += [
+            make_node('Mod', [name+'_indices', name+'_slice0_out'], [name+'_mod0_out']),
+            make_node('Gather', [data, name+'_mod0_out'], [name], axis=axis, name=name)
+        ]
+
+    else:
+        raise NotImplementedError("mode must be clip, wrap or raise.")
+
+    return nodes
 
 
 @mx_op.register("LayerNorm")
diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py
index 90ec8f5..eb74630 100644
--- a/tests/python-pytest/onnx/test_operators.py
+++ b/tests/python-pytest/onnx/test_operators.py
@@ -1139,3 +1139,24 @@ def test_onnx_export_tile(tmp_path, dtype, reps):
     x = mx.nd.random.normal(0, 100, (5, 6)).astype(dtype)
     M = def_model('tile', reps=reps)
     op_export_test('tile', M, [x], tmp_path)
+
+
+@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
+@pytest.mark.parametrize('axis', [-3, -2, -1, 0, 1, 2])
+@pytest.mark.parametrize('mode', ['clip', 'wrap'])
+def test_onnx_export_take(tmp_path, dtype, axis, mode):
+    x = mx.nd.random.normal(0, 10, (3, 4, 5)).astype(dtype)
+    y = mx.random.randint(-100, 100, (6, 7)).astype(dtype)
+    M1 = def_model('take')
+    op_export_test('take1', M1, [x, y], tmp_path)
+    M2 = def_model('take', axis=axis, mode=mode)
+    op_export_test('take2', M2, [x, y], tmp_path)
+
+
+@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
+@pytest.mark.parametrize('axis', [-3, -2, -1, 0, 1, 2])
+def test_onnx_export_take_raise(tmp_path, dtype, axis):
+    x = mx.nd.random.normal(0, 10, (3, 4, 5)).astype(dtype)
+    y = mx.random.randint(0, 3, (6, 7)).astype(dtype)
+    M = def_model('take', axis=axis, mode='raise')
+    op_export_test('take', M, [x, y], tmp_path)
\ No newline at end of file