You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2023/10/02 03:22:38 UTC

[tvm] branch unity updated: [Unity][TIR][Unittest] Fix failed 0-rank rfactor test (#15846)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new d85ea256fa [Unity][TIR][Unittest] Fix failed 0-rank rfactor test (#15846)
d85ea256fa is described below

commit d85ea256fa0d15d721599e4631b99cdef1eec9b8
Author: Ruihang Lai <ru...@cs.cmu.edu>
AuthorDate: Sun Oct 1 23:22:29 2023 -0400

    [Unity][TIR][Unittest] Fix failed 0-rank rfactor test (#15846)
    
    This PR fixes the failed 0-rank rfactor test of TIR scheduling.
    
    The reason of the failure is dtype mismatch. To be specific, when we
    declare a 0-rank buffer in TVM Script using `T.Buffer(shape=())`, the
    buffer's `elem_offset` field will be set to `T.int64(0)`, since we are
    unable to infer a dtype from the shape (which is 0-rank), and therefore
    the dtype falls back to int64.
    
    So in this case, the declared 0-rank buffer before rfactor has
    `T.int64(0)` as `elem_offset`. Because the rfactor buffer is created
    from this buffer, the rfactor buffer also has `T.int64(0)` as
    `elem_offset`.
    
    On the other hand, the rfactor buffer has shape `(128,)` which is not
    empty. So the expected buffer written in TVMScript has `T.int32(0)`
    as its `elem_offset`. And this causes the mismatch.
    
    After some thinking, I think the good way of fix is just to update the
    expected TIR, explicitly marks the `elem_offset` field to be
    `T.int64(0)` to avoid the mismatch. The reason is because the dtype of
    `elem_offset` here will not affect the behavior of lowering/codegen,
    and it is difficult to determine int64/int32 from a zero-rank buffer
    without any other context.
---
 tests/python/unittest/test_tir_schedule_rfactor.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py b/tests/python/unittest/test_tir_schedule_rfactor.py
index 43374d3751..9856c08204 100644
--- a/tests/python/unittest/test_tir_schedule_rfactor.py
+++ b/tests/python/unittest/test_tir_schedule_rfactor.py
@@ -501,7 +501,7 @@ def rowsum_zero_dim(a: T.handle, b: T.handle) -> None:
 def rowsum_zero_dim_rfactor(a: T.handle, b: T.handle) -> None:
     A = T.match_buffer(a, [128])
     B = T.match_buffer(b, [])
-    B_rf = T.alloc_buffer([128])
+    B_rf = T.alloc_buffer([128], elem_offset=T.int64(0))
 
     for i in range(128):
         with T.block("B_rf"):