You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/12/05 06:16:07 UTC

[tvm] branch main updated: Add recursive on loop with marked kUnrolled (#13536)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e7160d569a Add recursive on loop with marked kUnrolled (#13536)
e7160d569a is described below

commit e7160d569a19aa00b0fd605abd970d0e9ed8b1d0
Author: yin.changsheng <yi...@intellif.com>
AuthorDate: Mon Dec 5 14:15:56 2022 +0800

    Add recursive on loop with marked kUnrolled (#13536)
    
    Current LoopPartition pass, when the loop is marked kUnrolled, it returns directly
    This PR enhance LoopPartition pass to continue recursive on loop with marked kUnrolled.
---
 src/tir/transforms/loop_partition.cc               |  3 +-
 .../unittest/test_tir_transform_loop_partition.py  | 69 ++++++++++++++++++++++
 2 files changed, 71 insertions(+), 1 deletion(-)

diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc
index 1d995ef26e..0d08852669 100644
--- a/src/tir/transforms/loop_partition.cc
+++ b/src/tir/transforms/loop_partition.cc
@@ -597,7 +597,8 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim
   if (!opt_cond_value.has_value()) {
     if (has_partition_hint_ && unroll_loop_with_partition_hint_no_interval_ &&
         analyzer_.CanProve(max - min > 0)) {
-      return For(var, min, max - min + 1, ForKind::kUnrolled, body);
+      auto new_body = VisitAndMutate(body);
+      return For(var, min, max - min + 1, ForKind::kUnrolled, new_body);
     }
     return Stmt();
   }
diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py
index fe48aa7d8f..7dd8e79410 100644
--- a/tests/python/unittest/test_tir_transform_loop_partition.py
+++ b/tests/python/unittest/test_tir_transform_loop_partition.py
@@ -677,6 +677,75 @@ def test_loop_partition_unroll_hint():
     assert tvm.ir.structural_equal(mod["main"], partitioned_main)
 
 
+def test_loop_partition_recursive_unroll_hint():
+    @T.prim_func
+    def main():
+        placeholder_0_dm = T.decl_buffer([1, 32, 32, 16], dtype="int8")
+        for i3_0 in T.serial(5, annotations={"pragma_loop_partition_hint": 1}):
+            for i2_0 in T.serial(2, annotations={"pragma_loop_partition_hint": 1}):
+                pad_temp = T.decl_buffer([1, 16, 16, 16], dtype="int8")
+                for ax0, ax1, ax2 in T.grid(16, 16, 16):
+                    if (
+                        6 <= i2_0 * 4 + ax0
+                        and i2_0 * 4 + ax0 < 26
+                        and 6 <= i3_0 * 4 + ax1
+                        and i3_0 * 4 + ax1 < 26
+                    ):
+                        pad_temp[
+                            0,
+                            i2_0 * 4 + ax0 - 6 + 6 - i2_0 * 4,
+                            i3_0 * 4 + ax1 - 6 + 6 - i3_0 * 4,
+                            ax2,
+                        ] = placeholder_0_dm[
+                            0,
+                            i2_0 * 4 + ax0 - 6 - -6,
+                            i3_0 * 4 + ax1 - 6 - -6,
+                            ax2,
+                        ]
+
+    @T.prim_func
+    def partitioned_main():
+        placeholder_0_dm = T.allocate([16384], "int8", "global")
+        placeholder_0_dm_1 = T.buffer_decl([16384], dtype="int8", data=placeholder_0_dm)
+        for i3_0 in T.unroll(2):
+            for i2_0 in T.unroll(2):
+                pad_temp = T.allocate([4096], "int8", "global")
+                pad_temp_1 = T.buffer_decl([4096], dtype="int8", data=pad_temp)
+                for ax0, ax1, ax2 in T.grid(16, 16, 16):
+                    if 6 <= i2_0 * 4 + ax0 and 6 <= i3_0 * 4 + ax1:
+                        pad_temp_1[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[
+                            i2_0 * 2048 + ax0 * 512 + i3_0 * 64 + ax1 * 16 + ax2
+                        ]
+        for i2_0 in T.unroll(2):
+            pad_temp_2 = T.allocate([4096], "int8", "global")
+            pad_temp_3 = T.buffer_decl([4096], dtype="int8", data=pad_temp_2)
+            for ax0, ax1, ax2 in T.grid(16, 16, 16):
+                if 6 <= i2_0 * 4 + ax0:
+                    pad_temp_3[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[
+                        i2_0 * 2048 + ax0 * 512 + ax1 * 16 + ax2 + 128
+                    ]
+        for i3_0 in T.unroll(2):
+            for i2_0 in T.unroll(2):
+                pad_temp_4 = T.allocate([4096], "int8", "global")
+                pad_temp_5 = T.buffer_decl([4096], dtype="int8", data=pad_temp_4)
+                for ax0, ax1, ax2 in T.grid(16, 16, 16):
+                    if 6 <= i2_0 * 4 + ax0 and i3_0 * 4 + ax1 < 14:
+                        pad_temp_5[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[
+                            i2_0 * 2048 + ax0 * 512 + i3_0 * 64 + ax1 * 16 + ax2 + 192
+                        ]
+
+    mod = partition_from_scheduled_tir(
+        main,
+        {
+            "tir.LoopPartition": {
+                "partition_const_loop": True,
+                "unroll_loop_with_partition_hint_no_interval": True,
+            }
+        },
+    )
+    assert tvm.ir.structural_equal(mod["main"], partitioned_main)
+
+
 def test_loop_partition_keep_loop_annotations():
     @T.prim_func
     def before(A: T.Buffer[160, "int32"], B: T.Buffer[160, "int32"]) -> None: