You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ro...@apache.org on 2019/01/26 00:06:22 UTC

[incubator-mxnet] branch master updated: ONNX export: Add Crop, Deconvolution and fix the default stride of Pooling to 1 (#12399)

This is an automated email from the ASF dual-hosted git repository.

roshrini pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 25e915b  ONNX export: Add Crop, Deconvolution and fix the default stride of Pooling to 1 (#12399)
25e915b is described below

commit 25e915bd401f7ac4c639c935f775deccebec96d3
Author: Przemyslaw Tredak <pt...@gmail.com>
AuthorDate: Fri Jan 25 16:04:56 2019 -0800

    ONNX export: Add Crop, Deconvolution and fix the default stride of Pooling to 1 (#12399)
    
    * Added Deconvolution and Crop to ONNX exporter
    
    * Added default for pool_type
---
 .../mxnet/contrib/onnx/mx2onnx/_op_translations.py | 66 +++++++++++++++++++++-
 tests/python-pytest/onnx/test_cases.py             |  3 +-
 2 files changed, 66 insertions(+), 3 deletions(-)

diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
index 51deb4f..8e3c46d 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
@@ -219,6 +219,68 @@ def convert_convolution(node, **kwargs):
     return [conv_node]
 
 
+@mx_op.register("Deconvolution")
+def convert_deconvolution(node, **kwargs):
+    """Map MXNet's deconvolution operator attributes to onnx's ConvTranspose operator
+    and return the created node.
+    """
+    name, inputs, attrs = get_inputs(node, kwargs)
+
+    kernel_dims = list(parse_helper(attrs, "kernel"))
+    stride_dims = list(parse_helper(attrs, "stride", [1, 1]))
+    pad_dims = list(parse_helper(attrs, "pad", [0, 0]))
+    num_group = int(attrs.get("num_group", 1))
+    dilations = list(parse_helper(attrs, "dilate", [1, 1]))
+    adj_dims = list(parse_helper(attrs, "adj", [0, 0]))
+
+    pad_dims = pad_dims + pad_dims
+
+    deconv_node = onnx.helper.make_node(
+        "ConvTranspose",
+        inputs=inputs,
+        outputs=[name],
+        kernel_shape=kernel_dims,
+        strides=stride_dims,
+        dilations=dilations,
+        output_padding=adj_dims,
+        pads=pad_dims,
+        group=num_group,
+        name=name
+    )
+
+    return [deconv_node]
+
+
+@mx_op.register("Crop")
+def convert_crop(node, **kwargs):
+    """Map MXNet's crop operator attributes to onnx's Crop operator
+    and return the created node.
+    """
+    name, inputs, attrs = get_inputs(node, kwargs)
+    num_inputs = len(inputs)
+
+    y, x = list(parse_helper(attrs, "offset", [0, 0]))
+    h, w = list(parse_helper(attrs, "h_w", [0, 0]))
+    if num_inputs > 1:
+        h, w = kwargs["out_shape"][-2:]
+    border = [x, y, x + w, y + h]
+
+    crop_node = onnx.helper.make_node(
+        "Crop",
+        inputs=[inputs[0]],
+        outputs=[name],
+        border=border,
+        scale=[1, 1],
+        name=name
+    )
+
+    logging.warning(
+        "Using an experimental ONNX operator: Crop. " \
+        "Its definition can change.")
+
+    return [crop_node]
+
+
 @mx_op.register("FullyConnected")
 def convert_fully_connected(node, **kwargs):
     """Map MXNet's FullyConnected operator attributes to onnx's Gemm operator
@@ -583,8 +645,8 @@ def convert_pooling(node, **kwargs):
     name, input_nodes, attrs = get_inputs(node, kwargs)
 
     kernel = eval(attrs["kernel"])
-    pool_type = attrs["pool_type"]
-    stride = eval(attrs["stride"]) if attrs.get("stride") else None
+    pool_type = attrs["pool_type"] if attrs.get("pool_type") else "max"
+    stride = eval(attrs["stride"]) if attrs.get("stride") else (1, 1)
     global_pool = get_boolean_attribute_value(attrs, "global_pool")
     p_value = attrs.get('p_value', 'None')
 
diff --git a/tests/python-pytest/onnx/test_cases.py b/tests/python-pytest/onnx/test_cases.py
index 6ec3709..b20db23 100644
--- a/tests/python-pytest/onnx/test_cases.py
+++ b/tests/python-pytest/onnx/test_cases.py
@@ -113,7 +113,8 @@ BASIC_MODEL_TESTS = {
              'test_Softmax',
              'test_softmax_functional',
              'test_softmax_lastdim',
-             ]
+             ],
+    'export': ['test_ConvTranspose2d']
 }
 
 STANDARD_MODEL = {