You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/10/02 03:22:38 UTC
[tvm] branch unity updated: [Unity][TIR][Unittest] Fix failed 0-rank rfactor test (#15846)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new d85ea256fa [Unity][TIR][Unittest] Fix failed 0-rank rfactor test (#15846)
d85ea256fa is described below
commit d85ea256fa0d15d721599e4631b99cdef1eec9b8
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Sun Oct 1 23:22:29 2023 -0400
[Unity][TIR][Unittest] Fix failed 0-rank rfactor test (#15846)
This PR fixes the failed 0-rank rfactor test of TIR scheduling.
The reason of the failure is dtype mismatch. To be specific, when we
declare a 0-rank buffer in TVM Script using `T.Buffer(shape=())`, the
buffer's `elem_offset` field will be set to `T.int64(0)`, since we are
unable to infer a dtype from the shape (which is 0-rank), and therefore
the dtype falls back to int64.
So in this case, the declared 0-rank buffer before rfactor has
`T.int64(0)` as `elem_offset`. Because the rfactor buffer is created
from this buffer, the rfactor buffer also has `T.int64(0)` as
`elem_offset`.
On the other hand, the rfactor buffer has shape `(128,)` which is not
empty. So the expected buffer written in TVMScript has `T.int32(0)`
as its `elem_offset`. And this causes the mismatch.
After some thinking, I think the good way of fix is just to update the
expected TIR, explicitly marks the `elem_offset` field to be
`T.int64(0)` to avoid the mismatch. The reason is because the dtype of
`elem_offset` here will not affect the behavior of lowering/codegen,
and it is difficult to determine int64/int32 from a zero-rank buffer
without any other context.
---
tests/python/unittest/test_tir_schedule_rfactor.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py b/tests/python/unittest/test_tir_schedule_rfactor.py
index 43374d3751..9856c08204 100644
--- a/tests/python/unittest/test_tir_schedule_rfactor.py
+++ b/tests/python/unittest/test_tir_schedule_rfactor.py
@@ -501,7 +501,7 @@ def rowsum_zero_dim(a: T.handle, b: T.handle) -> None:
def rowsum_zero_dim_rfactor(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [128])
B = T.match_buffer(b, [])
- B_rf = T.alloc_buffer([128])
+ B_rf = T.alloc_buffer([128], elem_offset=T.int64(0))
for i in range(128):
with T.block("B_rf"):