You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wr...@apache.org on 2023/10/30 03:29:36 UTC

(tvm) branch main updated: [Fix][TIR]fix symbolic strides lower (#16000)

This is an automated email from the ASF dual-hosted git repository.

wrongtest pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 57597f62b4 [Fix][TIR]fix symbolic strides lower (#16000)
57597f62b4 is described below

commit 57597f62b44f0ab17adae34b89f5f526b816b759
Author: Wei Tao <11...@qq.com>
AuthorDate: Mon Oct 30 11:29:30 2023 +0800

    [Fix][TIR]fix symbolic strides lower (#16000)
    
    * [Fix][TIR]fix symbolic strides lower
    
    * [Fix][TIR] run the black formatter
---
 src/tir/transforms/ir_utils.cc                     |  3 +-
 .../test_tir_transform_lower_opaque_block.py       | 48 ++++++++++++++++++++++
 2 files changed, 50 insertions(+), 1 deletion(-)

diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index 99ed437659..25c10dd682 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -417,7 +417,8 @@ Array<PrimExpr> GetBufferAllocationShape(const Buffer& buffer) {
   if (buffer->strides.size()) {
     ICHECK_EQ(buffer->shape.size(), buffer->strides.size());
     for (size_t i = buffer->strides.size() - 1; i > 0; --i) {
-      ICHECK(is_zero(floormod(buffer->strides[i - 1], buffer->strides[i])));
+      ICHECK(
+          arith::Analyzer().CanProveEqual(floormod(buffer->strides[i - 1], buffer->strides[i]), 0));
       alloc_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]);
     }
   }
diff --git a/tests/python/unittest/test_tir_transform_lower_opaque_block.py b/tests/python/unittest/test_tir_transform_lower_opaque_block.py
index 444e36bfbb..ae44d21275 100644
--- a/tests/python/unittest/test_tir_transform_lower_opaque_block.py
+++ b/tests/python/unittest/test_tir_transform_lower_opaque_block.py
@@ -250,6 +250,50 @@ def transformed_strided_buffer_func(
             C[i0 * 4 + i1, j] = B[i1, j] * T.float32(2)
 
 
+@T.prim_func
+def compacted_symbolic_strided_buffer_func(a: T.handle) -> None:
+    n = T.int32()
+    A = T.match_buffer(a, (1, n, 10240))
+    padded_size = T.meta_var(T.min((n + 63) // 64 * 64, 96))
+    # with T.block("root"):
+    for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160):
+        with T.block(""):
+            A_pad_shared_dyn = T.alloc_buffer(
+                (1, padded_size, 64), strides=(72 * padded_size, 72, 1), scope="shared.dyn"
+            )
+            for ax0, ax1 in T.grid(96, 64):
+                with T.block("A_pad_shared.dyn"):
+                    T.where(i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64)
+                    A_pad_shared_dyn[0, ax0, ax1] = T.if_then_else(
+                        i * 128 + j * 32 + ax0 < n,
+                        A[0, i * 128 + j * 32 + ax0, k * 64 + ax1],
+                        T.float32(0),
+                    )
+
+
+@T.prim_func
+def transformed_symbolic_strided_buffer_func(a: T.handle):
+    n = T.int32()
+    A = T.match_buffer(a, (1, n, 10240))
+    for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160):
+        A_pad_shared_dyn = T.allocate(
+            [1, T.min((n + 63) // 64 * 64, 96), 72], "float32", "shared.dyn"
+        )
+        A_pad_shared_dyn_1 = T.decl_buffer(
+            (1, T.min((n + 63) // 64 * 64, 96), 64),
+            data=A_pad_shared_dyn,
+            strides=(72 * T.min((n + 63) // 64 * 64, 96), 72, 1),
+            scope="shared.dyn",
+        )
+        for ax0, ax1 in T.grid(96, 64):
+            if i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64:
+                A_pad_shared_dyn_1[0, ax0, ax1] = T.if_then_else(
+                    i * 128 + j * 32 + ax0 < n,
+                    A[0, i * 128 + j * 32 + ax0, k * 64 + ax1],
+                    T.float32(0),
+                )
+
+
 @T.prim_func
 def annotated_loops(a: T.handle) -> None:
     A = T.match_buffer(a, (16,), "float32")
@@ -301,6 +345,10 @@ def test_strided_buffer():
     _check(compacted_strided_buffer_func, transformed_strided_buffer_func)
 
 
+def test_symbolic_strided_buffer():
+    _check(compacted_symbolic_strided_buffer_func, transformed_symbolic_strided_buffer_func)
+
+
 def test_lower_te():
     x = te.placeholder((1,))
     y = te.compute((1,), lambda i: x[i] + 2)