You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2021/10/01 14:35:51 UTC
[tvm] branch main updated: [TIR][LowerMatchBuffer] Fix lowering
strides when source buffer has non-empty strides (#9166)
This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 62a7fb7 [TIR][LowerMatchBuffer] Fix lowering strides when source buffer has non-empty strides (#9166)
62a7fb7 is described below
commit 62a7fb78ce70396e22ca324cc9cd8560b2c4ed42
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Fri Oct 1 10:35:31 2021 -0400
[TIR][LowerMatchBuffer] Fix lowering strides when source buffer has non-empty strides (#9166)
---
src/tir/transforms/lower_match_buffer.cc | 18 +++++---
.../python/unittest/test_tir_lower_match_buffer.py | 52 ++++++++++++++++++++++
2 files changed, 65 insertions(+), 5 deletions(-)
diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc
index b48749c..6bfbcef 100644
--- a/src/tir/transforms/lower_match_buffer.cc
+++ b/src/tir/transforms/lower_match_buffer.cc
@@ -198,11 +198,19 @@ class MatchBufferLower : public StmtExprMutator {
int offset = source->region.size() - buffer->shape.size();
if (!buffer->strides.empty()) {
ICHECK_EQ(buffer->strides.size(), buffer->shape.size());
- PrimExpr stride = make_const(DataType::Int(32), 1);
- for (size_t i = buffer->shape.size(); i > 0; --i) {
- const PrimExpr& shape = source_buffer->shape[i - 1 + offset];
- Bind(buffer->strides[i - 1], stride, buffer->name + ".strides_" + std::to_string(i - 1));
- stride *= shape;
+ if (source_buffer->strides.empty()) {
+ PrimExpr stride = make_const(DataType::Int(32), 1);
+ for (size_t i = buffer->shape.size(); i > 0; --i) {
+ const PrimExpr& shape = source_buffer->shape[i - 1 + offset];
+ Bind(buffer->strides[i - 1], stride, buffer->name + ".strides_" + std::to_string(i - 1));
+ stride *= shape;
+ }
+ } else {
+ ICHECK_EQ(buffer->shape.size() + offset, source_buffer->strides.size());
+ for (size_t i = buffer->shape.size(); i > 0; --i) {
+ const PrimExpr& stride = source_buffer->strides[i - 1 + offset];
+ Bind(buffer->strides[i - 1], stride, buffer->name + ".strides_" + std::to_string(i - 1));
+ }
}
}
diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py
index efb2073..75c95a3 100644
--- a/tests/python/unittest/test_tir_lower_match_buffer.py
+++ b/tests/python/unittest/test_tir_lower_match_buffer.py
@@ -205,6 +205,54 @@ def transformed_high_dim_opaque_access(a: ty.handle) -> None:
@tvm.script.tir
+def high_dim_opaque_access_with_source_strides(a: ty.handle) -> None:
+ A = tir.match_buffer(a, (16, 32, 64), strides=[2576, 80, 1])
+ for i, j, k in tir.grid(16, 2, 4):
+ with tir.block([]):
+ As_0 = tir.var("int32")
+ As_1 = tir.var("int32")
+ tir.reads([])
+ tir.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16])
+ sub_A = tir.match_buffer(
+ A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16],
+ (16, 16),
+ strides=[As_0, As_1],
+ offset_factor=1,
+ )
+ tir.evaluate(
+ tir.intrin_test(
+ sub_A.data,
+ sub_A.elem_offset,
+ sub_A.strides[0],
+ sub_A.strides[1],
+ sub_A.shape[0],
+ sub_A.shape[1],
+ dtype="handle",
+ )
+ )
+
+
+@tvm.script.tir
+def transformed_high_dim_opaque_access_with_source_strides(a: ty.handle) -> None:
+ A = tir.match_buffer(a, (16, 32, 64), strides=[2576, 80, 1])
+ for i, j, k in tir.grid(16, 2, 4):
+ with tir.block([]):
+ tir.reads([])
+ tir.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16])
+ tir.evaluate(
+ tir.intrin_test(
+ A.data,
+ i * 2576 + j * 1280 + k * 16,
+ 80,
+ 1,
+ 16,
+ 16,
+ dtype="handle",
+ )
+ )
+
+
+@tvm.script.tir
def recursive_match(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (64, 64, 64))
B = tir.match_buffer(b, (64, 64, 64))
@@ -469,6 +517,10 @@ def test_opaque_access():
def test_high_dim_opaque_access():
_check(high_dim_opaque_access, transformed_high_dim_opaque_access)
+ _check(
+ high_dim_opaque_access_with_source_strides,
+ transformed_high_dim_opaque_access_with_source_strides,
+ )
def test_recursive_match():