You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/02/01 01:11:42 UTC
[incubator-mxnet] branch v1.x updated: [v1.x]Onnx support for
upsampling (#19795)
This is an automated email from the ASF dual-hosted git repository.
zha0q1 pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 76c8c7d [v1.x]Onnx support for upsampling (#19795)
76c8c7d is described below
commit 76c8c7df009eb5eacd096fce966d69db46890de4
Author: Zhaoqi Zhu <zh...@gmail.com>
AuthorDate: Sun Jan 31 17:09:32 2021 -0800
[v1.x]Onnx support for upsampling (#19795)
* bare bone implementation
* Update _op_translations.py
* Update _op_translations.py
* Update _op_translations.py
---
.../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 29 ++++++++++++++++++++++
tests/python-pytest/onnx/test_operators.py | 9 +++++++
2 files changed, 38 insertions(+)
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 843d5e2..3972332 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -3400,6 +3400,35 @@ def convert_gather_nd(node, **kwargs):
return nodes
+@mx_op.register('UpSampling')
+def convert_upsampling(node, **kwargs):
+ """Map MXNet's UpSampling operator to onnx.
+ """
+ from onnx.helper import make_node
+ name, input_nodes, attrs = get_inputs(node, kwargs)
+
+ scale = int(attrs.get('scale', '1'))
+ sample_type = attrs.get('sample_type')
+ num_args = int(attrs.get('num_args', '1'))
+
+ if num_args > 1:
+ raise NotImplementedError('Upsampling conversion does not currently support num_args > 1')
+
+ if sample_type != 'nearest':
+ raise NotImplementedError('Upsampling conversion does not currently support \
+ sample_type != nearest')
+
+ nodes = [
+ create_tensor([], name+'_roi', kwargs['initializer'], dtype='float32'),
+ create_tensor([1, 1, scale, scale], name+'_scales', kwargs['initializer'],
+ dtype='float32'),
+ make_node('Resize', [input_nodes[0], name+'_roi', name+'_scales'], [name], mode='nearest',
+ coordinate_transformation_mode='half_pixel')
+ ]
+
+ return nodes
+
+
@mx_op.register('SwapAxis')
def convert_swapaxis(node, **kwargs):
"""Map MXNet's SwapAxis operator
diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py
index 038541a..204fe57 100644
--- a/tests/python-pytest/onnx/test_operators.py
+++ b/tests/python-pytest/onnx/test_operators.py
@@ -595,6 +595,15 @@ def test_onnx_export_gather_nd(tmp_path, dtype):
op_export_test('gather_nd2', M2, [x2, y2], tmp_path)
+@pytest.mark.parametrize('dtype', ['float16', 'float32'])
+@pytest.mark.parametrize('shape', [(3, 4, 5, 6), (1, 1, 1, 1)])
+@pytest.mark.parametrize('scale', [1, 2, 3])
+def test_onnx_export_upsampling(tmp_path, dtype, shape, scale):
+ A = mx.random.uniform(0, 1, shape).astype(dtype)
+ M = def_model('UpSampling', scale=scale, sample_type='nearest', num_args=1)
+ op_export_test('UpSampling', M, [A], tmp_path)
+
+
@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
@pytest.mark.parametrize('params', [((4, 5, 6), (0, 2)), ((4, 5, 6), (0, 1)),
((1, 2, 3, 4, 1), (0, 4)),