You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ro...@apache.org on 2019/01/26 00:06:22 UTC
[incubator-mxnet] branch master updated: ONNX export: Add Crop,
Deconvolution and fix the default stride of Pooling to 1 (#12399)
This is an automated email from the ASF dual-hosted git repository.
roshrini pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 25e915b ONNX export: Add Crop, Deconvolution and fix the default stride of Pooling to 1 (#12399)
25e915b is described below
commit 25e915bd401f7ac4c639c935f775deccebec96d3
Author: Przemyslaw Tredak <pt...@gmail.com>
AuthorDate: Fri Jan 25 16:04:56 2019 -0800
ONNX export: Add Crop, Deconvolution and fix the default stride of Pooling to 1 (#12399)
* Added Deconvolution and Crop to ONNX exporter
* Added default for pool_type
---
.../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 66 +++++++++++++++++++++-
tests/python-pytest/onnx/test_cases.py | 3 +-
2 files changed, 66 insertions(+), 3 deletions(-)
diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 51deb4f..8e3c46d 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -219,6 +219,68 @@ def convert_convolution(node, **kwargs):
return [conv_node]
+@mx_op.register("Deconvolution")
+def convert_deconvolution(node, **kwargs):
+ """Map MXNet's deconvolution operator attributes to onnx's ConvTranspose operator
+ and return the created node.
+ """
+ name, inputs, attrs = get_inputs(node, kwargs)
+
+ kernel_dims = list(parse_helper(attrs, "kernel"))
+ stride_dims = list(parse_helper(attrs, "stride", [1, 1]))
+ pad_dims = list(parse_helper(attrs, "pad", [0, 0]))
+ num_group = int(attrs.get("num_group", 1))
+ dilations = list(parse_helper(attrs, "dilate", [1, 1]))
+ adj_dims = list(parse_helper(attrs, "adj", [0, 0]))
+
+ pad_dims = pad_dims + pad_dims
+
+ deconv_node = onnx.helper.make_node(
+ "ConvTranspose",
+ inputs=inputs,
+ outputs=[name],
+ kernel_shape=kernel_dims,
+ strides=stride_dims,
+ dilations=dilations,
+ output_padding=adj_dims,
+ pads=pad_dims,
+ group=num_group,
+ name=name
+ )
+
+ return [deconv_node]
+
+
+@mx_op.register("Crop")
+def convert_crop(node, **kwargs):
+ """Map MXNet's crop operator attributes to onnx's Crop operator
+ and return the created node.
+ """
+ name, inputs, attrs = get_inputs(node, kwargs)
+ num_inputs = len(inputs)
+
+ y, x = list(parse_helper(attrs, "offset", [0, 0]))
+ h, w = list(parse_helper(attrs, "h_w", [0, 0]))
+ if num_inputs > 1:
+ h, w = kwargs["out_shape"][-2:]
+ border = [x, y, x + w, y + h]
+
+ crop_node = onnx.helper.make_node(
+ "Crop",
+ inputs=[inputs[0]],
+ outputs=[name],
+ border=border,
+ scale=[1, 1],
+ name=name
+ )
+
+ logging.warning(
+ "Using an experimental ONNX operator: Crop. " \
+ "Its definition can change.")
+
+ return [crop_node]
+
+
@mx_op.register("FullyConnected")
def convert_fully_connected(node, **kwargs):
"""Map MXNet's FullyConnected operator attributes to onnx's Gemm operator
@@ -583,8 +645,8 @@ def convert_pooling(node, **kwargs):
name, input_nodes, attrs = get_inputs(node, kwargs)
kernel = eval(attrs["kernel"])
- pool_type = attrs["pool_type"]
- stride = eval(attrs["stride"]) if attrs.get("stride") else None
+ pool_type = attrs["pool_type"] if attrs.get("pool_type") else "max"
+ stride = eval(attrs["stride"]) if attrs.get("stride") else (1, 1)
global_pool = get_boolean_attribute_value(attrs, "global_pool")
p_value = attrs.get('p_value', 'None')
diff --git a/tests/python-pytest/onnx/test_cases.py b/tests/python-pytest/onnx/test_cases.py
index 6ec3709..b20db23 100644
--- a/tests/python-pytest/onnx/test_cases.py
+++ b/tests/python-pytest/onnx/test_cases.py
@@ -113,7 +113,8 @@ BASIC_MODEL_TESTS = {
'test_Softmax',
'test_softmax_functional',
'test_softmax_lastdim',
- ]
+ ],
+ 'export': ['test_ConvTranspose2d']
}
STANDARD_MODEL = {