You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/06/18 06:57:02 UTC

[tvm] branch main updated: [Relay][Convert Layout] Enable layout transformation for image.resize op (#8205)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0f4c065  [Relay][Convert Layout] Enable layout transformation for image.resize op (#8205)
0f4c065 is described below

commit 0f4c0654ef94c2252d0075e726b2c6589430d9d7
Author: Jorn Tuyls <jt...@users.noreply.github.com>
AuthorDate: Fri Jun 18 08:56:49 2021 +0200

    [Relay][Convert Layout] Enable layout transformation for image.resize op (#8205)
    
    * Enable layout transformation for image.resize op
    
    * Change str map function to str and index retrieval
    
    * Fix for pytorch frontend segmentation models test
---
 python/tvm/relay/op/image/_image.py               | 31 ++++++++
 src/relay/op/image/resize.cc                      | 26 +++++++
 tests/python/relay/test_pass_convert_op_layout.py | 86 +++++++++++++++++++++++
 3 files changed, 143 insertions(+)

diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py
index 5b7fd32..2071a43 100644
--- a/python/tvm/relay/op/image/_image.py
+++ b/python/tvm/relay/op/image/_image.py
@@ -26,6 +26,7 @@ from tvm.topi.utils import get_const_tuple
 from .. import op as reg
 from .. import strategy
 from ..op import OpPattern
+from .image import resize
 
 
 # resize
@@ -58,6 +59,36 @@ def compute_resize(attrs, inputs, out_type):
 reg.register_injective_schedule("image.resize")
 
 
+@reg.register_convert_op_layout("image.resize")
+def convert_image_resize(attrs, inputs, tinfos, desired_layouts):
+    """Convert Layout pass registration for image resize op.
+
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current resize op
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    tinfos : list of types
+        List of input and output types
+    desired_layouts : list of layout strings
+        List of layouts defining our desired
+        layout for the data input.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The transformed expr
+    """
+
+    new_attrs = dict(attrs)
+    assert len(desired_layouts) == 1, "Only one desired layout is expected"
+    desired_layout = str(desired_layouts[0])
+    assert desired_layout != "default", "Layout cannot be default"
+    new_attrs["layout"] = desired_layout
+    return resize(*inputs, **new_attrs)
+
+
 @script
 def _resize_shape_func(image_shape, size, batch_axis, height_axis, width_axis, channel_axis):
     out = output_tensor((4,), "int64")
diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc
index 9c3d601..2c90d7b 100644
--- a/src/relay/op/image/resize.cc
+++ b/src/relay/op/image/resize.cc
@@ -33,6 +33,31 @@ namespace relay {
 
 TVM_REGISTER_NODE_TYPE(ResizeAttrs);
 
+template <typename T>
+Array<Array<Layout> > ResizeInferCorrectLayout(const Attrs& attrs,
+                                               const Array<Layout>& new_in_layouts,
+                                               const Array<Layout>& old_in_layouts,
+                                               const Array<tvm::relay::Type>& old_in_types) {
+  // NOTE: Discard "const" qualifier here.
+  T* params = const_cast<T*>(attrs.as<T>());
+
+  if (new_in_layouts.defined()) {
+    ICHECK_EQ(new_in_layouts.size(), 1);
+
+    Layout raw_layout(params->layout);
+    Layout new_layout = new_in_layouts[0];
+    Layout old_layout = old_in_layouts[0];
+    if (!new_layout.Equals(old_layout) && raw_layout.Equals(old_layout) &&
+        new_layout->axes.size() == old_layout->axes.size()) {
+      // Follow input layout
+      params->layout = new_layout.name();
+    }
+  }
+
+  Layout inferred_layout(params->layout);
+  return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}};
+}
+
 bool ResizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                const TypeReporter& reporter) {
   ICHECK_EQ(types.size(), 2);
@@ -102,6 +127,7 @@ RELAY_REGISTER_OP("image.resize")
     .add_argument("data", "Tensor", "The input tensor.")
     .set_support_level(5)
     .add_type_rel("Resize", ResizeRel)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ResizeInferCorrectLayout<ResizeAttrs>)
     .set_attr<TOpPattern>("TOpPattern", kInjective);
 
 TVM_REGISTER_NODE_TYPE(Resize3dAttrs);
diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py
index 4710d50..88590c9 100644
--- a/tests/python/relay/test_pass_convert_op_layout.py
+++ b/tests/python/relay/test_pass_convert_op_layout.py
@@ -1797,6 +1797,90 @@ def test_conv_reduce_convert_layout():
     _test_conv_reduce_convert_layout2()
 
 
+def test_image_resize_convert_layout():
+    def _test_image_resize_convert_layout_nchw_to_nhwc():
+        def before():
+            x = relay.var("x", shape=(1, 2, 4, 4))
+            y = relay.image.resize(x, (8, 8))
+            y = relay.Function([x], y)
+            return y
+
+        def expected():
+            x = relay.var("x", shape=(1, 2, 4, 4))
+            x = relay.layout_transform(x, "NCHW", "NHWC")
+            y = relay.image.resize(x, (8, 8), layout="NHWC")
+            y = relay.layout_transform(y, "NHWC", "NCHW")
+            y = relay.Function(relay.analysis.free_vars(y), y)
+            return y
+
+        a = before()
+        a = run_opt_pass(a, transform.ConvertLayout({"image.resize": ["NHWC"]}))
+        b = run_opt_pass(expected(), transform.InferType())
+
+        assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+
+    def _test_image_resize_convert_layout_nhwc_to_nchw():
+        def before():
+            x = relay.var("x", shape=(1, 4, 4, 2))
+            y = relay.image.resize(x, (8, 8), layout="NHWC")
+            y = relay.Function([x], y)
+            return y
+
+        def expected():
+            x = relay.var("x", shape=(1, 4, 4, 2))
+            x = relay.layout_transform(x, "NHWC", "NCHW")
+            y = relay.image.resize(x, (8, 8), layout="NCHW")
+            y = relay.layout_transform(y, "NCHW", "NHWC")
+            y = relay.Function(relay.analysis.free_vars(y), y)
+            return y
+
+        a = before()
+        a = run_opt_pass(a, transform.ConvertLayout({"image.resize": ["NCHW"]}))
+        b = run_opt_pass(expected(), transform.InferType())
+
+        assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+
+    _test_image_resize_convert_layout_nchw_to_nhwc()
+    _test_image_resize_convert_layout_nhwc_to_nchw()
+
+
+def test_conv_image_resize_convert_layout():
+    """Check that layout transforms are propagated through image resize."""
+
+    def before():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        weight = relay.var("weight", shape=(3, 3, 64, 64))
+        y = relay.nn.conv2d(
+            x,
+            weight,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
+        y = relay.image.resize(y, (112, 112), layout="NHWC")
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    def expected():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        w = relay.var("weight", shape=(3, 3, 64, 64))
+        x = relay.layout_transform(x, "NHWC", "NCHW")
+        w = relay.layout_transform(w, "HWIO", "OIHW")
+        y = relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1))
+        y = relay.image.resize(y, (112, 112), layout="NCHW")
+        y = relay.layout_transform(y, "NCHW", "NHWC")
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    a = before()
+    a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
+    b = run_opt_pass(expected(), transform.InferType())
+
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+
+
 if __name__ == "__main__":
     test_qnn_binary_no_convert_layout()
     test_no_convert_layout()
@@ -1828,3 +1912,5 @@ if __name__ == "__main__":
     test_conv_squeeze_convert_layout()
     test_conv_reduce_convert_layout()
     test_conv_strided_slice_axes_convert_layout()
+    test_image_resize_convert_layout()
+    test_conv_image_resize_convert_layout()