You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/12/05 06:16:07 UTC
[tvm] branch main updated: Add recursive on loop with marked kUnrolled (#13536)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e7160d569a Add recursive on loop with marked kUnrolled (#13536)
e7160d569a is described below
commit e7160d569a19aa00b0fd605abd970d0e9ed8b1d0
Author: yin.changsheng <yi...@intellif.com>
AuthorDate: Mon Dec 5 14:15:56 2022 +0800
Add recursive on loop with marked kUnrolled (#13536)
Current LoopPartition pass, when the loop is marked kUnrolled, it returns directly
This PR enhance LoopPartition pass to continue recursive on loop with marked kUnrolled.
---
src/tir/transforms/loop_partition.cc | 3 +-
.../unittest/test_tir_transform_loop_partition.py | 69 ++++++++++++++++++++++
2 files changed, 71 insertions(+), 1 deletion(-)
diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc
index 1d995ef26e..0d08852669 100644
--- a/src/tir/transforms/loop_partition.cc
+++ b/src/tir/transforms/loop_partition.cc
@@ -597,7 +597,8 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim
if (!opt_cond_value.has_value()) {
if (has_partition_hint_ && unroll_loop_with_partition_hint_no_interval_ &&
analyzer_.CanProve(max - min > 0)) {
- return For(var, min, max - min + 1, ForKind::kUnrolled, body);
+ auto new_body = VisitAndMutate(body);
+ return For(var, min, max - min + 1, ForKind::kUnrolled, new_body);
}
return Stmt();
}
diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py
index fe48aa7d8f..7dd8e79410 100644
--- a/tests/python/unittest/test_tir_transform_loop_partition.py
+++ b/tests/python/unittest/test_tir_transform_loop_partition.py
@@ -677,6 +677,75 @@ def test_loop_partition_unroll_hint():
assert tvm.ir.structural_equal(mod["main"], partitioned_main)
+def test_loop_partition_recursive_unroll_hint():
+ @T.prim_func
+ def main():
+ placeholder_0_dm = T.decl_buffer([1, 32, 32, 16], dtype="int8")
+ for i3_0 in T.serial(5, annotations={"pragma_loop_partition_hint": 1}):
+ for i2_0 in T.serial(2, annotations={"pragma_loop_partition_hint": 1}):
+ pad_temp = T.decl_buffer([1, 16, 16, 16], dtype="int8")
+ for ax0, ax1, ax2 in T.grid(16, 16, 16):
+ if (
+ 6 <= i2_0 * 4 + ax0
+ and i2_0 * 4 + ax0 < 26
+ and 6 <= i3_0 * 4 + ax1
+ and i3_0 * 4 + ax1 < 26
+ ):
+ pad_temp[
+ 0,
+ i2_0 * 4 + ax0 - 6 + 6 - i2_0 * 4,
+ i3_0 * 4 + ax1 - 6 + 6 - i3_0 * 4,
+ ax2,
+ ] = placeholder_0_dm[
+ 0,
+ i2_0 * 4 + ax0 - 6 - -6,
+ i3_0 * 4 + ax1 - 6 - -6,
+ ax2,
+ ]
+
+ @T.prim_func
+ def partitioned_main():
+ placeholder_0_dm = T.allocate([16384], "int8", "global")
+ placeholder_0_dm_1 = T.buffer_decl([16384], dtype="int8", data=placeholder_0_dm)
+ for i3_0 in T.unroll(2):
+ for i2_0 in T.unroll(2):
+ pad_temp = T.allocate([4096], "int8", "global")
+ pad_temp_1 = T.buffer_decl([4096], dtype="int8", data=pad_temp)
+ for ax0, ax1, ax2 in T.grid(16, 16, 16):
+ if 6 <= i2_0 * 4 + ax0 and 6 <= i3_0 * 4 + ax1:
+ pad_temp_1[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[
+ i2_0 * 2048 + ax0 * 512 + i3_0 * 64 + ax1 * 16 + ax2
+ ]
+ for i2_0 in T.unroll(2):
+ pad_temp_2 = T.allocate([4096], "int8", "global")
+ pad_temp_3 = T.buffer_decl([4096], dtype="int8", data=pad_temp_2)
+ for ax0, ax1, ax2 in T.grid(16, 16, 16):
+ if 6 <= i2_0 * 4 + ax0:
+ pad_temp_3[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[
+ i2_0 * 2048 + ax0 * 512 + ax1 * 16 + ax2 + 128
+ ]
+ for i3_0 in T.unroll(2):
+ for i2_0 in T.unroll(2):
+ pad_temp_4 = T.allocate([4096], "int8", "global")
+ pad_temp_5 = T.buffer_decl([4096], dtype="int8", data=pad_temp_4)
+ for ax0, ax1, ax2 in T.grid(16, 16, 16):
+ if 6 <= i2_0 * 4 + ax0 and i3_0 * 4 + ax1 < 14:
+ pad_temp_5[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[
+ i2_0 * 2048 + ax0 * 512 + i3_0 * 64 + ax1 * 16 + ax2 + 192
+ ]
+
+ mod = partition_from_scheduled_tir(
+ main,
+ {
+ "tir.LoopPartition": {
+ "partition_const_loop": True,
+ "unroll_loop_with_partition_hint_no_interval": True,
+ }
+ },
+ )
+ assert tvm.ir.structural_equal(mod["main"], partitioned_main)
+
+
def test_loop_partition_keep_loop_annotations():
@T.prim_func
def before(A: T.Buffer[160, "int32"], B: T.Buffer[160, "int32"]) -> None: