You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/02/22 03:49:18 UTC

[GitHub] [tvm] qsqqsqqsq opened a new pull request #10340: [TIR][Transform] relax LoopPartition restriction

qsqqsqqsq opened a new pull request #10340:
URL: https://github.com/apache/tvm/pull/10340


   Currently LoopPartition has a limitation that the intersection of all conditions on var can not be none. In fact we can partition the loop when the intersection of all conditions on var is none in some cases.  For example:
   ```python
   @T.prim_func
   def concat_func(
       placeholder: T.Buffer[(1, 64, 28, 28), "int8"],
       placeholder_1: T.Buffer[(1, 32, 28, 28), "int8"],
       placeholder_2: T.Buffer[(1, 32, 28, 28), "int8"],
       T_concat: T.Buffer[(1, 128, 28, 28), "int8"],
   ) -> None:
       for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}):
           for i2, i3 in T.grid(28, 28):
               if 96 <= i1:
                   T.store(
                       T_concat.data,
                       i1 * 784 + i2 * 28 + i3,
                       T.load("int8", placeholder_2.data, i1 * 784 + i2 * 28 + i3 - 75264),
                       True,
                   )
               if 64 <= i1 and i1 < 96:
                   T.store(
                       T_concat.data,
                       i1 * 784 + i2 * 28 + i3,
                       T.load("int8", placeholder_1.data, i1 * 784 + i2 * 28 + i3 - 50176),
                       True,
                   )
               if i1 < 64:
                   T.store(
                       T_concat.data,
                       i1 * 784 + i2 * 28 + i3,
                       T.load("int8", placeholder.data, i1 * 784 + i2 * 28 + i3),
                       True,
                   )
   ```
   This PR relax this restriction so that we can continue to try partition if  the intersection of only part of the conditions is not none.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Hzfengsy merged pull request #10340: [TIR][Transform] relax LoopPartition restriction

Posted by GitBox <gi...@apache.org>.
Hzfengsy merged pull request #10340:
URL: https://github.com/apache/tvm/pull/10340


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] qsqqsqqsq commented on a change in pull request #10340: [TIR][Transform] relax LoopPartition restriction

Posted by GitBox <gi...@apache.org>.
qsqqsqqsq commented on a change in pull request #10340:
URL: https://github.com/apache/tvm/pull/10340#discussion_r811609270



##########
File path: tests/python/unittest/test_tir_transform_loop_partition.py
##########
@@ -565,6 +566,78 @@ def test_explicit_partition_hint():
     assert tvm.ir.structural_equal(mod["main"], partitioned_concat)
 
 
+@T.prim_func
+def partitioned_concat_3(
+    placeholder: T.Buffer[(1, 64, 28, 28), "int8"],
+    placeholder_1: T.Buffer[(1, 32, 28, 28), "int8"],
+    placeholder_2: T.Buffer[(1, 32, 28, 28), "int8"],
+    T_concat: T.Buffer[(1, 128, 28, 28), "int8"],
+) -> None:
+    for i1, i2, i3 in T.grid(64, 28, 28):
+        T.store(
+            T_concat.data,
+            i1 * 784 + i2 * 28 + i3,
+            T.load("int8", placeholder.data, i1 * 784 + i2 * 28 + i3),
+            True,
+        )
+    for i1, i2, i3 in T.grid(32, 28, 28):
+        T.store(
+            T_concat.data,
+            i1 * 784 + i2 * 28 + i3 + 50176,
+            T.load("int8", placeholder_1.data, i1 * 784 + i2 * 28 + i3),
+            True,
+        )
+    for i1, i2, i3 in T.grid(32, 28, 28):
+        T.store(
+            T_concat.data,
+            i1 * 784 + i2 * 28 + i3 + 75264,
+            T.load("int8", placeholder_2.data, i1 * 784 + i2 * 28 + i3),
+            True,
+        )
+
+
+@T.prim_func
+def concat_func_3(
+    placeholder: T.Buffer[(1, 64, 28, 28), "int8"],
+    placeholder_1: T.Buffer[(1, 32, 28, 28), "int8"],
+    placeholder_2: T.Buffer[(1, 32, 28, 28), "int8"],
+    T_concat: T.Buffer[(1, 128, 28, 28), "int8"],
+) -> None:
+    for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}):
+        for i2, i3 in T.grid(28, 28):
+            if 96 <= i1:
+                T.store(
+                    T_concat.data,
+                    i1 * 784 + i2 * 28 + i3,
+                    T.load("int8", placeholder_2.data, i1 * 784 + i2 * 28 + i3 - 75264),
+                    True,
+                )
+            if 64 <= i1 and i1 < 96:
+                T.store(
+                    T_concat.data,
+                    i1 * 784 + i2 * 28 + i3,
+                    T.load("int8", placeholder_1.data, i1 * 784 + i2 * 28 + i3 - 50176),
+                    True,
+                )
+            if i1 < 64:
+                T.store(
+                    T_concat.data,
+                    i1 * 784 + i2 * 28 + i3,
+                    T.load("int8", placeholder.data, i1 * 784 + i2 * 28 + i3),
+                    True,
+                )
+
+
+def test_condition_mutually_exclusive():
+    mod = IRModule.from_expr(concat_func_3)
+    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
+        mod = tvm.tir.transform.FlattenBuffer()(mod)

