You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wr...@apache.org on 2023/06/21 03:54:43 UTC

[tvm] branch main updated: [Arith][TIR] Recognize empty extents (#15129)

This is an automated email from the ASF dual-hosted git repository.

wrongtest pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b37ad17ce6 [Arith][TIR] Recognize empty extents (#15129)
b37ad17ce6 is described below

commit b37ad17ce605578df713c2c74e164a2cd6761b7d
Author: Krzysztof Parzyszek <kp...@quicinc.com>
AuthorDate: Tue Jun 20 22:54:36 2023 -0500

    [Arith][TIR] Recognize empty extents (#15129)
    
    Generate empty interval sets when empty extents are encountered. Handle
    empty regions when constructing ScheduleState.
---
 src/arith/int_set.cc                               |  6 +++++
 src/tir/schedule/state.cc                          |  3 +++
 .../python/unittest/test_tir_schedule_analysis.py  | 28 ++++++++++++++++++++++
 3 files changed, 37 insertions(+)

diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc
index cf93a481c2..625488430b 100644
--- a/src/arith/int_set.cc
+++ b/src/arith/int_set.cc
@@ -1053,11 +1053,17 @@ Map<Var, arith::IntSet> AsIntSet(const Map<Var, Range>& var_dom) {
 /*! \brief Helper function to convert IterSumExpr to the actual touched range. */
 static Optional<IntSet> EvalIterSum(const IterSumExpr& iter_min, const PrimExpr& extent,
                                     Analyzer* analyzer) {
+  if (analyzer->CanProve(extent == 0)) {
+    return IntSet::Nothing();
+  }
   if (iter_min->args.empty()) {
     return IntSet::FromMinExtent(iter_min->base, extent);
   }
   ICHECK_EQ(iter_min->args.size(), 1) << "The `EvalIterSum` expects fused iter sum expr";
   const IterSplitExpr& split = iter_min->args[0];
+  if (analyzer->CanProve(split->extent == 0)) {
+    return IntSet::Nothing();
+  }
   if (!analyzer->CanProve(extent >= split->scale)) {
     return NullOpt;
   }
diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc
index 33f8598289..facee629af 100644
--- a/src/tir/schedule/state.cc
+++ b/src/tir/schedule/state.cc
@@ -99,6 +99,9 @@ bool ProducerCoversConsumer(const Array<PrimExpr>& buffer_shape,
     if (produced_region[i].IsNothing()) {
       return false;
     }
+    if (consumed_region[i].IsNothing()) {
+      continue;
+    }
     arith::IntSet produced =
         arith::IntSet::Interval(analyzer->canonical_simplify(produced_region[i].min()),
                                 analyzer->canonical_simplify(produced_region[i].max()));
diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py
index cd91a44b65..4484b6ab39 100644
--- a/tests/python/unittest/test_tir_schedule_analysis.py
+++ b/tests/python/unittest/test_tir_schedule_analysis.py
@@ -417,5 +417,33 @@ def test_is_output_block():
     assert is_output_block(sch, block_rv)
 
 
+def test_empty_grid():
+    @T.prim_func
+    def foo(out: T.Buffer((T.int64(1), T.int64(8), T.int64(8)), "int32")):
+        act = T.alloc_buffer((1, 8, 8), "int32")
+        for z2, y2, x2 in T.grid(1, 8, 8):
+            with T.block("b0"):
+                az, ay, ax = T.axis.remap("SSS", [z2, y2, x2])
+                T.writes(act[az, ay, ax])
+                act[az, ay, az] = T.int32(0)
+        # Empty grid:
+        for z1, y1, x1 in T.grid(0, 8, 8):
+            with T.block("b1"):
+                az, ay, ax = T.axis.remap("SSS", [z1, y1, x1])
+                T.reads(act[az + 1, ay, ax])
+                T.writes(out[az, ay, ax])
+                out[az, ay, ax] = act[az + 1, ay, ax]
+        # The block below is not needed to show the bug, but the 'out'
+        # buffer would be undefined without it.
+        for z2, y2, x2 in T.grid(1, 8, 8):
+            with T.block("b2"):
+                az, ay, ax = T.axis.remap("SSS", [z2, y2, x2])
+                T.writes(out[az, ay, ax])
+                out[az, ay, az] = T.int32(0)
+
+    # This caused a crash before.
+    sch = tvm.tir.Schedule(foo)
+
+
 if __name__ == "__main__":
     tvm.testing.main()