You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/11/24 05:25:57 UTC

[GitHub] [tvm] yelite opened a new pull request, #13481: [TIR] Fix an error when the result of compute_at has unit loop

yelite opened a new pull request, #13481:
URL: https://github.com/apache/tvm/pull/13481

   This PR fixes an error when the result of `reverse_compute_at`/`compute_at` contains unit loop (extent is 1). This is caused by dtype mismatch when binding vars in arith analyzer. This breaks the tuning of detectron2 models.
   
   I also add a check in `Analyzer::Bind` to prevent similar problem. If this fails too many tests we can remove it from this PR.
   
   cc: @junrushao @zxybazh 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] zxybazh commented on pull request #13481: [TIR] Fix an error when the result of compute_at has unit loop

Posted by GitBox <gi...@apache.org>.
zxybazh commented on PR #13481:
URL: https://github.com/apache/tvm/pull/13481#issuecomment-1325982645

   While the change looks good to me, I found that the test case can be passed without this fix. Can you please validate if this regression test is good enough to reproduce the issue?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tvm-bot commented on pull request #13481: [TIR] Fix an error when the result of compute_at has unit loop

Posted by GitBox <gi...@apache.org>.
tvm-bot commented on PR #13481:
URL: https://github.com/apache/tvm/pull/13481#issuecomment-1325974219

   <!---bot-comment-->
   
   Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @-ing them in a comment.
   
   <!--bot-comment-ccs-start-->
    * cc @Hzfengsy, @junrushao <sub>See [#10317](https://github.com/apache/tvm/issues/10317) for details</sub><!--bot-comment-ccs-end-->
   
   <sub>Generated by [tvm-bot](https://github.com/apache/tvm/blob/main/ci/README.md#github-actions)</sub>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] zxybazh commented on a diff in pull request #13481: [TIR] Fix an error when the result of compute_at has unit loop

Posted by GitBox <gi...@apache.org>.
zxybazh commented on code in PR #13481:
URL: https://github.com/apache/tvm/pull/13481#discussion_r1031806526


##########
tests/python/unittest/test_tir_schedule_compute_at.py:
##########
@@ -1505,5 +1505,56 @@ def main_reverse_compute_at(
     tvm.ir.assert_structural_equal(main_reverse_compute_at, sch.mod["main"])
 
 
+def test_reverse_compute_at_with_unit_loop():
+    @T.prim_func
+    def main(A: T.Buffer[(128, 128), "float32"], D: T.Buffer[(1, 2, 1), "float32"]) -> None:
+        B = T.alloc_buffer([128, 128], dtype="float32")
+        C = T.alloc_buffer([128, 128], dtype="float32")
+        for i_0, j_0, i_1 in T.grid(8, 8, 16):
+            for j_1 in T.serial(16):
+                with T.block("B"):
+                    vi = T.axis.spatial(128, i_0 * 16 + i_1)
+                    vj = T.axis.spatial(128, j_0 * 16 + j_1)
+                    T.reads(A[vi, vj])
+                    T.writes(B[vi, vj])
+                    B[vi, vj] = A[vi, vj] * T.float32(2)
+        for ax0, ax1, ax2 in T.grid(1, 2, 1):
+            with T.block("D"):
+                v0, v1, v2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                T.reads(B[v0, v1])
+                T.writes(D[v0, v1, v2])
+                D[v0, v1, v2] = B[v0, v1] + T.float32(1)
+
+    @T.prim_func
+    def main_reverse_compute_at(
+        A: T.Buffer[(128, 128), "float32"], D: T.Buffer[(1, 2, 1), "float32"]
+    ):
+        B = T.alloc_buffer([128, 128], dtype="float32")
+        C = T.alloc_buffer([128, 128], dtype="float32")
+        for i_0, j_0, i_1 in T.grid(8, 8, 16):
+            for j_1 in T.serial(16):
+                with T.block("B"):
+                    vi = T.axis.spatial(128, i_0 * 16 + i_1)
+                    vj = T.axis.spatial(128, j_0 * 16 + j_1)
+                    T.reads(A[vi, vj])
+                    T.writes(B[vi, vj])
+                    B[vi, vj] = A[vi, vj] * T.float32(2)
+            for ax0, ax1, ax2 in T.grid(1, 16, 1):
+                with T.block("D"):
+                    T.where(i_0 * 16 + i_1 < 1 and j_0 * 16 + ax1 < 2)
+                    v0 = T.axis.spatial(1, i_0 * 16 + i_1 + ax0)
+                    v1 = T.axis.spatial(2, j_0 * 16 + ax1)
+                    v2 = T.axis.spatial(1, ax2)
+                    T.reads(B[v0, v1])
+                    T.writes(D[v0, v1, v2])
+                    D[v0, v1, v2] = B[v0, v1] + T.float32(1)

Review Comment:
   ```suggestion
       @T.prim_func
       def main(A: T.Buffer[(128, 128), "float32"], D: T.Buffer[(1, 2, 1), "float32"]) -> None:
           B = T.alloc_buffer([128, 128], dtype="float32")
           for i_0, j_0, i_1 in T.grid(T.int64(8), T.int64(8), T.int64(16)):
               for j_1 in T.serial(T.int64(16)):
                   with T.block("B"):
                       vi = T.axis.spatial(T.int64(128), i_0 * T.int64(16) + i_1)
                       vj = T.axis.spatial(T.int64(128), j_0 * T.int64(16) + j_1)
                       T.reads(A[vi, vj])
                       T.writes(B[vi, vj])
                       B[vi, vj] = A[vi, vj] * T.float32(2)
           for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(2), T.int64(1)):
               with T.block("D"):
                   v0, v1, v2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                   T.reads(B[v0, v1])
                   T.writes(D[v0, v1, v2])
                   D[v0, v1, v2] = B[v0, v1] + T.float32(1)
   
       @T.prim_func
       def main_reverse_compute_at(
           A: T.Buffer[(128, 128), "float32"], D: T.Buffer[(1, 2, 1), "float32"]
       ):
           B = T.alloc_buffer([128, 128], dtype="float32")
           for i_0, j_0, i_1 in T.grid(T.int64(8), T.int64(8), T.int64(16)):
               for j_1 in T.serial(T.int64(16)):
                   with T.block("B"):
                       vi = T.axis.spatial(T.int64(128), i_0 * T.int64(16) + i_1)
                       vj = T.axis.spatial(T.int64(128), j_0 * T.int64(16) + j_1)
                       T.reads(A[vi, vj])
                       T.writes(B[vi, vj])
                       B[vi, vj] = A[vi, vj] * T.float32(2)
               for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(16), T.int64(1)):
                   with T.block("D"):
                       T.where(
                           i_0 * T.int64(16) + i_1 < T.int64(1)
                           and j_0 * T.int64(16) + ax1 < T.int64(2)
                       )
                       v0 = T.axis.spatial(T.int64(1), i_0 * T.int64(16) + i_1 + ax0)
                       v1 = T.axis.spatial(T.int64(2), j_0 * T.int64(16) + ax1)
                       v2 = T.axis.spatial(T.int64(1), ax2)
                       T.reads(B[v0, v1])
                       T.writes(D[v0, v1, v2])
                       D[v0, v1, v2] = B[v0, v1] + T.float32(1)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] zxybazh merged pull request #13481: [TIR] Fix an error when the result of compute_at has unit loop

Posted by GitBox <gi...@apache.org>.
zxybazh merged PR #13481:
URL: https://github.com/apache/tvm/pull/13481


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yelite commented on pull request #13481: [TIR] Fix an error when the result of compute_at has unit loop

Posted by GitBox <gi...@apache.org>.
yelite commented on PR #13481:
URL: https://github.com/apache/tvm/pull/13481#issuecomment-1327004926

   > I changed the regression test a bit to reflect the fix's impact. Let me know if the new test make sense.
   
   This looks good. Thanks for adding this for me! 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org