You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/12/08 14:18:16 UTC

[tvm] branch main updated: [Relay][Frontend][Onnx] MaxUnpool Operator (#7036)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7a0d10c  [Relay][Frontend][Onnx] MaxUnpool Operator (#7036)
7a0d10c is described below

commit 7a0d10ce8d4df27734949b944992e5c516f620d0
Author: Josh Fromm <jw...@uw.edu>
AuthorDate: Tue Dec 8 06:18:01 2020 -0800

    [Relay][Frontend][Onnx] MaxUnpool Operator (#7036)
    
    * Added maxunpool test.
    
    * MaxUnpool implemented and tested.
    
    * Lint fix.
    
    * Add explicit output shape in tests.
---
 python/tvm/relay/frontend/onnx.py          | 57 ++++++++++++++++++++++++
 tests/python/frontend/onnx/test_forward.py | 69 ++++++++++++++++++++++++++++++
 2 files changed, 126 insertions(+)

diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index d65f567..0b6ebdb 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -622,6 +622,62 @@ class MaxPool(Pool):
     name = "max_pool"
 
 
+class MaxUnpool(OnnxOpConverter):
+    """Operator converter for MaxUnpool"""
+
+    @classmethod
+    def _impl_v11(cls, inputs, attr, params):
+        # Unpack inputs and attributes
+        data = inputs[0]
+        data_type = infer_type(data).checked_type.dtype
+        indices = inputs[1]
+        output_shape = inputs[2]
+        kernel_shape = attr.get("kernel_shape")
+        pads = attr.get("pads", None)
+        strides = attr.get("strides", [1] * len(kernel_shape))
+
+        # Compute the proper output shape before padding.
+        multiplier = _op.concatenate(
+            [_expr.const([1, 1], dtype="int64"), _expr.const(list(strides), dtype="int64")], axis=0
+        )
+        total_output_shape = multiplier * _op.shape_of(data, dtype="int64")
+        # Add extra dimensions from kernel size and stride mismatch
+        total_output_shape += _op.concatenate(
+            [_expr.const([0, 0], "int64"), _expr.const(list(kernel_shape), "int64")], axis=0
+        ) - _op.concatenate(
+            [_expr.const([0, 0], "int64"), _expr.const(list(strides), "int64")], axis=0
+        )
+
+        # Compute padding amount if output shape is specified.
+        if output_shape is not None:
+            total_output_shape = output_shape
+
+        elif pads is not None:
+            # Get pads in the proper format for relay.
+            pads = _op.concatenate(
+                [_expr.const([0, 0, 0, 0], "int64"), _expr.const(list(pads), "int64")], axis=0
+            )
+            pads = _op.reshape(pads, [-1, 2])
+            # Compute the total padding per axis.
+            total_pad = _op.sum(pads, axis=-1)
+            # Reversing maxpool means that padding actually makes our output smaller.
+            total_output_shape = total_output_shape - total_pad
+
+        # Create a tensor of zeros then scatter our data through it.
+        zeros_tensor = _op.zeros(total_output_shape, data_type)
+        # We need to flatten all our tensors before scattering.
+        flat_tensor = _op.scatter(
+            _op.reshape(zeros_tensor, [-1]),
+            _op.reshape(indices, [-1]),
+            _op.reshape(data, [-1]),
+            axis=0,
+        )
+        # Now reshape back to prepadded shape.
+        output_tensor = _op.reshape(flat_tensor, total_output_shape)
+
+        return output_tensor
+
+
 class LpPool(OnnxOpConverter):
     """A helper class for lppool op converters."""
 
@@ -2330,6 +2386,7 @@ def _get_convert_map(opset):
         "AveragePool": AveragePool.get_converter(opset),
         "LpPool": LpPool.get_converter(opset),
         "MaxPool": MaxPool.get_converter(opset),
+        "MaxUnpool": MaxUnpool.get_converter(opset),
         "Conv": Conv.get_converter(opset),
         "ConvTranspose": ConvTranspose.get_converter(opset),
         "GlobalAveragePool": Renamer("global_avg_pool2d"),
diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py
index 3ddc80a..1e0b729 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -3915,6 +3915,74 @@ def test_size():
     verify_size(input_data)
 
 
+@tvm.testing.uses_gpu
+def test_maxunpool():
+    def verify_maxunpool(data, indices, kernel_shape, strides, output_shape=None, pads=None):
+        input_names = ["xT", "xI"]
+        input_info = [
+            helper.make_tensor_value_info("xT", TensorProto.FLOAT, list(data.shape)),
+            helper.make_tensor_value_info("xI", TensorProto.INT64, list(indices.shape)),
+        ]
+        input_values = [data, indices]
+        if output_shape is not None:
+            input_names.append("output_shape")
+            input_info.append(
+                helper.make_tensor_value_info(
+                    "output_shape", TensorProto.INT64, list(output_shape.shape)
+                )
+            )
+            input_values.append(output_shape)
+        else:
+            # Compute expected output shape
+            output_shape = np.asarray(([1, 1] + list(strides))) * np.asarray(list(data.shape))
+            output_shape += np.asarray(([0, 0] + list(kernel_shape))) - np.asarray(
+                ([0, 0] + list(strides))
+            )
+            if pads is not None:
+                output_shape -= np.asarray(
+                    [0, 0] + list(np.sum(np.reshape(list(pads), [-1, 2]), axis=-1))
+                )
+        output_shape = [int(i) for i in output_shape]
+
+        node = helper.make_node(
+            "MaxUnpool", inputs=input_names, outputs=["y"], kernel_shape=kernel_shape
+        )
+
+        if pads is not None:
+            pad_attr = helper.make_attribute("pads", pads)
+            node.attribute.append(pad_attr)
+
+        if strides is not None:
+            strides_attr = helper.make_attribute("strides", strides)
+            node.attribute.append(strides_attr)
+
+        graph = helper.make_graph(
+            [node],
+            "maxunpool_test",
+            inputs=input_info,
+            outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, output_shape)],
+        )
+
+        model = helper.make_model(graph, producer_name="size_test")
+
+        verify_with_ort_with_inputs(model, input_values, use_vm=True, opset=11)
+
+    # Basic test
+    xT = np.array([[[[5, 6], [7, 8]]]], dtype=np.float32)
+    xI = np.array([[[[0, 7], [13, 15]]]], dtype=np.int64)
+    verify_maxunpool(xT, xI, [2, 2], strides=[2, 2])
+    # Small stride
+    verify_maxunpool(xT, xI, [2, 2], strides=[1, 1])
+    # Big kernel
+    verify_maxunpool(xT, xI, [3, 3], strides=[2, 2])
+    # With output shape
+    output_shape = np.array((1, 1, 5, 5), dtype=np.int64)
+    verify_maxunpool(xT, xI, [2, 2], strides=[2, 2], output_shape=output_shape)
+    # With explicit reverse padding
+    pads = np.asarray([1, 1, 1, 1]).astype(np.int64)
+    verify_maxunpool(xT, xI, [2, 2], strides=[2, 2], pads=pads)
+
+
 if __name__ == "__main__":
     test_flatten()
     test_reshape()
@@ -3992,3 +4060,4 @@ if __name__ == "__main__":
     test_range()
     test_loop()
     test_size()
+    test_maxunpool()