You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/11/10 06:05:37 UTC

[tvm] branch feature/2022-11-09/printer-explicit-ir-node updated: fix

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch feature/2022-11-09/printer-explicit-ir-node
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/feature/2022-11-09/printer-explicit-ir-node by this push:
     new 0bf774dd79 fix
0bf774dd79 is described below

commit 0bf774dd79bb17260bdb75ae298ae5f270ca2a2b
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Wed Nov 9 22:05:24 2022 -0800

    fix
---
 tests/python/unittest/test_tvmscript_roundtrip.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py
index dd6706762d..f22e61e183 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -90,9 +90,9 @@ def opt_gemm_lower():
         def mmult(A: T.handle, B: T.handle, C: T.handle) -> None:
             # function attr dict
             T.func_attr({"global_symbol": "mmult", "tir.noalias": True})
-            A_1 = T.match_buffer(A, [1024 * 1024], elem_offset=0, align=64, offset_factor=1)
+            A_1 = T.match_buffer(A, [16384], elem_offset=0, align=64, offset_factor=1)
             B_1 = T.match_buffer(B, [1024, 1024], elem_offset=0, align=64, offset_factor=1)
-            C_1 = T.match_buffer(C, [1024 * 1024], elem_offset=0, align=64, offset_factor=1)
+            C_1 = T.match_buffer(C, [16384], elem_offset=0, align=64, offset_factor=1)
             # body
             packedB_data = T.allocate([32768], "float32", "global")
             packedB = T.buffer_decl(
@@ -3008,7 +3008,7 @@ def comm_reducer_single_reduce_group():
     def comm_reducer_single_reduce_group(a: T.handle, b: T.handle) -> None:
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
         threadIdx_x = T.env_thread("threadIdx.x")
-        A = T.match_buffer(a, [128 * 128], dtype="float32")
+        A = T.match_buffer(a, [16384], dtype="float32")
         for i in T.serial(0, 128):
             T.launch_thread(threadIdx_x, 128)
             reduce_temp0_data = T.allocate([1], "float32", "local")
@@ -3024,7 +3024,7 @@ def comm_reducer_multiple_reduce_groups():
     def comm_reducer_multiple_reduce_groups(a: T.handle, b: T.handle) -> None:
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
         threadIdx_x = T.env_thread("threadIdx.x")
-        A = T.match_buffer(a, [128 * 128], dtype="float32")
+        A = T.match_buffer(a, [16384], dtype="float32")
         for i in T.serial(0, 128):
             T.launch_thread(threadIdx_x, 128)
             reduce_temp0_data = T.allocate([1], "float32", "local")