You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/02/22 06:21:36 UTC

[GitHub] [tvm] qsqqsqqsq commented on a change in pull request #10340: [TIR][Transform] relax LoopPartition restriction

qsqqsqqsq commented on a change in pull request #10340:
URL: https://github.com/apache/tvm/pull/10340#discussion_r811609270



##########
File path: tests/python/unittest/test_tir_transform_loop_partition.py
##########
@@ -565,6 +566,78 @@ def test_explicit_partition_hint():
     assert tvm.ir.structural_equal(mod["main"], partitioned_concat)
 
 
+@T.prim_func
+def partitioned_concat_3(
+    placeholder: T.Buffer[(1, 64, 28, 28), "int8"],
+    placeholder_1: T.Buffer[(1, 32, 28, 28), "int8"],
+    placeholder_2: T.Buffer[(1, 32, 28, 28), "int8"],
+    T_concat: T.Buffer[(1, 128, 28, 28), "int8"],
+) -> None:
+    for i1, i2, i3 in T.grid(64, 28, 28):
+        T.store(
+            T_concat.data,
+            i1 * 784 + i2 * 28 + i3,
+            T.load("int8", placeholder.data, i1 * 784 + i2 * 28 + i3),
+            True,
+        )
+    for i1, i2, i3 in T.grid(32, 28, 28):
+        T.store(
+            T_concat.data,
+            i1 * 784 + i2 * 28 + i3 + 50176,
+            T.load("int8", placeholder_1.data, i1 * 784 + i2 * 28 + i3),
+            True,
+        )
+    for i1, i2, i3 in T.grid(32, 28, 28):
+        T.store(
+            T_concat.data,
+            i1 * 784 + i2 * 28 + i3 + 75264,
+            T.load("int8", placeholder_2.data, i1 * 784 + i2 * 28 + i3),
+            True,
+        )
+
+
+@T.prim_func
+def concat_func_3(
+    placeholder: T.Buffer[(1, 64, 28, 28), "int8"],
+    placeholder_1: T.Buffer[(1, 32, 28, 28), "int8"],
+    placeholder_2: T.Buffer[(1, 32, 28, 28), "int8"],
+    T_concat: T.Buffer[(1, 128, 28, 28), "int8"],
+) -> None:
+    for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}):
+        for i2, i3 in T.grid(28, 28):
+            if 96 <= i1:
+                T.store(
+                    T_concat.data,
+                    i1 * 784 + i2 * 28 + i3,
+                    T.load("int8", placeholder_2.data, i1 * 784 + i2 * 28 + i3 - 75264),
+                    True,
+                )
+            if 64 <= i1 and i1 < 96:
+                T.store(
+                    T_concat.data,
+                    i1 * 784 + i2 * 28 + i3,
+                    T.load("int8", placeholder_1.data, i1 * 784 + i2 * 28 + i3 - 50176),
+                    True,
+                )
+            if i1 < 64:
+                T.store(
+                    T_concat.data,
+                    i1 * 784 + i2 * 28 + i3,
+                    T.load("int8", placeholder.data, i1 * 784 + i2 * 28 + i3),
+                    True,
+                )
+
+
+def test_condition_mutually_exclusive():
+    mod = IRModule.from_expr(concat_func_3)
+    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
+        mod = tvm.tir.transform.FlattenBuffer()(mod)

Review comment:
       FlattenBuffer is used to transform "annotations={"pragma_loop_partition_hint":1}" to " T.attr(i1, "pragma_loop_partition_hint", 1)".  Python can't parse the flowing script so I use annotations instead.
   ```python
       T.attr(i1, "pragma_loop_partition_hint", 1)
       for i1, i2, i3 in T.grid(256, 28, 28):
          ...
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org