You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/11/24 20:24:07 UTC

[GitHub] [tvm] zxybazh commented on a diff in pull request #13481: [TIR] Fix an error when the result of compute_at has unit loop

zxybazh commented on code in PR #13481:
URL: https://github.com/apache/tvm/pull/13481#discussion_r1031806526


##########
tests/python/unittest/test_tir_schedule_compute_at.py:
##########
@@ -1505,5 +1505,56 @@ def main_reverse_compute_at(
     tvm.ir.assert_structural_equal(main_reverse_compute_at, sch.mod["main"])
 
 
+def test_reverse_compute_at_with_unit_loop():
+    @T.prim_func
+    def main(A: T.Buffer[(128, 128), "float32"], D: T.Buffer[(1, 2, 1), "float32"]) -> None:
+        B = T.alloc_buffer([128, 128], dtype="float32")
+        C = T.alloc_buffer([128, 128], dtype="float32")
+        for i_0, j_0, i_1 in T.grid(8, 8, 16):
+            for j_1 in T.serial(16):
+                with T.block("B"):
+                    vi = T.axis.spatial(128, i_0 * 16 + i_1)
+                    vj = T.axis.spatial(128, j_0 * 16 + j_1)
+                    T.reads(A[vi, vj])
+                    T.writes(B[vi, vj])
+                    B[vi, vj] = A[vi, vj] * T.float32(2)
+        for ax0, ax1, ax2 in T.grid(1, 2, 1):
+            with T.block("D"):
+                v0, v1, v2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                T.reads(B[v0, v1])
+                T.writes(D[v0, v1, v2])
+                D[v0, v1, v2] = B[v0, v1] + T.float32(1)
+
+    @T.prim_func
+    def main_reverse_compute_at(
+        A: T.Buffer[(128, 128), "float32"], D: T.Buffer[(1, 2, 1), "float32"]
+    ):
+        B = T.alloc_buffer([128, 128], dtype="float32")
+        C = T.alloc_buffer([128, 128], dtype="float32")
+        for i_0, j_0, i_1 in T.grid(8, 8, 16):
+            for j_1 in T.serial(16):
+                with T.block("B"):
+                    vi = T.axis.spatial(128, i_0 * 16 + i_1)
+                    vj = T.axis.spatial(128, j_0 * 16 + j_1)
+                    T.reads(A[vi, vj])
+                    T.writes(B[vi, vj])
+                    B[vi, vj] = A[vi, vj] * T.float32(2)
+            for ax0, ax1, ax2 in T.grid(1, 16, 1):
+                with T.block("D"):
+                    T.where(i_0 * 16 + i_1 < 1 and j_0 * 16 + ax1 < 2)
+                    v0 = T.axis.spatial(1, i_0 * 16 + i_1 + ax0)
+                    v1 = T.axis.spatial(2, j_0 * 16 + ax1)
+                    v2 = T.axis.spatial(1, ax2)
+                    T.reads(B[v0, v1])
+                    T.writes(D[v0, v1, v2])
+                    D[v0, v1, v2] = B[v0, v1] + T.float32(1)

Review Comment:
   ```suggestion
       @T.prim_func
       def main(A: T.Buffer[(128, 128), "float32"], D: T.Buffer[(1, 2, 1), "float32"]) -> None:
           B = T.alloc_buffer([128, 128], dtype="float32")
           for i_0, j_0, i_1 in T.grid(T.int64(8), T.int64(8), T.int64(16)):
               for j_1 in T.serial(T.int64(16)):
                   with T.block("B"):
                       vi = T.axis.spatial(T.int64(128), i_0 * T.int64(16) + i_1)
                       vj = T.axis.spatial(T.int64(128), j_0 * T.int64(16) + j_1)
                       T.reads(A[vi, vj])
                       T.writes(B[vi, vj])
                       B[vi, vj] = A[vi, vj] * T.float32(2)
           for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(2), T.int64(1)):
               with T.block("D"):
                   v0, v1, v2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                   T.reads(B[v0, v1])
                   T.writes(D[v0, v1, v2])
                   D[v0, v1, v2] = B[v0, v1] + T.float32(1)
   
       @T.prim_func
       def main_reverse_compute_at(
           A: T.Buffer[(128, 128), "float32"], D: T.Buffer[(1, 2, 1), "float32"]
       ):
           B = T.alloc_buffer([128, 128], dtype="float32")
           for i_0, j_0, i_1 in T.grid(T.int64(8), T.int64(8), T.int64(16)):
               for j_1 in T.serial(T.int64(16)):
                   with T.block("B"):
                       vi = T.axis.spatial(T.int64(128), i_0 * T.int64(16) + i_1)
                       vj = T.axis.spatial(T.int64(128), j_0 * T.int64(16) + j_1)
                       T.reads(A[vi, vj])
                       T.writes(B[vi, vj])
                       B[vi, vj] = A[vi, vj] * T.float32(2)
               for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(16), T.int64(1)):
                   with T.block("D"):
                       T.where(
                           i_0 * T.int64(16) + i_1 < T.int64(1)
                           and j_0 * T.int64(16) + ax1 < T.int64(2)
                       )
                       v0 = T.axis.spatial(T.int64(1), i_0 * T.int64(16) + i_1 + ax0)
                       v1 = T.axis.spatial(T.int64(2), j_0 * T.int64(16) + ax1)
                       v2 = T.axis.spatial(T.int64(1), ax2)
                       T.reads(B[v0, v1])
                       T.writes(D[v0, v1, v2])
                       D[v0, v1, v2] = B[v0, v1] + T.float32(1)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org