You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wr...@apache.org on 2023/06/21 03:54:43 UTC
[tvm] branch main updated: [Arith][TIR] Recognize empty extents (#15129)
This is an automated email from the ASF dual-hosted git repository.
wrongtest pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b37ad17ce6 [Arith][TIR] Recognize empty extents (#15129)
b37ad17ce6 is described below
commit b37ad17ce605578df713c2c74e164a2cd6761b7d
Author: Krzysztof Parzyszek <kp...@quicinc.com>
AuthorDate: Tue Jun 20 22:54:36 2023 -0500
[Arith][TIR] Recognize empty extents (#15129)
Generate empty interval sets when empty extents are encountered. Handle
empty regions when constructing ScheduleState.
---
src/arith/int_set.cc | 6 +++++
src/tir/schedule/state.cc | 3 +++
.../python/unittest/test_tir_schedule_analysis.py | 28 ++++++++++++++++++++++
3 files changed, 37 insertions(+)
diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc
index cf93a481c2..625488430b 100644
--- a/src/arith/int_set.cc
+++ b/src/arith/int_set.cc
@@ -1053,11 +1053,17 @@ Map<Var, arith::IntSet> AsIntSet(const Map<Var, Range>& var_dom) {
/*! \brief Helper function to convert IterSumExpr to the actual touched range. */
static Optional<IntSet> EvalIterSum(const IterSumExpr& iter_min, const PrimExpr& extent,
Analyzer* analyzer) {
+ if (analyzer->CanProve(extent == 0)) {
+ return IntSet::Nothing();
+ }
if (iter_min->args.empty()) {
return IntSet::FromMinExtent(iter_min->base, extent);
}
ICHECK_EQ(iter_min->args.size(), 1) << "The `EvalIterSum` expects fused iter sum expr";
const IterSplitExpr& split = iter_min->args[0];
+ if (analyzer->CanProve(split->extent == 0)) {
+ return IntSet::Nothing();
+ }
if (!analyzer->CanProve(extent >= split->scale)) {
return NullOpt;
}
diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc
index 33f8598289..facee629af 100644
--- a/src/tir/schedule/state.cc
+++ b/src/tir/schedule/state.cc
@@ -99,6 +99,9 @@ bool ProducerCoversConsumer(const Array<PrimExpr>& buffer_shape,
if (produced_region[i].IsNothing()) {
return false;
}
+ if (consumed_region[i].IsNothing()) {
+ continue;
+ }
arith::IntSet produced =
arith::IntSet::Interval(analyzer->canonical_simplify(produced_region[i].min()),
analyzer->canonical_simplify(produced_region[i].max()));
diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py
index cd91a44b65..4484b6ab39 100644
--- a/tests/python/unittest/test_tir_schedule_analysis.py
+++ b/tests/python/unittest/test_tir_schedule_analysis.py
@@ -417,5 +417,33 @@ def test_is_output_block():
assert is_output_block(sch, block_rv)
+def test_empty_grid():
+ @T.prim_func
+ def foo(out: T.Buffer((T.int64(1), T.int64(8), T.int64(8)), "int32")):
+ act = T.alloc_buffer((1, 8, 8), "int32")
+ for z2, y2, x2 in T.grid(1, 8, 8):
+ with T.block("b0"):
+ az, ay, ax = T.axis.remap("SSS", [z2, y2, x2])
+ T.writes(act[az, ay, ax])
+ act[az, ay, az] = T.int32(0)
+ # Empty grid:
+ for z1, y1, x1 in T.grid(0, 8, 8):
+ with T.block("b1"):
+ az, ay, ax = T.axis.remap("SSS", [z1, y1, x1])
+ T.reads(act[az + 1, ay, ax])
+ T.writes(out[az, ay, ax])
+ out[az, ay, ax] = act[az + 1, ay, ax]
+ # The block below is not needed to show the bug, but the 'out'
+ # buffer would be undefined without it.
+ for z2, y2, x2 in T.grid(1, 8, 8):
+ with T.block("b2"):
+ az, ay, ax = T.axis.remap("SSS", [z2, y2, x2])
+ T.writes(out[az, ay, ax])
+ out[az, ay, az] = T.int32(0)
+
+ # This caused a crash before.
+ sch = tvm.tir.Schedule(foo)
+
+
if __name__ == "__main__":
tvm.testing.main()