You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2021/12/29 19:10:22 UTC

[tvm] branch main updated: [TensorIR] fix region cover check (#9810)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 654a687  [TensorIR] fix region cover check (#9810)
654a687 is described below

commit 654a687e5c67d1ccaaf54ffb8da79af3c113ce03
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Thu Dec 30 03:09:45 2021 +0800

    [TensorIR] fix region cover check (#9810)
---
 src/tir/schedule/state.cc                          |  4 ++--
 .../test_tir_schedule_state_cached_flags.py        | 23 ++++++++++++++++++++++
 2 files changed, 25 insertions(+), 2 deletions(-)

diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc
index faeb0b9..04b7dd5 100644
--- a/src/tir/schedule/state.cc
+++ b/src/tir/schedule/state.cc
@@ -104,9 +104,9 @@ bool ProducerCoversConsumer(const Array<PrimExpr>& buffer_shape,
     arith::IntSet produced = arith::Intersect({produced_region[i], buffer_size});
     arith::IntSet consumed = arith::Intersect({consumed_region[i], buffer_size});
     PrimExpr produced_min = analyzer->Simplify(produced.min());
-    PrimExpr produced_max = analyzer->Simplify(produced.max() - produced_min + 1);
+    PrimExpr produced_max = analyzer->Simplify(produced.max());
     PrimExpr consumed_min = analyzer->Simplify(consumed.min());
-    PrimExpr consumed_max = analyzer->Simplify(consumed.max() - consumed_min + 1);
+    PrimExpr consumed_max = analyzer->Simplify(consumed.max());
     if (!analyzer->CanProve((produced_min <= consumed_min) && (consumed_max <= produced_max))) {
       return false;
     }
diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py b/tests/python/unittest/test_tir_schedule_state_cached_flags.py
index d86af72..e88eacd 100644
--- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py
+++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py
@@ -353,6 +353,18 @@ def non_perfect_tiling_cache(a: T.handle, b: T.handle) -> None:
                 )
 
 
+@T.prim_func
+def uncovered_producer_region(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), "float32"]):
+    for i in range(120):
+        with T.block("producer"):
+            vi = T.axis.S((0, 120), i)
+            A[vi] = 1.0
+    for i in range(120):
+        with T.block("consumer"):
+            vi = T.axis.S((8, 128), i + 8)
+            B[vi] = A[vi]
+
+
 # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
 
 
@@ -757,5 +769,16 @@ def test_non_perfect_tiling_cache():
     # pylint: enable=protected-access
 
 
+def test_uncovered_producer_region():
+    s = tir.ScheduleState(uncovered_producer_region, debug_mask="all")
+    # pylint: disable=protected-access
+    assert s._get_cached_flags(_get_block(s, "consumer")) == CachedFlags(
+        affine_binding=True,
+        region_cover=False,
+        stage_pipeline=True,
+    )
+    # pylint: enable=protected-access
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))