You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/06/18 06:57:02 UTC
[tvm] branch main updated: [Relay][Convert Layout] Enable layout
transformation for image.resize op (#8205)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0f4c065 [Relay][Convert Layout] Enable layout transformation for image.resize op (#8205)
0f4c065 is described below
commit 0f4c0654ef94c2252d0075e726b2c6589430d9d7
Author: Jorn Tuyls <jt...@users.noreply.github.com>
AuthorDate: Fri Jun 18 08:56:49 2021 +0200
[Relay][Convert Layout] Enable layout transformation for image.resize op (#8205)
* Enable layout transformation for image.resize op
* Change str map function to str and index retrieval
* Fix for pytorch frontend segmentation models test
---
python/tvm/relay/op/image/_image.py | 31 ++++++++
src/relay/op/image/resize.cc | 26 +++++++
tests/python/relay/test_pass_convert_op_layout.py | 86 +++++++++++++++++++++++
3 files changed, 143 insertions(+)
diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py
index 5b7fd32..2071a43 100644
--- a/python/tvm/relay/op/image/_image.py
+++ b/python/tvm/relay/op/image/_image.py
@@ -26,6 +26,7 @@ from tvm.topi.utils import get_const_tuple
from .. import op as reg
from .. import strategy
from ..op import OpPattern
+from .image import resize
# resize
@@ -58,6 +59,36 @@ def compute_resize(attrs, inputs, out_type):
reg.register_injective_schedule("image.resize")
+@reg.register_convert_op_layout("image.resize")
+def convert_image_resize(attrs, inputs, tinfos, desired_layouts):
+ """Convert Layout pass registration for image resize op.
+
+ Parameters
+ ----------
+ attrs : tvm.ir.Attrs
+ Attributes of current resize op
+ inputs : list of tvm.relay.Expr
+ The args of the Relay expr to be legalized
+ tinfos : list of types
+ List of input and output types
+ desired_layouts : list of layout strings
+ List of layouts defining our desired
+ layout for the data input.
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The transformed expr
+ """
+
+ new_attrs = dict(attrs)
+ assert len(desired_layouts) == 1, "Only one desired layout is expected"
+ desired_layout = str(desired_layouts[0])
+ assert desired_layout != "default", "Layout cannot be default"
+ new_attrs["layout"] = desired_layout
+ return resize(*inputs, **new_attrs)
+
+
@script
def _resize_shape_func(image_shape, size, batch_axis, height_axis, width_axis, channel_axis):
out = output_tensor((4,), "int64")
diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc
index 9c3d601..2c90d7b 100644
--- a/src/relay/op/image/resize.cc
+++ b/src/relay/op/image/resize.cc
@@ -33,6 +33,31 @@ namespace relay {
TVM_REGISTER_NODE_TYPE(ResizeAttrs);
+template <typename T>
+Array<Array<Layout> > ResizeInferCorrectLayout(const Attrs& attrs,
+ const Array<Layout>& new_in_layouts,
+ const Array<Layout>& old_in_layouts,
+ const Array<tvm::relay::Type>& old_in_types) {
+ // NOTE: Discard "const" qualifier here.
+ T* params = const_cast<T*>(attrs.as<T>());
+
+ if (new_in_layouts.defined()) {
+ ICHECK_EQ(new_in_layouts.size(), 1);
+
+ Layout raw_layout(params->layout);
+ Layout new_layout = new_in_layouts[0];
+ Layout old_layout = old_in_layouts[0];
+ if (!new_layout.Equals(old_layout) && raw_layout.Equals(old_layout) &&
+ new_layout->axes.size() == old_layout->axes.size()) {
+ // Follow input layout
+ params->layout = new_layout.name();
+ }
+ }
+
+ Layout inferred_layout(params->layout);
+ return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}};
+}
+
bool ResizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
ICHECK_EQ(types.size(), 2);
@@ -102,6 +127,7 @@ RELAY_REGISTER_OP("image.resize")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(5)
.add_type_rel("Resize", ResizeRel)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ResizeInferCorrectLayout<ResizeAttrs>)
.set_attr<TOpPattern>("TOpPattern", kInjective);
TVM_REGISTER_NODE_TYPE(Resize3dAttrs);
diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py
index 4710d50..88590c9 100644
--- a/tests/python/relay/test_pass_convert_op_layout.py
+++ b/tests/python/relay/test_pass_convert_op_layout.py
@@ -1797,6 +1797,90 @@ def test_conv_reduce_convert_layout():
_test_conv_reduce_convert_layout2()
+def test_image_resize_convert_layout():
+ def _test_image_resize_convert_layout_nchw_to_nhwc():
+ def before():
+ x = relay.var("x", shape=(1, 2, 4, 4))
+ y = relay.image.resize(x, (8, 8))
+ y = relay.Function([x], y)
+ return y
+
+ def expected():
+ x = relay.var("x", shape=(1, 2, 4, 4))
+ x = relay.layout_transform(x, "NCHW", "NHWC")
+ y = relay.image.resize(x, (8, 8), layout="NHWC")
+ y = relay.layout_transform(y, "NHWC", "NCHW")
+ y = relay.Function(relay.analysis.free_vars(y), y)
+ return y
+
+ a = before()
+ a = run_opt_pass(a, transform.ConvertLayout({"image.resize": ["NHWC"]}))
+ b = run_opt_pass(expected(), transform.InferType())
+
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+
+ def _test_image_resize_convert_layout_nhwc_to_nchw():
+ def before():
+ x = relay.var("x", shape=(1, 4, 4, 2))
+ y = relay.image.resize(x, (8, 8), layout="NHWC")
+ y = relay.Function([x], y)
+ return y
+
+ def expected():
+ x = relay.var("x", shape=(1, 4, 4, 2))
+ x = relay.layout_transform(x, "NHWC", "NCHW")
+ y = relay.image.resize(x, (8, 8), layout="NCHW")
+ y = relay.layout_transform(y, "NCHW", "NHWC")
+ y = relay.Function(relay.analysis.free_vars(y), y)
+ return y
+
+ a = before()
+ a = run_opt_pass(a, transform.ConvertLayout({"image.resize": ["NCHW"]}))
+ b = run_opt_pass(expected(), transform.InferType())
+
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+
+ _test_image_resize_convert_layout_nchw_to_nhwc()
+ _test_image_resize_convert_layout_nhwc_to_nchw()
+
+
+def test_conv_image_resize_convert_layout():
+ """Check that layout transforms are propagated through image resize."""
+
+ def before():
+ x = relay.var("x", shape=(1, 56, 56, 64))
+ weight = relay.var("weight", shape=(3, 3, 64, 64))
+ y = relay.nn.conv2d(
+ x,
+ weight,
+ channels=64,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout="NHWC",
+ kernel_layout="HWIO",
+ )
+ y = relay.image.resize(y, (112, 112), layout="NHWC")
+ y = relay.Function(analysis.free_vars(y), y)
+ return y
+
+ def expected():
+ x = relay.var("x", shape=(1, 56, 56, 64))
+ w = relay.var("weight", shape=(3, 3, 64, 64))
+ x = relay.layout_transform(x, "NHWC", "NCHW")
+ w = relay.layout_transform(w, "HWIO", "OIHW")
+ y = relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1))
+ y = relay.image.resize(y, (112, 112), layout="NCHW")
+ y = relay.layout_transform(y, "NCHW", "NHWC")
+ y = relay.Function(analysis.free_vars(y), y)
+ return y
+
+ a = before()
+ a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
+ b = run_opt_pass(expected(), transform.InferType())
+
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+
+
if __name__ == "__main__":
test_qnn_binary_no_convert_layout()
test_no_convert_layout()
@@ -1828,3 +1912,5 @@ if __name__ == "__main__":
test_conv_squeeze_convert_layout()
test_conv_reduce_convert_layout()
test_conv_strided_slice_axes_convert_layout()
+ test_image_resize_convert_layout()
+ test_conv_image_resize_convert_layout()