You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/09/29 12:51:07 UTC

[tvm] branch main updated: [Relay] Extend split for blocked ConvertLayout pass (#12886)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0d8c9cef72 [Relay] Extend split for blocked ConvertLayout pass (#12886)
0d8c9cef72 is described below

commit 0d8c9cef7212e62c18814f1632613fb04de6d290
Author: Andrey Malyshev <el...@gmail.com>
AuthorDate: Thu Sep 29 16:50:59 2022 +0400

    [Relay] Extend split for blocked ConvertLayout pass (#12886)
    
    * [Relay] Extend split for blocked ConvertLayout pass
    
    * Fix lint hits
    
    * Fix spelling
---
 src/relay/op/tensor/transform.cc                  | 24 ++++++++++-
 tests/python/relay/test_pass_convert_op_layout.py | 49 +++++++++++++++++++++++
 2 files changed, 72 insertions(+), 1 deletion(-)

diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index deb05e8877..985222307a 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -2982,10 +2982,32 @@ InferCorrectLayoutOutput SplitInferCorrectLayout(const Attrs& attrs,
 
   // If new_in_layouts are defined, this code tries to modify the layout.
   if (new_in_layouts.defined() && old_in_layouts.defined()) {
+    bool divisible = true;
     const auto& sp_dim = old_in_layouts[0][axis];
     auto new_index = new_in_layouts[0].IndexOf(sp_dim);
     param->axis = new_index;
-    ret = new_in_layouts[0];
+    int factor = new_in_layouts[0].FactorOf(sp_dim);
+    if (factor > 1) {
+      if (!param->indices_or_sections.as<IntImmNode>()) {
+        auto ios = Downcast<Array<Integer>>(param->indices_or_sections);
+        Array<Integer> new_ios;
+        for (const auto& v : ios) {
+          const IntImmNode* vint = v.as<IntImmNode>();
+          new_ios.push_back(vint->value / factor);
+          if (vint->value % factor) {
+            divisible = false;
+          }
+        }
+        if (divisible) {
+          param->indices_or_sections = new_ios;
+        }
+      }
+    }
+    if (divisible) {
+      ret = new_in_layouts[0];
+    } else {
+      ret = old_in_layouts[0];
+    }
   } else if (old_in_layouts.defined()) {
     ret = old_in_layouts[0];
   }
diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py
index 3d5af83b8c..223926a877 100644
--- a/tests/python/relay/test_pass_convert_op_layout.py
+++ b/tests/python/relay/test_pass_convert_op_layout.py
@@ -1760,9 +1760,58 @@ def test_conv_split_convert_layout():
 
         assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
+    def _test_conv_split_convert_layout_blocking():
+        def before():
+            x = relay.var("x", shape=(1, 512, 38, 38))
+            weight = relay.var("weight", shape=(512, 512, 3, 3))
+            y = relay.nn.conv2d(
+                x,
+                weight,
+                channels=512,
+                kernel_size=(3, 3),
+                data_layout="NCHW",
+                kernel_layout="OIHW",
+            )
+            y = relay.nn.relu(y)
+            y = relay.op.split(y, indices_or_sections=[256], axis=1).astuple()
+            a = relay.TupleGetItem(y, 0)
+            b = relay.TupleGetItem(y, 1)
+            out = relay.Tuple([a, b])
+            return relay.Function(analysis.free_vars(out), out)
+
+        def expected():
+            x = relay.var("x", shape=(1, 512, 38, 38))
+            weight = relay.var("weight", shape=(512, 512, 3, 3))
+            weight = relay.layout_transform(weight, "OIHW", "OIHW4o")
+            x = relay.layout_transform(x, "NCHW", "NCHW4c")
+            y = relay.op.nn.contrib_conv2d_nchwc(
+                x,
+                weight,
+                channels=512,
+                kernel_size=(3, 3),
+                padding=(0, 0),
+                data_layout="NCHW4c",
+                kernel_layout="OIHW4o",
+            )
+            y = relay.nn.relu(y)
+            y = relay.op.split(y, indices_or_sections=[64], axis=1).astuple()
+            a = relay.TupleGetItem(y, 0)
+            b = relay.TupleGetItem(y, 1)
+            a = relay.layout_transform(a, "NCHW4c", "NCHW")
+            b = relay.layout_transform(b, "NCHW4c", "NCHW")
+            out = relay.Tuple([a, b])
+            return relay.Function(analysis.free_vars(out), out)
+
+        a = before()
+        a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW4c", "OIHW4o"]}))
+        b = run_opt_pass(expected(), transform.InferType())
+
+        assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+
     _test_conv_split_convert_layout1()
     _test_conv_split_convert_layout2()
     _test_conv_split_convert_layout3()
+    _test_conv_split_convert_layout_blocking()
 
 
 def test_conv_strided_slice_axes_convert_layout():