You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/02/01 01:11:42 UTC

[incubator-mxnet] branch v1.x updated: [v1.x]Onnx support for upsampling (#19795)

This is an automated email from the ASF dual-hosted git repository.

zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 76c8c7d  [v1.x]Onnx support for upsampling (#19795)
76c8c7d is described below

commit 76c8c7df009eb5eacd096fce966d69db46890de4
Author: Zhaoqi Zhu <zh...@gmail.com>
AuthorDate: Sun Jan 31 17:09:32 2021 -0800

    [v1.x]Onnx support for upsampling (#19795)
    
    * bare bone implementation
    
    * Update _op_translations.py
    
    * Update _op_translations.py
    
    * Update _op_translations.py
---
 .../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 29 ++++++++++++++++++++++
 tests/python-pytest/onnx/test_operators.py         |  9 +++++++
 2 files changed, 38 insertions(+)

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 843d5e2..3972332 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -3400,6 +3400,35 @@ def convert_gather_nd(node, **kwargs):
     return nodes
 
 
+@mx_op.register('UpSampling')
+def convert_upsampling(node, **kwargs):
+    """Map MXNet's UpSampling operator to onnx.
+    """
+    from onnx.helper import make_node
+    name, input_nodes, attrs = get_inputs(node, kwargs)
+
+    scale = int(attrs.get('scale', '1'))
+    sample_type = attrs.get('sample_type')
+    num_args = int(attrs.get('num_args', '1'))
+
+    if num_args > 1:
+        raise NotImplementedError('Upsampling conversion does not currently support num_args > 1')
+
+    if sample_type != 'nearest':
+        raise NotImplementedError('Upsampling conversion does not currently support \
+                                   sample_type != nearest')
+
+    nodes = [
+        create_tensor([], name+'_roi', kwargs['initializer'], dtype='float32'),
+        create_tensor([1, 1, scale, scale], name+'_scales', kwargs['initializer'],
+                      dtype='float32'),
+        make_node('Resize', [input_nodes[0], name+'_roi', name+'_scales'], [name], mode='nearest',
+                  coordinate_transformation_mode='half_pixel')
+        ]
+
+    return nodes
+
+
 @mx_op.register('SwapAxis')
 def convert_swapaxis(node, **kwargs):
     """Map MXNet's SwapAxis operator
diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py
index 038541a..204fe57 100644
--- a/tests/python-pytest/onnx/test_operators.py
+++ b/tests/python-pytest/onnx/test_operators.py
@@ -595,6 +595,15 @@ def test_onnx_export_gather_nd(tmp_path, dtype):
     op_export_test('gather_nd2', M2, [x2, y2], tmp_path)
 
 
+@pytest.mark.parametrize('dtype', ['float16', 'float32'])
+@pytest.mark.parametrize('shape', [(3, 4, 5, 6), (1, 1, 1, 1)])
+@pytest.mark.parametrize('scale', [1, 2, 3])
+def test_onnx_export_upsampling(tmp_path, dtype, shape, scale):
+    A = mx.random.uniform(0, 1, shape).astype(dtype)
+    M = def_model('UpSampling', scale=scale, sample_type='nearest', num_args=1)
+    op_export_test('UpSampling', M, [A], tmp_path)
+
+
 @pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
 @pytest.mark.parametrize('params', [((4, 5, 6), (0, 2)), ((4, 5, 6), (0, 1)),
                                     ((1, 2, 3, 4, 1), (0, 4)),