Review comment:
       FlattenBuffer is used to transform "annotations={"pragma_loop_partition_hint":1}" to " T.attr(i1, "pragma_loop_partition_hint", 1)".  Python can't parse the flowing script so I use annotations instead.
   ```python
       T.attr(i1, "pragma_loop_partition_hint", 1)
       for i1, i2, i3 in T.grid(256, 28, 28):
          ...
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Hzfengsy commented on a change in pull request #10340: [TIR][Transform] relax LoopPartition restriction

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on a change in pull request #10340:
URL: https://github.com/apache/tvm/pull/10340#discussion_r811586894



##########
File path: tests/python/unittest/test_tir_transform_loop_partition.py
##########
@@ -565,6 +566,78 @@ def test_explicit_partition_hint():
     assert tvm.ir.structural_equal(mod["main"], partitioned_concat)
 
 
+@T.prim_func
+def partitioned_concat_3(
+    placeholder: T.Buffer[(1, 64, 28, 28), "int8"],
+    placeholder_1: T.Buffer[(1, 32, 28, 28), "int8"],
+    placeholder_2: T.Buffer[(1, 32, 28, 28), "int8"],
+    T_concat: T.Buffer[(1, 128, 28, 28), "int8"],
+) -> None:
+    for i1, i2, i3 in T.grid(64, 28, 28):
+        T.store(
+            T_concat.data,
+            i1 * 784 + i2 * 28 + i3,
+            T.load("int8", placeholder.data, i1 * 784 + i2 * 28 + i3),
+            True,
+        )
+    for i1, i2, i3 in T.grid(32, 28, 28):
+        T.store(
+            T_concat.data,
+            i1 * 784 + i2 * 28 + i3 + 50176,
+            T.load("int8", placeholder_1.data, i1 * 784 + i2 * 28 + i3),
+            True,
+        )
+    for i1, i2, i3 in T.grid(32, 28, 28):
+        T.store(
+            T_concat.data,
+            i1 * 784 + i2 * 28 + i3 + 75264,
+            T.load("int8", placeholder_2.data, i1 * 784 + i2 * 28 + i3),
+            True,
+        )
+
+
+@T.prim_func
+def concat_func_3(
+    placeholder: T.Buffer[(1, 64, 28, 28), "int8"],
+    placeholder_1: T.Buffer[(1, 32, 28, 28), "int8"],
+    placeholder_2: T.Buffer[(1, 32, 28, 28), "int8"],
+    T_concat: T.Buffer[(1, 128, 28, 28), "int8"],
+) -> None:
+    for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}):
+        for i2, i3 in T.grid(28, 28):
+            if 96 <= i1:
+                T.store(
+                    T_concat.data,
+                    i1 * 784 + i2 * 28 + i3,
+                    T.load("int8", placeholder_2.data, i1 * 784 + i2 * 28 + i3 - 75264),
+                    True,
+                )
+            if 64 <= i1 and i1 < 96:
+                T.store(
+                    T_concat.data,
+                    i1 * 784 + i2 * 28 + i3,
+                    T.load("int8", placeholder_1.data, i1 * 784 + i2 * 28 + i3 - 50176),
+                    True,
+                )
+            if i1 < 64:
+                T.store(
+                    T_concat.data,
+                    i1 * 784 + i2 * 28 + i3,
+                    T.load("int8", placeholder.data, i1 * 784 + i2 * 28 + i3),
+                    True,
+                )
+
+
+def test_condition_mutually_exclusive():
+    mod = IRModule.from_expr(concat_func_3)
+    with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}):
+        mod = tvm.tir.transform.FlattenBuffer()(mod)

Review comment:
       Do we really need FlattenBuffer in this testcase?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org