You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/09/05 03:40:43 UTC

[GitHub] [tvm] yincs-intellif commented on a diff in pull request #12631: [TIR] Add partition unroll hint pragma

yincs-intellif commented on code in PR #12631:
URL: https://github.com/apache/tvm/pull/12631#discussion_r962452026


##########
tests/python/unittest/test_tir_transform_loop_partition.py:
##########
@@ -619,6 +619,45 @@ def test_condition_mutually_exclusive():
     assert tvm.ir.structural_equal(mod["main"], partitioned_concat_3)
 
 
+def test_loop_partition_unroll_hint():
+    @T.prim_func
+    def main(A: T.Buffer[150528, "int8"], B: T.Buffer[25088, "int8"]) -> None:
+        T.preflattened_buffer(A, [1, 3, 224, 224], "int8", data=A.data)
+        T.preflattened_buffer(B, [1, 224, 7, 16], "int8", data=B.data)
+        for ax0 in T.serial(
+            112,
+            annotations={"pragma_loop_partition_hint": True, "pragma_partition_unroll_hint": True},
+        ):
+            for ax1, ax2, ax3 in T.grid(224, 7, 16):
+                if 3 <= ax0 * 2 + ax2 and ax0 * 2 + ax2 < 227 and ax3 < 3:
+                    B[ax1 * 112 + ax2 * 16 + ax3] = A[ax3 * 50176 + ax1 * 224 + ax0 * 2 + ax2 - 3]
+
+    @T.prim_func
+    def partitioned_main(A: T.Buffer[150528, "int8"], B: T.Buffer[25088, "int8"]) -> None:
+        T.preflattened_buffer(A, [1, 3, 224, 224], dtype="int8", data=A.data)
+        T.preflattened_buffer(B, [1, 224, 7, 16], dtype="int8", data=B.data)
+        for ax1, ax2, ax3 in T.grid(224, 7, 16):
+            if 3 <= ax2 and ax3 < 3:
+                B[ax1 * 112 + ax2 * 16 + ax3] = A[ax3 * 50176 + ax1 * 224 + ax2 - 3]
+        for ax1, ax2, ax3 in T.grid(224, 7, 16):
+            if 1 <= ax2 and ax3 < 3:
+                B[ax1 * 112 + ax2 * 16 + ax3] = A[ax3 * 50176 + ax1 * 224 + ax2 - 1]
+        for ax0, ax1, ax2, ax3 in T.grid(109, 224, 7, 16):
+            if ax3 < 3:
+                B[ax1 * 112 + ax2 * 16 + ax3] = A[ax3 * 50176 + ax1 * 224 + ax0 * 2 + ax2 + 1]
+        for ax1, ax2, ax3 in T.grid(224, 7, 16):
+            if ax2 < 5 and ax3 < 3:
+                B[ax1 * 112 + ax2 * 16 + ax3] = A[ax3 * 50176 + ax1 * 224 + ax2 + 219]
+
+    mod = tvm.ir.module.IRModule.from_expr(main)
+    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
+        mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
+        mod = tvm.tir.transform.FlattenBuffer()(mod)
+        mod = tvm.tir.transform.LoopPartition()(mod)
+        mod = tvm.tir.transform.Simplify()(mod)
+    assert tvm.ir.structural_equal(mod["main"], partitioned_main)
+
+
 if __name__ == "__main__":
     test_basic()

Review Comment:
   no problem



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org