You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2021/10/01 14:35:51 UTC

[tvm] branch main updated: [TIR][LowerMatchBuffer] Fix lowering strides when source buffer has non-empty strides (#9166)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 62a7fb7  [TIR][LowerMatchBuffer] Fix lowering strides when source buffer has non-empty strides (#9166)
62a7fb7 is described below

commit 62a7fb78ce70396e22ca324cc9cd8560b2c4ed42
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Fri Oct 1 10:35:31 2021 -0400

    [TIR][LowerMatchBuffer] Fix lowering strides when source buffer has non-empty strides (#9166)
---
 src/tir/transforms/lower_match_buffer.cc           | 18 +++++---
 .../python/unittest/test_tir_lower_match_buffer.py | 52 ++++++++++++++++++++++
 2 files changed, 65 insertions(+), 5 deletions(-)

diff --git a/src/tir/transforms/lower_match_buffer.cc b/src/tir/transforms/lower_match_buffer.cc
index b48749c..6bfbcef 100644
--- a/src/tir/transforms/lower_match_buffer.cc
+++ b/src/tir/transforms/lower_match_buffer.cc
@@ -198,11 +198,19 @@ class MatchBufferLower : public StmtExprMutator {
     int offset = source->region.size() - buffer->shape.size();
     if (!buffer->strides.empty()) {
       ICHECK_EQ(buffer->strides.size(), buffer->shape.size());
-      PrimExpr stride = make_const(DataType::Int(32), 1);
-      for (size_t i = buffer->shape.size(); i > 0; --i) {
-        const PrimExpr& shape = source_buffer->shape[i - 1 + offset];
-        Bind(buffer->strides[i - 1], stride, buffer->name + ".strides_" + std::to_string(i - 1));
-        stride *= shape;
+      if (source_buffer->strides.empty()) {
+        PrimExpr stride = make_const(DataType::Int(32), 1);
+        for (size_t i = buffer->shape.size(); i > 0; --i) {
+          const PrimExpr& shape = source_buffer->shape[i - 1 + offset];
+          Bind(buffer->strides[i - 1], stride, buffer->name + ".strides_" + std::to_string(i - 1));
+          stride *= shape;
+        }
+      } else {
+        ICHECK_EQ(buffer->shape.size() + offset, source_buffer->strides.size());
+        for (size_t i = buffer->shape.size(); i > 0; --i) {
+          const PrimExpr& stride = source_buffer->strides[i - 1 + offset];
+          Bind(buffer->strides[i - 1], stride, buffer->name + ".strides_" + std::to_string(i - 1));
+        }
       }
     }
 
diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py
index efb2073..75c95a3 100644
--- a/tests/python/unittest/test_tir_lower_match_buffer.py
+++ b/tests/python/unittest/test_tir_lower_match_buffer.py
@@ -205,6 +205,54 @@ def transformed_high_dim_opaque_access(a: ty.handle) -> None:
 
 
 @tvm.script.tir
+def high_dim_opaque_access_with_source_strides(a: ty.handle) -> None:
+    A = tir.match_buffer(a, (16, 32, 64), strides=[2576, 80, 1])
+    for i, j, k in tir.grid(16, 2, 4):
+        with tir.block([]):
+            As_0 = tir.var("int32")
+            As_1 = tir.var("int32")
+            tir.reads([])
+            tir.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16])
+            sub_A = tir.match_buffer(
+                A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16],
+                (16, 16),
+                strides=[As_0, As_1],
+                offset_factor=1,
+            )
+            tir.evaluate(
+                tir.intrin_test(
+                    sub_A.data,
+                    sub_A.elem_offset,
+                    sub_A.strides[0],
+                    sub_A.strides[1],
+                    sub_A.shape[0],
+                    sub_A.shape[1],
+                    dtype="handle",
+                )
+            )
+
+
+@tvm.script.tir
+def transformed_high_dim_opaque_access_with_source_strides(a: ty.handle) -> None:
+    A = tir.match_buffer(a, (16, 32, 64), strides=[2576, 80, 1])
+    for i, j, k in tir.grid(16, 2, 4):
+        with tir.block([]):
+            tir.reads([])
+            tir.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16])
+            tir.evaluate(
+                tir.intrin_test(
+                    A.data,
+                    i * 2576 + j * 1280 + k * 16,
+                    80,
+                    1,
+                    16,
+                    16,
+                    dtype="handle",
+                )
+            )
+
+
+@tvm.script.tir
 def recursive_match(a: ty.handle, b: ty.handle) -> None:
     A = tir.match_buffer(a, (64, 64, 64))
     B = tir.match_buffer(b, (64, 64, 64))
@@ -469,6 +517,10 @@ def test_opaque_access():
 
 def test_high_dim_opaque_access():
     _check(high_dim_opaque_access, transformed_high_dim_opaque_access)
+    _check(
+        high_dim_opaque_access_with_source_strides,
+        transformed_high_dim_opaque_access_with_source_strides,
+    )
 
 
 def test_recursive_match():