You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wr...@apache.org on 2023/10/30 03:29:36 UTC
(tvm) branch main updated: [Fix][TIR]fix symbolic strides lower (#16000)
This is an automated email from the ASF dual-hosted git repository.
wrongtest pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 57597f62b4 [Fix][TIR]fix symbolic strides lower (#16000)
57597f62b4 is described below
commit 57597f62b44f0ab17adae34b89f5f526b816b759
Author: Wei Tao <11...@qq.com>
AuthorDate: Mon Oct 30 11:29:30 2023 +0800
[Fix][TIR]fix symbolic strides lower (#16000)
* [Fix][TIR]fix symbolic strides lower
* [Fix][TIR] run the black formatter
---
src/tir/transforms/ir_utils.cc | 3 +-
.../test_tir_transform_lower_opaque_block.py | 48 ++++++++++++++++++++++
2 files changed, 50 insertions(+), 1 deletion(-)
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index 99ed437659..25c10dd682 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -417,7 +417,8 @@ Array<PrimExpr> GetBufferAllocationShape(const Buffer& buffer) {
if (buffer->strides.size()) {
ICHECK_EQ(buffer->shape.size(), buffer->strides.size());
for (size_t i = buffer->strides.size() - 1; i > 0; --i) {
- ICHECK(is_zero(floormod(buffer->strides[i - 1], buffer->strides[i])));
+ ICHECK(
+ arith::Analyzer().CanProveEqual(floormod(buffer->strides[i - 1], buffer->strides[i]), 0));
alloc_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]);
}
}
diff --git a/tests/python/unittest/test_tir_transform_lower_opaque_block.py b/tests/python/unittest/test_tir_transform_lower_opaque_block.py
index 444e36bfbb..ae44d21275 100644
--- a/tests/python/unittest/test_tir_transform_lower_opaque_block.py
+++ b/tests/python/unittest/test_tir_transform_lower_opaque_block.py
@@ -250,6 +250,50 @@ def transformed_strided_buffer_func(
C[i0 * 4 + i1, j] = B[i1, j] * T.float32(2)
+@T.prim_func
+def compacted_symbolic_strided_buffer_func(a: T.handle) -> None:
+ n = T.int32()
+ A = T.match_buffer(a, (1, n, 10240))
+ padded_size = T.meta_var(T.min((n + 63) // 64 * 64, 96))
+ # with T.block("root"):
+ for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160):
+ with T.block(""):
+ A_pad_shared_dyn = T.alloc_buffer(
+ (1, padded_size, 64), strides=(72 * padded_size, 72, 1), scope="shared.dyn"
+ )
+ for ax0, ax1 in T.grid(96, 64):
+ with T.block("A_pad_shared.dyn"):
+ T.where(i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64)
+ A_pad_shared_dyn[0, ax0, ax1] = T.if_then_else(
+ i * 128 + j * 32 + ax0 < n,
+ A[0, i * 128 + j * 32 + ax0, k * 64 + ax1],
+ T.float32(0),
+ )
+
+
+@T.prim_func
+def transformed_symbolic_strided_buffer_func(a: T.handle):
+ n = T.int32()
+ A = T.match_buffer(a, (1, n, 10240))
+ for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160):
+ A_pad_shared_dyn = T.allocate(
+ [1, T.min((n + 63) // 64 * 64, 96), 72], "float32", "shared.dyn"
+ )
+ A_pad_shared_dyn_1 = T.decl_buffer(
+ (1, T.min((n + 63) // 64 * 64, 96), 64),
+ data=A_pad_shared_dyn,
+ strides=(72 * T.min((n + 63) // 64 * 64, 96), 72, 1),
+ scope="shared.dyn",
+ )
+ for ax0, ax1 in T.grid(96, 64):
+ if i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64:
+ A_pad_shared_dyn_1[0, ax0, ax1] = T.if_then_else(
+ i * 128 + j * 32 + ax0 < n,
+ A[0, i * 128 + j * 32 + ax0, k * 64 + ax1],
+ T.float32(0),
+ )
+
+
@T.prim_func
def annotated_loops(a: T.handle) -> None:
A = T.match_buffer(a, (16,), "float32")
@@ -301,6 +345,10 @@ def test_strided_buffer():
_check(compacted_strided_buffer_func, transformed_strided_buffer_func)
+def test_symbolic_strided_buffer():
+ _check(compacted_symbolic_strided_buffer_func, transformed_symbolic_strided_buffer_func)
+
+
def test_lower_te():
x = te.placeholder((1,))
y = te.compute((1,), lambda i: x[i] + 2)