You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/09/29 12:51:07 UTC
[tvm] branch main updated: [Relay] Extend split for blocked ConvertLayout pass (#12886)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0d8c9cef72 [Relay] Extend split for blocked ConvertLayout pass (#12886)
0d8c9cef72 is described below
commit 0d8c9cef7212e62c18814f1632613fb04de6d290
Author: Andrey Malyshev <el...@gmail.com>
AuthorDate: Thu Sep 29 16:50:59 2022 +0400
[Relay] Extend split for blocked ConvertLayout pass (#12886)
* [Relay] Extend split for blocked ConvertLayout pass
* Fix lint hits
* Fix spelling
---
src/relay/op/tensor/transform.cc | 24 ++++++++++-
tests/python/relay/test_pass_convert_op_layout.py | 49 +++++++++++++++++++++++
2 files changed, 72 insertions(+), 1 deletion(-)
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index deb05e8877..985222307a 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -2982,10 +2982,32 @@ InferCorrectLayoutOutput SplitInferCorrectLayout(const Attrs& attrs,
// If new_in_layouts are defined, this code tries to modify the layout.
if (new_in_layouts.defined() && old_in_layouts.defined()) {
+ bool divisible = true;
const auto& sp_dim = old_in_layouts[0][axis];
auto new_index = new_in_layouts[0].IndexOf(sp_dim);
param->axis = new_index;
- ret = new_in_layouts[0];
+ int factor = new_in_layouts[0].FactorOf(sp_dim);
+ if (factor > 1) {
+ if (!param->indices_or_sections.as<IntImmNode>()) {
+ auto ios = Downcast<Array<Integer>>(param->indices_or_sections);
+ Array<Integer> new_ios;
+ for (const auto& v : ios) {
+ const IntImmNode* vint = v.as<IntImmNode>();
+ new_ios.push_back(vint->value / factor);
+ if (vint->value % factor) {
+ divisible = false;
+ }
+ }
+ if (divisible) {
+ param->indices_or_sections = new_ios;
+ }
+ }
+ }
+ if (divisible) {
+ ret = new_in_layouts[0];
+ } else {
+ ret = old_in_layouts[0];
+ }
} else if (old_in_layouts.defined()) {
ret = old_in_layouts[0];
}
diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py
index 3d5af83b8c..223926a877 100644
--- a/tests/python/relay/test_pass_convert_op_layout.py
+++ b/tests/python/relay/test_pass_convert_op_layout.py
@@ -1760,9 +1760,58 @@ def test_conv_split_convert_layout():
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+ def _test_conv_split_convert_layout_blocking():
+ def before():
+ x = relay.var("x", shape=(1, 512, 38, 38))
+ weight = relay.var("weight", shape=(512, 512, 3, 3))
+ y = relay.nn.conv2d(
+ x,
+ weight,
+ channels=512,
+ kernel_size=(3, 3),
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ )
+ y = relay.nn.relu(y)
+ y = relay.op.split(y, indices_or_sections=[256], axis=1).astuple()
+ a = relay.TupleGetItem(y, 0)
+ b = relay.TupleGetItem(y, 1)
+ out = relay.Tuple([a, b])
+ return relay.Function(analysis.free_vars(out), out)
+
+ def expected():
+ x = relay.var("x", shape=(1, 512, 38, 38))
+ weight = relay.var("weight", shape=(512, 512, 3, 3))
+ weight = relay.layout_transform(weight, "OIHW", "OIHW4o")
+ x = relay.layout_transform(x, "NCHW", "NCHW4c")
+ y = relay.op.nn.contrib_conv2d_nchwc(
+ x,
+ weight,
+ channels=512,
+ kernel_size=(3, 3),
+ padding=(0, 0),
+ data_layout="NCHW4c",
+ kernel_layout="OIHW4o",
+ )
+ y = relay.nn.relu(y)
+ y = relay.op.split(y, indices_or_sections=[64], axis=1).astuple()
+ a = relay.TupleGetItem(y, 0)
+ b = relay.TupleGetItem(y, 1)
+ a = relay.layout_transform(a, "NCHW4c", "NCHW")
+ b = relay.layout_transform(b, "NCHW4c", "NCHW")
+ out = relay.Tuple([a, b])
+ return relay.Function(analysis.free_vars(out), out)
+
+ a = before()
+ a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW4c", "OIHW4o"]}))
+ b = run_opt_pass(expected(), transform.InferType())
+
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+
_test_conv_split_convert_layout1()
_test_conv_split_convert_layout2()
_test_conv_split_convert_layout3()
+ _test_conv_split_convert_layout_blocking()
def test_conv_strided_slice_axes_convert_layout():