You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2021/12/25 04:13:43 UTC
[tvm] branch main updated: [TIR] Affine utility support iter lowerbound and diagnostics (#9699)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e2dcba2 [TIR] Affine utility support iter lowerbound and diagnostics (#9699)
e2dcba2 is described below
commit e2dcba2fde4f129c04bef6ab6310c74ba05f2a36
Author: wrongtest <wr...@gmail.com>
AuthorDate: Sat Dec 25 12:13:09 2021 +0800
[TIR] Affine utility support iter lowerbound and diagnostics (#9699)
* Enable freevars, iter lowerbound and diagnostics in affine utility
* fix lint issues and compare bug
* update to use iter shift instead of itermark min for lowerbound
* add testcase of fused iters sum with multiple lowerbounds
* add more affine check testcases, fix bug for single iter and duplicate constraints on iter
* add a newline to comment
* forbidden predicate unmatch
Co-authored-by: baoxinqi <wr...@intellif.com>
---
include/tvm/arith/iter_affine_map.h | 8 +-
src/arith/int_set.cc | 4 +-
src/arith/iter_affine_map.cc | 471 ++++++++++++++++-----
src/tir/schedule/analysis/analysis.cc | 4 +-
tests/python/unittest/test_arith_intset.py | 87 ++++
.../python/unittest/test_arith_iter_affine_map.py | 200 +++++++++
tests/python/unittest/test_tir_schedule_reorder.py | 5 +-
tests/python/unittest/test_tir_schedule_rfactor.py | 8 +-
.../test_tir_schedule_state_cached_flags.py | 55 +++
9 files changed, 720 insertions(+), 122 deletions(-)
diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h
index 6c72cbe..22b4cd5 100644
--- a/include/tvm/arith/iter_affine_map.h
+++ b/include/tvm/arith/iter_affine_map.h
@@ -49,6 +49,7 @@
#define TVM_ARITH_ITER_AFFINE_MAP_H_
#include <tvm/arith/analyzer.h>
+#include <tvm/ir/diagnostic.h>
#include <tvm/ir/expr.h>
#include <tvm/tir/var.h>
@@ -275,13 +276,14 @@ class IterSumExpr : public IterMapExpr {
* \param predicate The predicate constraints on the input iterators
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
* \param analyzer Analyzer used to get context information.
+ * \param diag_ctx Diagnostic context.
*
* \return The detected pattern if a match exists,
* otherwise return an empty array.
*/
Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& predicate, bool require_bijective,
- arith::Analyzer* analyzer);
+ arith::Analyzer* analyzer, DiagnosticContext diag_ctx);
/*!
* \brief Use IterVarMap detector to rewrite and simplify the indices
*
@@ -333,6 +335,7 @@ Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
* \param predicate The predicate constraints on the input iterators
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
* \param analyzer Analyzer used to get context information.
+ * \param diag_ctx Diagnostic context.
*
* \return The result list has length len(bindings) + 1
[0, len(bindings)): The iter map matching result. The inner list is of length 2.
@@ -344,7 +347,8 @@ Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
const Map<Var, Range>& input_iters,
const Array<Var>& sub_iters, const PrimExpr& predicate,
- bool require_bijective, arith::Analyzer* analyzer);
+ bool require_bijective, arith::Analyzer* analyzer,
+ DiagnosticContext diag_ctx);
} // namespace arith
} // namespace tvm
diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc
index e620e3b..55a1a5a 100644
--- a/src/arith/int_set.cc
+++ b/src/arith/int_set.cc
@@ -835,9 +835,10 @@ Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
for (const Range& range : region) {
affine_indices.push_back(range->min);
}
+ DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule()));
iter_sum_exprs = DetectIterMap(
/*indices=*/affine_indices, /*input_iters=*/var_dom,
- /*predicate=*/predicate, /*require_bijective=*/false, analyzer);
+ /*predicate=*/predicate, /*require_bijective=*/false, analyzer, diag_ctx);
}
if (iter_sum_exprs.empty()) {
return NullOpt;
@@ -857,6 +858,7 @@ Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
if (!analyzer->CanProve(range->extent >= split->scale)) {
return NullOpt;
}
+
const PrimExpr& base = sum_expr->base;
// IterSplitExpr: (source // lower_factor) % extent * scale
// where `(source // lower_factor) % extent` is within [0, extent - 1]
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index 34c35ce..c9d4b1e 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -160,13 +160,22 @@ class IterMarkSplitCollector {
}
};
+/*! \brief Record form of IterMark(x, extent) + offset */
+struct IterMarkWithOffset {
+ IterMark mark;
+ PrimExpr offset{0};
+ IterMarkWithOffset() {}
+ IterMarkWithOffset(IterMark mark, PrimExpr offset) : mark(mark), offset(offset) {}
+};
+
/*! \brief Rewriter to rewrite PrimExpr to IterMapExpr when possible */
class IterMapRewriter : public ExprMutator {
public:
using Parent = ExprMutator;
- explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters)
- : analyzer_(analyzer) {
+ explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters,
+ DiagnosticContext diag_ctx)
+ : analyzer_(analyzer), diag_ctx_(diag_ctx) {
for (auto kv : input_iters) {
const Var& var = kv.first;
const Range& vrng = kv.second;
@@ -192,9 +201,10 @@ class IterMapRewriter : public ExprMutator {
return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr)));
}
- IterSumExpr RewriteIterConstraint(const PrimExpr& expr,
- const PrimExpr& predicate_induced_extent) {
- return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_extent);
+ IterSumExpr RewriteIterConstraint(const PrimExpr& expr, const PrimExpr& predicate_induced_min,
+ const PrimExpr& predicate_induced_max) {
+ return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_min,
+ predicate_induced_max);
}
/*!
@@ -224,13 +234,21 @@ class IterMapRewriter : public ExprMutator {
// The splits do not overlap with each other.
collector.Collect(bindings);
for (const IterMark& mark : collector.visited_) {
- if (TryNormalizeSplits(mark, collector.mark2splits_[mark], require_bijective).empty())
+ if (TryNormalizeSplits(mark, collector.mark2splits_[mark], require_bijective).empty()) {
+ diag_ctx_.Emit(Diagnostic::Error(mark->source->span)
+ << "Fail to normalize iter mark splits: " << mark);
return false;
+ }
}
if (require_bijective) {
// all input marks must be visited
for (const IterMark& mark : input_marks_) {
- if (collector.visited_.count(mark) == 0) return false;
+ if (collector.visited_.count(mark) == 0) {
+ diag_ctx_.Emit(Diagnostic::Error(mark->source->span)
+ << "The mapping is not bijective because input iter mark " << mark
+ << " is not covered, ");
+ return false;
+ }
}
}
return true;
@@ -278,7 +296,7 @@ class IterMapRewriter : public ExprMutator {
PrimExpr VisitExpr(const PrimExpr& input_expr) final {
auto expr = ExprMutator::VisitExpr(input_expr);
if (expr->IsInstance<IterMapExprNode>()) {
- ++unresolved_count_;
+ Fail(Diagnostic::Error(input_expr->span));
}
return expr;
}
@@ -328,6 +346,13 @@ class IterMapRewriter : public ExprMutator {
}
};
+ void Fail(const Diagnostic& diagnostic) {
+ unresolved_count_++;
+ if (diag_ctx_.defined()) {
+ diag_ctx_.Emit(diagnostic);
+ }
+ }
+
// Internal analyzer
Analyzer* analyzer_;
// Counter to keep track of unresolved cases.
@@ -336,8 +361,9 @@ class IterMapRewriter : public ExprMutator {
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> var_map_;
// input iter marks
std::vector<IterMark> input_marks_;
- // The map for sum that maps flattened form to IterMark with normal form and extent
- // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
+ // The map for sum that maps flattened form to IterMark with normal form and extent (and possibly
+ // an extra offset)
+ // Example(1): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
// predicate: j*2 + k < 9
// Then, flattened form = IterSum(IterSplit(i, scale=9),
// IterSplit(j, scale=2),
@@ -347,11 +373,24 @@ class IterMapRewriter : public ExprMutator {
// IterSplit(k, scale=1)),
// extent=9)
// scale=1))
- std::unordered_map<IterSumExpr, IterMark, IterSumHash, IterSumEqual> sum_fuse_map_;
+ // Example(2): expr = i*8 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
+ // predicate: 1 <= j*2 + k < 9
+ // Then, flattened form = IterSum(IterSplit(i, scale=8),
+ // IterSplit(j, scale=2),
+ // IterSplit(k, scale=1))
+ // normal form = IterSum(IterSplit(i, scale=8),
+ // IterSplit(IterMark(IterSum(IterSplit(j, scale=2),
+ // IterSplit(k, scale=1), base=-1),
+ // extent=9-1)
+ // scale=1),
+ // base=1)
+ std::unordered_map<IterSumExpr, IterMarkWithOffset, IterSumHash, IterSumEqual> sum_fuse_map_;
// The map for sum that maps normal form to flattened form
std::unordered_map<IterSumExpr, IterSumExpr, IterSumHash, IterSumEqual> flattened_map_;
// The flattened forms of constrained iters
std::vector<IterSumExpr> constrained_iters_flattened_;
+ // Diagnostic context
+ DiagnosticContext diag_ctx_;
/*!
* \brief Look for a split in splits that is not used such that its lower_factor is smallest.
@@ -407,19 +446,32 @@ class IterMapRewriter : public ExprMutator {
}
if (j == splits.size()) {
// we do not allow incomplete split if the bindings should be bijective
- if (require_bijective) return Array<IterSplitExpr>();
+ if (require_bijective) {
+ diag_ctx_.Emit(
+ Diagnostic::Error(mark->source->span)
+ << "Do not allow incomplete split in bijective checking, expected_lower_factor="
+ << expected_lower_factor);
+ return Array<IterSplitExpr>();
+ }
// look for the next split skipping this lower factor
// For example, y \in [0, 24) has 3 splits [y / 6, (y / 2) % 6, y % 2]
// It is valid to only have [y / 6, y % 2] if bijective is not required
// We can skip (y / 2) % 6
j = SearchSkipLowerFactor(splits, used, expected_lower_factor);
// split not found
- if (j == splits.size()) return Array<IterSplitExpr>();
+ if (j == splits.size()) {
+ diag_ctx_.Emit(Diagnostic::Error(mark->source->span)
+ << "Fail to find split skipping the lower factor in bijective-free "
+ "checking, expected_lower_factor="
+ << expected_lower_factor);
+ return Array<IterSplitExpr>();
+ }
}
used[j] = true;
iters.push_back(splits[j]);
expected_lower_factor = splits[j]->lower_factor * splits[j]->extent;
}
+
// Case 1. bijective is required.
// We check the extent we calculate is consistent with the extent of the mark
// Case 2. bijective is not required.
@@ -427,42 +479,73 @@ class IterMapRewriter : public ExprMutator {
// For example, y \in [0, 24) [(y / 2) % 6, y % 2] is valid, but y \in [0, 25) is not.
if ((require_bijective && !analyzer_->CanProveEqual(expected_lower_factor, mark->extent)) ||
(!require_bijective && !CanProveDivisible(mark->extent, expected_lower_factor))) {
+ diag_ctx_.Emit(Diagnostic::Error(mark->source->span)
+ << "Mark extent of " << mark
+ << " is not compatible with expected_lower_factor=" << expected_lower_factor);
return Array<IterSplitExpr>();
}
return Array<IterSplitExpr>(iters.rbegin(), iters.rend());
}
/*!
- * \brief Normalize the left hand side of iter constraint(expr < predicate_induced_extent)
- * \param expr The left hand side of iter constraint.
- * \param predicate_induced_extent Extent from iter constraint.
+ * \brief Normalize the iter expression with constraint (min <= expr < max)
+ * \param expr The iter expression.
+ * \param predicate_induced_min Closed lower bound from iter constraint, maybe undefined.
+ * \param predicate_induced_max Open upper bound from iter constraint, maybe undefined.
* \return The Normalized expression.
*/
- IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr,
- const PrimExpr& predicate_induced_extent) {
- // We are normalizing the left hand side of iter constraint(iter < predicate_induced_extent)
- Optional<IterSplitExpr> opt = TryFuseIters(expr);
+ IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, PrimExpr predicate_induced_min,
+ PrimExpr predicate_induced_max) {
+ // normalize to zero base
+ PrimExpr base = expr->base;
+ if (!is_zero(base)) {
+ expr.CopyOnWrite()->base = 0;
+ if (predicate_induced_min.defined()) predicate_induced_min = predicate_induced_min - base;
+ if (predicate_induced_max.defined()) predicate_induced_max = predicate_induced_max - base;
+ }
+ Optional<IterSumExpr> opt = TryFuseIters(expr);
+ ICHECK(!opt.defined() || opt.value()->args.size() == 1);
// scale should be 1
- if (opt.defined() && is_one(opt.value()->scale)) {
- IterSumExpr sum = Downcast<IterSumExpr>(opt.value()->source->source);
+ if (opt.defined() && is_one(opt.value()->args[0]->scale)) {
+ const IterSplitExpr split = opt.value()->args[0];
+ IterSumExpr structured_form = Downcast<IterSumExpr>(split->source->source);
// get the flattened form
- auto it = flattened_map_.find(sum);
+ auto it = flattened_map_.find(structured_form);
ICHECK(it != flattened_map_.end());
IterSumExpr flattened_form = it->second;
- // get the mark
+ // get the mark and offset of the structured_form
auto it_mark = sum_fuse_map_.find(flattened_form);
ICHECK(it_mark != sum_fuse_map_.end());
- IterMark mark = it_mark->second;
- mark.CopyOnWrite()->extent = min(predicate_induced_extent, mark->extent);
- // update the bound of the lhs based on predicate_induced_extent
- sum_fuse_map_[flattened_form] = mark;
+ IterMark mark = it_mark->second.mark;
+ PrimExpr mark_offset = it_mark->second.offset;
+ PrimExpr iter_min = mark_offset;
+ PrimExpr iter_max = iter_min + mark->extent;
+ if (predicate_induced_min.defined()) {
+ iter_min = max(predicate_induced_min, iter_min);
+ }
+ if (predicate_induced_max.defined()) {
+ iter_max = min(predicate_induced_max, iter_max);
+ }
+ if (!is_zero(iter_min)) {
+ // structured form's offset should be updated
+ flattened_map_.erase(structured_form);
+ structured_form.CopyOnWrite()->base = -iter_min;
+ mark.CopyOnWrite()->source = structured_form;
+ flattened_map_[structured_form] = flattened_form;
+ }
+ mark.CopyOnWrite()->extent = iter_max - iter_min;
+ sum_fuse_map_[flattened_form] = {mark, iter_min};
+
// we need to note down the flattened form of constrained iterators
// to check the validity of constraints, see also CheckConstraints()
constrained_iters_flattened_.push_back(flattened_form);
- expr.CopyOnWrite()->args = Array<IterSplitExpr>({opt.value()});
+ expr.CopyOnWrite()->args = Array<IterSplitExpr>({split});
+ expr.CopyOnWrite()->base = base + iter_min;
return expr;
}
- ++unresolved_count_;
+ Fail(Diagnostic::Error(expr->span)
+ << "Fail to normalize " << expr << " with predicate bound [" << predicate_induced_min
+ << ", " << predicate_induced_max << ")");
return expr;
}
@@ -473,16 +556,12 @@ class IterMapRewriter : public ExprMutator {
*/
IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) {
// We are normalizing a regular iter
- if (expr->args.size() <= 1) return expr;
- PrimExpr base = expr->base;
- expr.CopyOnWrite()->base = make_zero(expr->dtype);
- Optional<IterSplitExpr> opt = TryFuseIters(expr);
- expr.CopyOnWrite()->base = base;
+ if (expr->args.size() < 1) return expr;
+ Optional<IterSumExpr> opt = TryFuseIters(expr);
if (opt.defined()) {
- expr.CopyOnWrite()->args = Array<IterSplitExpr>({opt.value()});
- return expr;
+ return opt.value();
} else {
- ++unresolved_count_;
+ Fail(Diagnostic::Error(expr->span) << "Fail to normalize iter sum with offset: " << expr);
return expr;
}
}
@@ -504,17 +583,16 @@ class IterMapRewriter : public ExprMutator {
}
/*!
- * \brief IterSum = x1*c1 + x2*c2 + ... + xn*cn
- * = (x1*s1 + x2*s2 + ... + xn)*cn
- * = y*cn (IterMark y => x1*s1 + x2*s2 + ... + xn)
- * = [IterSplit(IterMark(y), scale=cn)]
- * return a corresponding IterSplitExpr if needed.
+ * \brief IterSum = x1*c1 + x2*c2 + ... + xn*cn + base
+ * = (x1*s1 + x2*s2 + ... + xn)*cn + base
+ * = y*cn (IterMark y => x1*s1 + x2*s2 + ... + xn) + base
+ * = [IterSplit(IterMark(y), scale=cn)] + base
+ * return a corresponding IterSumExpr with extra offset if needed.
* Try to normalize IterSum into a fused IterMark
* \param expr The input sum.
- * \return The split with the fused IterMark if succeed.
+ * \return The sum with the fused IterMark and extra offset if succeed.
*/
- Optional<IterSplitExpr> TryFuseIters(IterSumExpr expr) {
- if (!is_zero(expr->base)) return NullOpt;
+ Optional<IterSumExpr> TryFuseIters(IterSumExpr expr) {
// select the iterators in order
std::vector<bool> visited(expr->args.size(), false);
std::vector<IterSplitExpr> flattened_iters, grouped_iters;
@@ -530,8 +608,13 @@ class IterMapRewriter : public ExprMutator {
}
}
}
- if (!base_scale) return NullOpt;
+ if (!base_scale) {
+ diag_ctx_.Emit(Diagnostic::Error(expr->span)
+ << "Fuse iters failed, can not find a valid base scale");
+ return NullOpt;
+ }
// check if it can be remapped into a fused pattern.
+ PrimExpr expected_extra_base = 0;
PrimExpr expected_scale = base_scale.value();
for (size_t i = 0; i < expr->args.size();) {
// find j such that expr->args[j] has expected scale
@@ -539,7 +622,11 @@ class IterMapRewriter : public ExprMutator {
for (; j < expr->args.size(); ++j) {
if (!visited[j] && analyzer_->CanProveEqual(expr->args[j]->scale, expected_scale)) break;
}
- if (j == expr->args.size()) return NullOpt;
+ if (j == expr->args.size()) {
+ diag_ctx_.Emit(Diagnostic::Error(expr->span)
+ << "Fuse iters failed, can not find expected scale " << expected_scale);
+ return NullOpt;
+ }
// look for the longest constrained iter started from expr->args[j]
// Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
// predicate: j*2 + k < 9
@@ -569,15 +656,21 @@ class IterMapRewriter : public ExprMutator {
break;
}
}
- if (k == expr->args.size()) return NullOpt;
+ if (k == expr->args.size()) {
+ diag_ctx_.Emit(Diagnostic::Error(expr->span)
+ << "Fuse iters failed, can not find flattened iter match constraint "
+ << constraint_to_match.value());
+ return NullOpt;
+ }
visited[k] = true;
flattened_iters.push_back(expr->args[k]);
}
auto iter = sum_fuse_map_.find(constraint_to_match.value());
ICHECK(iter != sum_fuse_map_.end());
- IterMark iter_matched = iter->second;
- grouped_iters.emplace_back(iter_matched, expected_scale);
- expected_scale *= iter_matched->extent;
+ const IterMarkWithOffset& iter_matched = iter->second;
+ grouped_iters.emplace_back(iter_matched.mark, expected_scale);
+ expected_extra_base += iter_matched.offset * expected_scale;
+ expected_scale *= iter_matched.mark->extent;
// move forward
i += constraint_to_match.value()->args.size();
} else {
@@ -594,18 +687,28 @@ class IterMapRewriter : public ExprMutator {
IterSumExpr structured_form = expr, flattened_form = expr;
flattened_form.CopyOnWrite()->args =
Array<IterSplitExpr>(flattened_iters.rbegin(), flattened_iters.rend());
+ flattened_form.CopyOnWrite()->base = 0;
structured_form.CopyOnWrite()->args =
Array<IterSplitExpr>(grouped_iters.rbegin(), grouped_iters.rend());
+ structured_form.CopyOnWrite()->base = 0;
auto it = sum_fuse_map_.find(flattened_form);
if (it != sum_fuse_map_.end()) {
// old iter
- return IterSplitExpr(it->second, base_scale.value());
+ if (!analyzer_->CanProveEqual(expected_extra_base, it->second.offset * base_scale.value())) {
+ // the extra offset is not consistent with old
+ diag_ctx_.Emit(Diagnostic::Error(expr->span)
+ << "Fuse iters failed, the extra offset is not consistent with old");
+ return NullOpt;
+ }
+ return IterSumExpr({IterSplitExpr(it->second.mark, base_scale.value())},
+ expr->base + expected_extra_base);
} else {
// new iter, form a new mark
IterMark mark = IterMark(structured_form, div(expected_scale, base_scale.value()));
- sum_fuse_map_[flattened_form] = mark;
+ sum_fuse_map_[flattened_form] = IterMarkWithOffset(mark, 0);
flattened_map_[structured_form] = flattened_form;
- return IterSplitExpr(mark, base_scale.value());
+ return IterSumExpr({IterSplitExpr(mark, base_scale.value())},
+ expr->base + expected_extra_base);
}
}
@@ -667,34 +770,126 @@ class IterMapRewriter : public ExprMutator {
struct IterConstraint {
// The expr of the iter
PrimExpr iter;
+ // The expr of the lower_bound
+ PrimExpr lower_bound;
// The expr of the upper_bound
PrimExpr upper_bound;
// The size of the iter, which is the number of nodes
size_t expr_size = 0;
- IterConstraint(PrimExpr iter, PrimExpr upper_bound, size_t size)
- : iter(std::move(iter)), upper_bound(std::move(upper_bound)), expr_size(size) {}
+ IterConstraint(PrimExpr iter, PrimExpr lower_bound, PrimExpr upper_bound, size_t size)
+ : iter(std::move(iter)),
+ lower_bound(std::move(lower_bound)),
+ upper_bound(std::move(upper_bound)),
+ expr_size(size) {}
};
/*!
* \brief Split the predicate into `(a < b) && (c < d) && ...`
* \param pred The predicate to be split.
- * \return A list of pairs, each element of which are lhs and rhs of the '<' sign,
- * empty if the split failed.
+ * \return A list of IterConstraint, empty if the split failed.
*/
-std::vector<IterConstraint> MatchUpperBoundConstraints(PrimExpr pred) {
+std::vector<IterConstraint> MatchBoundConstraints(PrimExpr pred,
+ const Map<Var, Range>& input_iters) {
std::vector<IterConstraint> result;
arith::PVar<PrimExpr> lhs, rhs, rest;
for (;;) {
- if ((rest && (lhs < rhs)).Match(pred)) {
- result.emplace_back(lhs.Eval(), rhs.Eval(), 0);
- pred = rest.Eval();
+ // try extract comparisions
+ bool is_finish = false;
+ bool is_greater = false;
+ bool is_equal = false;
+ if ((rest && (lhs < rhs)).Match(pred) || ((lhs < rhs) && rest).Match(pred)) {
+ // pass
} else if ((lhs < rhs).Match(pred)) {
- result.emplace_back(lhs.Eval(), rhs.Eval(), 0);
- break;
+ is_finish = true;
+ } else if ((rest && (lhs <= rhs)).Match(pred) || ((lhs <= rhs) && rest).Match(pred)) {
+ is_equal = true;
+ } else if ((lhs <= rhs).Match(pred)) {
+ is_equal = true;
+ is_finish = true;
+ } else if ((rest && (lhs > rhs)).Match(pred) || ((lhs > rhs) && rest).Match(pred)) {
+ is_greater = true;
+ } else if ((lhs > rhs).Match(pred)) {
+ is_greater = true;
+ is_finish = true;
+ } else if ((rest && (lhs >= rhs)).Match(pred) || ((lhs >= rhs) && rest).Match(pred)) {
+ is_greater = true;
+ is_equal = true;
+ } else if ((lhs >= rhs).Match(pred)) {
+ is_greater = true;
+ is_equal = true;
+ is_finish = true;
} else {
return std::vector<IterConstraint>();
}
+ PrimExpr lhs_expr = lhs.Eval();
+ PrimExpr rhs_expr = rhs.Eval();
+ // we only accept predicate of integers
+ if (!((lhs_expr->dtype.is_int() || lhs_expr->dtype.is_uint()) &&
+ (rhs_expr->dtype.is_int() || rhs_expr->dtype.is_uint()))) {
+ return std::vector<IterConstraint>();
+ }
+ // determine iter and bound, if we can not distinguish them simply,
+ // try divide (lhs - rhs) into itervar aware and itervar free parts
+ auto f_use_itervar = [&input_iters](const VarNode* v) {
+ return input_iters.count(GetRef<Var>(v));
+ };
+ bool bound_at_left;
+ if (is_const_int(lhs_expr) || !UsesVar(lhs_expr, f_use_itervar)) {
+ bound_at_left = true;
+ } else if (is_const_int(rhs_expr) || !UsesVar(rhs_expr, f_use_itervar)) {
+ bound_at_left = false;
+ } else {
+ bound_at_left = false; // accumulate bound to rhs
+ PrimExpr sum_parts = lhs_expr - rhs_expr;
+ lhs_expr = 0;
+ rhs_expr = 0;
+ std::function<void(const PrimExpr&, bool)> f_extract =
+ [&lhs_expr, &rhs_expr, f_use_itervar, &f_extract](const PrimExpr& part, bool sign) {
+ if (const AddNode* add = part.as<AddNode>()) {
+ f_extract(add->a, sign);
+ f_extract(add->b, sign);
+ } else if (const SubNode* sub = part.as<SubNode>()) {
+ f_extract(sub->a, sign);
+ f_extract(sub->b, !sign);
+ } else if (UsesVar(part, f_use_itervar)) {
+ lhs_expr = sign ? lhs_expr + part : lhs_expr - part;
+ } else {
+ rhs_expr = sign ? rhs_expr - part : rhs_expr + part;
+ }
+ };
+ f_extract(sum_parts, true);
+ arith::Analyzer analyzer;
+ lhs_expr = analyzer.Simplify(lhs_expr);
+ rhs_expr = analyzer.Simplify(rhs_expr);
+ }
+ PrimExpr lower_bound, upper_bound, iter;
+ if (is_greater) {
+ if (bound_at_left) {
+ // bound > iter
+ upper_bound = is_equal ? lhs_expr + 1 : lhs_expr;
+ iter = rhs_expr;
+ } else {
+ // iter > bound
+ lower_bound = is_equal ? rhs_expr : rhs_expr + 1;
+ iter = lhs_expr;
+ }
+ } else {
+ if (bound_at_left) {
+ // bound < iter
+ lower_bound = is_equal ? lhs_expr : lhs_expr + 1;
+ iter = rhs_expr;
+ } else {
+ // iter < bound
+ upper_bound = is_equal ? rhs_expr + 1 : rhs_expr;
+ iter = lhs_expr;
+ }
+ }
+ result.emplace_back(iter, lower_bound, upper_bound, 0);
+ if (is_finish) {
+ break;
+ }
+ pred = rest.Eval();
}
return result;
}
@@ -711,14 +906,17 @@ bool IterRangeSanityCheck(const Map<Var, Range>& iter_ranges) {
Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& predicate, bool require_bijective,
- arith::Analyzer* analyzer) {
+ arith::Analyzer* analyzer, DiagnosticContext diag_ctx) {
// Overall detection algorithm is divided into two steps:
// - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns.
// - Step1: IterIndependenceChecker checks if the iterator are independent.
-
if (!IterRangeSanityCheck(input_iters)) return Array<IterSumExpr>();
- std::vector<IterConstraint> constraints = MatchUpperBoundConstraints(predicate);
- if (!is_one(predicate) && constraints.empty()) return Array<IterSumExpr>();
+ std::vector<IterConstraint> constraints = MatchBoundConstraints(predicate, input_iters);
+ if (!is_one(predicate) && constraints.empty()) {
+ diag_ctx.Emit(Diagnostic::Error(predicate->span)
+ << "Fail to collect constraints from iteration predicate: " << predicate);
+ return Array<IterSumExpr>();
+ }
// We have to make sure when we visit an iterator, all the constraints related with its successors
// in the iter var graph has been visited, where the expression of this iterator will contain the
@@ -731,13 +929,17 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
constraints.begin(), constraints.end(),
[](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size; });
- IterMapRewriter rewriter(analyzer, input_iters);
+ IterMapRewriter rewriter(analyzer, input_iters, diag_ctx);
// Step0.0: rewrite constraints in the order from size-small ones to size-big ones
for (const IterConstraint& constraint : constraints) {
- PrimExpr res = rewriter.RewriteIterConstraint(constraint.iter, constraint.upper_bound);
+ rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound, constraint.upper_bound);
if (rewriter.unresolved_count() != 0) return Array<IterSumExpr>();
}
- if (!rewriter.CheckConstraints()) return Array<IterSumExpr>();
+ if (!rewriter.CheckConstraints()) {
+ diag_ctx.Emit(Diagnostic::Error(predicate->span)
+ << "Illegal iteration constraints: " << predicate);
+ return Array<IterSumExpr>();
+ }
// Step0.1: rewrite indices
Array<IterSumExpr> results;
for (PrimExpr value : indices) {
@@ -745,7 +947,10 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
if (rewriter.unresolved_count() != 0) return Array<IterSumExpr>();
}
// Step1: IterIndependenceChecker checks if the iterator are independent.
- if (!rewriter.CheckMapping(results, require_bijective)) return Array<IterSumExpr>();
+ if (!rewriter.CheckMapping(results, require_bijective)) {
+ diag_ctx.Emit(Diagnostic::Error(predicate->span) << "Iterators are not independent");
+ return Array<IterSumExpr>();
+ }
return results;
}
@@ -754,7 +959,8 @@ TVM_REGISTER_GLOBAL("arith.DetectIterMap")
.set_body_typed([](const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& input_pred, bool is_bijective) {
arith::Analyzer ana;
- return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana);
+ DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule()));
+ return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana, diag_ctx);
});
PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) {
@@ -768,7 +974,6 @@ PrimExpr IterMapRewriter::VisitExpr_(const AddNode* op) {
if (!IsIndexType(op->dtype)) {
return Parent::VisitExpr_(op);
}
-
PrimExpr a = this->DirectMutate(op->a);
PrimExpr b = this->DirectMutate(op->b);
@@ -858,7 +1063,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) {
if (a->IsInstance<IterMapExprNode>() && b->IsInstance<IterMapExprNode>()) {
// cannot multiply two iterators, mark as unresolved.
- ++unresolved_count_;
+ Fail(Diagnostic::Error(op->span) << "Cannot multiply two iterators: " << GetRef<PrimExpr>(op));
return GetRef<PrimExpr>(op);
}
@@ -894,7 +1099,9 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs,
lhs.CopyOnWrite()->scale = make_const(rhs->dtype, 1);
} else {
// mark as unresolved.
- ++unresolved_count_;
+ Fail(Diagnostic::Error(orig->span)
+ << "Can not prove floordiv rhs " << rhs << " divisible by lhs scale " << lhs->scale
+ << ", lhs=" << lhs);
return orig;
}
}
@@ -916,7 +1123,8 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs,
return std::move(lhs);
} else {
// mark as unresolved.
- ++unresolved_count_;
+ Fail(Diagnostic::Error(orig->span)
+ << "Can not prove floordiv lhs extent " << lhs->extent << " divisible by rhs " << rhs);
return orig;
}
}
@@ -944,16 +1152,24 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) {
if (b->IsInstance<IterMapExprNode>()) {
// cannot divide an iterator, mark as unresolved.
- ++unresolved_count_;
+ Fail(Diagnostic::Error(op->span) << "Cannot divide an iterator: " << GetRef<PrimExpr>(op));
return GetRef<PrimExpr>(op);
}
if (a->IsInstance<IterSumExprNode>()) {
IterSumExpr ret = Downcast<IterSumExpr>(a);
- if (Optional<IterSplitExpr> opt = TryFuseIters(ret)) {
- return SplitFloorDivConst(opt.value(), b, GetRef<PrimExpr>(op));
+ if (Optional<IterSumExpr> opt = TryFuseIters(ret)) {
+ IterSumExpr sum = opt.value();
+ if (!is_zero(sum->base)) {
+ Fail(Diagnostic::Error(op->span)
+ << "Fuse IterSumExpr " << ret
+ << " failed, cannot floordiv an IterSumExpr with nonzero base");
+ return GetRef<PrimExpr>(op);
+ }
+ ICHECK_EQ(sum->args.size(), 1U);
+ return SplitFloorDivConst(sum->args[0], b, GetRef<PrimExpr>(op));
} else {
- ++unresolved_count_;
+ Fail(Diagnostic::Error(op->span) << "Fuse IterSumExpr " << ret << " failed");
return GetRef<PrimExpr>(op);
}
} else {
@@ -977,7 +1193,8 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs,
rhs = floordiv(rhs, lhs->scale);
} else {
// mark as unresolved.
- ++unresolved_count_;
+ Fail(Diagnostic::Error(orig->span) << "Can not prove floormod rhs " << rhs
+ << " divisible by " << lhs->scale << ", lhs=" << lhs);
return orig;
}
}
@@ -991,7 +1208,8 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs,
return std::move(lhs);
} else {
// mark as unresolved.
- ++unresolved_count_;
+ Fail(Diagnostic::Error(orig->span)
+ << "Can not prove floormod lhs extent " << lhs->extent << " divisible by rhs " << rhs);
return orig;
}
}
@@ -1019,16 +1237,23 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) {
if (b->IsInstance<IterMapExprNode>()) {
// cannot mod an iterator, mark as unresolved.
- ++unresolved_count_;
+ Fail(Diagnostic::Error(op->span) << "Cannot mod an iterator: " << GetRef<PrimExpr>(op));
return GetRef<PrimExpr>(op);
}
if (a->IsInstance<IterSumExprNode>()) {
IterSumExpr ret = Downcast<IterSumExpr>(a);
- if (Optional<IterSplitExpr> opt = TryFuseIters(ret)) {
- return SplitFloorModConst(opt.value(), b, GetRef<PrimExpr>(op));
+ if (Optional<IterSumExpr> opt = TryFuseIters(ret)) {
+ IterSumExpr sum = opt.value();
+ if (!is_zero(sum->base)) {
+ Fail(Diagnostic::Error(op->span)
+ << "Fuse IterSumExpr " << ret
+ << " failed, cannot floormod an IterSumExpr with nonzero base");
+ return GetRef<PrimExpr>(op);
+ }
+ return SplitFloorModConst(sum->args[0], b, GetRef<PrimExpr>(op));
} else {
- ++unresolved_count_;
+ Fail(Diagnostic::Error(op->span) << "Fail to fuse iters of " << ret);
return GetRef<PrimExpr>(op);
}
} else {
@@ -1039,19 +1264,21 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) {
}
/*! * \brief Given an IterVarMapExpr, transform it to normal PrimExpr. */
-class IterMapToExprNormalizer {
+class IterMapToExprNormalizer : public ExprMutator {
public:
explicit IterMapToExprNormalizer(Analyzer* analyzer) : analyzer_(analyzer) {}
- PrimExpr Convert(const IterMapExpr& expr) {
+ PrimExpr Convert(const IterMapExpr& expr) { return VisitExpr(expr); }
+
+ private:
+ /*! \brief Override VisitExpr for iter expr type processing */
+ PrimExpr VisitExpr(const PrimExpr& expr) override {
if (const auto* op = expr.as<IterSplitExprNode>()) {
return ConvertIterSplitExpr(GetRef<IterSplitExpr>(op));
} else if (const auto* op = expr.as<IterSumExprNode>()) {
return ConvertIterSumExpr(GetRef<IterSumExpr>(op));
} else {
- ICHECK(expr.defined());
- LOG(FATAL) << "Unknown IterMapExpr type " << expr->GetTypeKey();
- return 0;
+ return ExprMutator::VisitExpr(expr);
}
}
@@ -1071,7 +1298,7 @@ class IterMapToExprNormalizer {
} else if (const auto* op = expr->source->source.as<IterSumExprNode>()) {
source = ConvertIterSumExpr(GetRef<IterSumExpr>(op));
} else {
- LOG(FATAL) << "Unexpected source of IterSplitExpr";
+ source = VisitExpr(expr->source->source);
}
if (analyzer_->CanProve(expr->extent == expr->source->extent) && is_one(expr->lower_factor)) {
return source * expr->scale;
@@ -1100,8 +1327,9 @@ Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, R
const PrimExpr& input_pred, bool require_bijective) {
if (!IterRangeSanityCheck(input_iters)) return indices;
Analyzer analyzer;
+ DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule()));
Array<IterSumExpr> rewrite =
- DetectIterMap(indices, input_iters, input_pred, require_bijective, &analyzer);
+ DetectIterMap(indices, input_iters, input_pred, require_bijective, &analyzer, diag_ctx);
if (rewrite.empty()) {
return indices;
}
@@ -1128,8 +1356,9 @@ Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, R
class SubspaceDivider {
public:
explicit SubspaceDivider(Analyzer* analyzer, const IterMarkSplitCollector& collector,
- const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& sub_iters)
- : analyzer_(analyzer), collector_(collector), sub_iters_(sub_iters) {}
+ const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& sub_iters,
+ DiagnosticContext diag_ctx)
+ : analyzer_(analyzer), collector_(collector), sub_iters_(sub_iters), diag_ctx_(diag_ctx) {}
size_t unresolved_count() const { return unresolved_count_; }
@@ -1190,7 +1419,10 @@ class SubspaceDivider {
return DivisionResult(IterSumExpr({}, 0), 1, IterSumExpr({}, expr->base), 1);
} else if (expr->args.size() == 1) {
// arg + base, if arg=Y*E(X)+X, then arg+base = Y*E(X)+(X+base)
- if (!is_one(expr->args[0]->scale)) return Fail();
+ if (!is_one(expr->args[0]->scale)) {
+ return Fail(Diagnostic::Error(expr->span)
+ << "Expect split scale be 1, got " << expr->args[0]->scale);
+ }
DivisionResult res = DivideIterSplitExpr(expr->args[0]);
if (!is_zero(expr->base)) res = AddBase(res, expr->base);
return res;
@@ -1208,7 +1440,9 @@ class SubspaceDivider {
DivisionResult arg_division = DivideIterSplitExpr(arg);
IterSplitExpr new_arg;
if (arg_division.IsInner()) {
- if (!inner) return Fail();
+ if (!inner)
+ return Fail(Diagnostic::Error(expr->span)
+ << "Current division is inner but outer division exists for previous args");
new_arg = arg_division.GetInnerAsSplit();
inner_args.push_back(new_arg);
inner = true;
@@ -1217,11 +1451,13 @@ class SubspaceDivider {
outer_args.push_back(new_arg);
inner = false;
} else {
- return Fail();
+ return Fail(Diagnostic::Error(expr->span)
+ << "Division of " << arg << " is neither inner nor outer");
}
extent *= new_arg->extent;
}
- if (!scale_is_one) return Fail();
+ if (!scale_is_one)
+ return Fail(Diagnostic::Error(expr->span) << "Expect all iter sum arg's scale be 1");
bool need_predicate = !analyzer_->CanProveEqual(extent, mark_extent);
const IterMark& outer_mark = MarkFromArgsAndBase(outer_args, 0);
const IterMark& inner_mark = MarkFromArgsAndBase(inner_args, expr->base);
@@ -1240,7 +1476,8 @@ class SubspaceDivider {
inner_preds_ = inner_preds_ && (converter.Convert(inner_source) < mark_extent);
return DivisionResult::Inner(inner_source, mark_extent);
} else {
- return Fail();
+ return Fail(Diagnostic::Error(expr->span)
+ << "Either inner or outer args should exists if need predicate: " << expr);
}
}
return DivisionResult(outer_source, outer_mark->extent, inner_source, inner_mark->extent);
@@ -1250,8 +1487,11 @@ class SubspaceDivider {
PrimExpr GetInnerPreds() const { return inner_preds_; }
private:
- DivisionResult Fail() {
+ DivisionResult Fail(const Diagnostic& diagnostic) {
unresolved_count_++;
+ if (diag_ctx_.defined()) {
+ diag_ctx_.Emit(diagnostic);
+ }
return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0);
}
@@ -1330,7 +1570,10 @@ class SubspaceDivider {
if (!used[j] && analyzer_->CanProveEqual(splits[j]->lower_factor, expected_lower_factor))
break;
}
- if (j == splits.size()) return Fail();
+ if (j == splits.size())
+ return Fail(Diagnostic::Error(expr->span)
+ << "Can not find expected lower factor " << expected_lower_factor
+ << " in splits of " << expr->source);
used[j] = true;
if (!encountered_boundary) {
inner_iters.push_back(splits[j]);
@@ -1341,7 +1584,9 @@ class SubspaceDivider {
if (analyzer_->CanProveEqual(expected_lower_factor, mark_division.inner_extent))
encountered_boundary = true;
}
- if (!encountered_boundary) return Fail();
+ if (!encountered_boundary)
+ return Fail(Diagnostic::Error(expr->span)
+ << "Can not find inner/outer boundary of " << expr);
for (const IterSplitExpr& inner_iter : inner_iters) {
IterSplitExpr new_iter = inner_iter;
new_iter.CopyOnWrite()->source = inner_mark;
@@ -1355,7 +1600,8 @@ class SubspaceDivider {
split_map_.emplace(outer_iter, DivisionResult::Outer(new_iter, outer_iter->extent));
}
} else {
- return Fail();
+ return Fail(Diagnostic::Error(expr->span)
+ << "Source expr to divide is neither var nor IterSumExpr");
}
return split_map_.at(expr);
}
@@ -1371,15 +1617,18 @@ class SubspaceDivider {
std::unordered_map<IterSplitExpr, DivisionResult, ObjectPtrHash, ObjectPtrEqual> split_map_;
// predicate of outer space and inner space;
PrimExpr outer_preds_{Bool(true)}, inner_preds_{Bool(true)};
+ // diagnostic context
+ DiagnosticContext diag_ctx_;
};
Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
const Map<Var, Range>& input_iters,
const Array<Var>& sub_iters, const PrimExpr& predicate,
- bool require_bijective, arith::Analyzer* analyzer) {
+ bool require_bijective, arith::Analyzer* analyzer,
+ DiagnosticContext diag_ctx) {
if (!IterRangeSanityCheck(input_iters)) return Array<Array<IterMark>>();
const Array<IterSumExpr>& maps =
- DetectIterMap(bindings, input_iters, predicate, require_bijective, analyzer);
+ DetectIterMap(bindings, input_iters, predicate, require_bijective, analyzer, diag_ctx);
if (maps.empty()) return {};
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> inner_iter_set;
@@ -1389,7 +1638,7 @@ Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
IterMarkSplitCollector collector;
collector.Collect(maps);
- SubspaceDivider subspace_divider(analyzer, collector, inner_iter_set);
+ SubspaceDivider subspace_divider(analyzer, collector, inner_iter_set, diag_ctx);
std::vector<Array<IterMark>> results;
for (const IterSumExpr& expr : maps) {
@@ -1409,7 +1658,9 @@ TVM_REGISTER_GLOBAL("arith.SubspaceDivide")
const Array<Var>& sub_iters, const PrimExpr& predicate,
bool require_bijective) {
arith::Analyzer ana;
- return SubspaceDivide(bindings, root_iters, sub_iters, predicate, require_bijective, &ana);
+ DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule()));
+ return SubspaceDivide(bindings, root_iters, sub_iters, predicate, require_bijective, &ana,
+ diag_ctx);
});
class InverseAffineIterMapTransformer {
diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc
index 6d744a6..0a7d57e 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -415,12 +415,14 @@ bool IsAffineBinding(const BlockRealize& realize, const Map<Var, Range>& loop_va
if (loop_var_ranges.empty()) {
return true;
}
+ DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule()));
Array<arith::IterSumExpr> results = arith::DetectIterMap(
/*indices=*/realize->iter_values,
/*input_iters=*/loop_var_ranges,
/*predicate=*/realize->predicate,
/*require_bijective=*/false,
- /*analyzer=*/analyzer);
+ /*analyzer=*/analyzer,
+ /*diag_ctx*/ diag_ctx);
if (results.empty()) {
return false;
}
diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py
index 618fd55..b40f3c9 100644
--- a/tests/python/unittest/test_arith_intset.py
+++ b/tests/python/unittest/test_arith_intset.py
@@ -16,6 +16,8 @@
# under the License.
import tvm
from tvm import te
+from tvm import tir
+from tvm.ir.base import structural_equal
class IntSetChecker:
@@ -233,6 +235,90 @@ def test_region_lower_bound_negative_scale():
assert int_set_1.max_value.value == 35
+def test_region_lower_bound_for_non_perfect_tile():
+ h1 = tvm.tir.Var("h1", "int32")
+ h2 = tvm.tir.Var("h2", "int32")
+ h3 = tvm.tir.Var("h3", "int32")
+ analyzer = tvm.arith.Analyzer()
+
+ def do_test_point_access(point, predicates, var_dom, expect):
+ regions = tvm.arith.estimate_region_lower_bound(
+ region=[
+ tvm.ir.Range.from_min_extent(min_value=point, extent=1),
+ ],
+ var_dom=var_dom,
+ predicate=tvm.tir.all(*predicates),
+ )
+ if expect is None: # expect a failure
+ assert regions is None
+ else:
+ assert len(regions) == 1
+ for binding, expect_min, expect_max in expect:
+ min_diff = expect_min - regions[0].min_value
+ assert analyzer.simplify(tir.stmt_functor.substitute(min_diff, binding), 3) == 0
+ max_diff = expect_max - regions[0].max_value
+ assert analyzer.simplify(tir.stmt_functor.substitute(max_diff, binding), 3) == 0
+
+ # non-uniform tiling, single inner variable
+ # h3 == 0: region is [1, 9]
+ # 0 < h3 <= 26: region is [h3 * 8, h3 * 8 + 9]
+ # h3 > 26: region is [h3 * 8, 223]
+ do_test_point_access(
+ point=h3 * 8 + h2,
+ predicates=[1 <= h3 * 8 + h2, h3 * 8 + h2 < 224],
+ var_dom={
+ h2: tvm.ir.Range(begin=0, end=10),
+ },
+ expect=[
+ (
+ {},
+ tvm.tir.max(h3 * 8, 1),
+ tvm.tir.max(h3 * 8, 1)
+ - tvm.tir.max(h3 * 8, 214)
+ - tvm.tir.max(1 - h3 * 8, 0)
+ + 223,
+ ),
+ ({h3: 0}, 1, 9),
+ ({h3: 10}, h3 * 8, h3 * 8 + 9),
+ ({h3: 27}, h3 * 8, 223),
+ ],
+ )
+
+ # non-uniform tiling, two inner variables
+ do_test_point_access(
+ point=h3 * 8 + h2 * 5 + h1,
+ predicates=[1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h2 * 5 + h1 < 224],
+ var_dom={
+ h2: tvm.ir.Range(begin=0, end=2),
+ h1: tvm.ir.Range(begin=0, end=5),
+ },
+ expect=[
+ (
+ {},
+ tvm.tir.max(h3 * 8, 1),
+ tvm.tir.max(h3 * 8, 1)
+ - tvm.tir.max(h3 * 8, 214)
+ - tvm.tir.max(1 - h3 * 8, 0)
+ + 223,
+ ),
+ ({h3: 0}, 1, 9),
+ ({h3: 10}, h3 * 8, h3 * 8 + 9),
+ ({h3: 27}, h3 * 8, 223),
+ ],
+ )
+
+ # should fail on incompatible predicates
+ do_test_point_access(
+ point=h3 * 8 + h2 * 5 + h1,
+ predicates=[1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h1 * 2 + h2 < 224],
+ var_dom={
+ h2: tvm.ir.Range(begin=0, end=2),
+ h1: tvm.ir.Range(begin=0, end=5),
+ },
+ expect=None,
+ )
+
+
def test_union_lower_bound():
neg_inf = tvm.arith.int_set.neg_inf()
pos_inf = tvm.arith.int_set.pos_inf()
@@ -257,4 +343,5 @@ if __name__ == "__main__":
test_region_lower_bound_split_predicate()
test_region_lower_bound_multiple_variables()
test_region_lower_bound_negative_scale()
+ test_region_lower_bound_for_non_perfect_tile()
test_union_lower_bound()
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py
index c307034..6b3c295 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -199,8 +199,171 @@ def test_predicate():
x = tvm.tir.Var("x", "int32"), 13
y = tvm.tir.Var("y", "int32"), 10
+ # available contraints
+ # upper bound only
res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] < 128)
+ assert len(res) == 1
+ assert_iter_sum_pattern(res[0], 128, 0)
+ res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] <= 127)
+ assert len(res) == 1
+ assert_iter_sum_pattern(res[0], 128, 0)
+ # lower bound only
+ res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] > 5)
+ assert len(res) == 1
+ assert_iter_sum_pattern(res[0], 124, 6)
+ res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 + y[0] >= 6)
+ assert len(res) == 1
+ assert_iter_sum_pattern(res[0], 124, 6)
+
+ # lower bound + upper bound
+ res = tvm.arith.detect_iter_map(
+ [x[0] * 10 + y[0]],
+ var_dom([x, y]),
+ tvm.tir.And(x[0] * 10 + y[0] > 5, x[0] * 10 + y[0] < 128),
+ )
+ assert len(res) == 1
+ assert_iter_sum_pattern(res[0], 122, 6)
+ res = tvm.arith.detect_iter_map(
+ [x[0] * 10 + y[0]],
+ var_dom([x, y]),
+ tvm.tir.And(x[0] * 10 + y[0] >= 6, x[0] * 10 + y[0] <= 127),
+ )
+ assert len(res) == 1
+ assert_iter_sum_pattern(res[0], 122, 6)
+
+ # constraint on one fused iter
+ i = tvm.tir.Var("i", "int32")
+ j = tvm.tir.Var("j", "int32")
+ k = tvm.tir.Var("k", "int32")
+ res = tvm.arith.detect_iter_map(
+ [i * 8 + j * 2 + k],
+ var_dom([(i, 11), (j, 5), (k, 2)]),
+ tvm.tir.all(1 <= j * 2 + k, j * 2 + k < 9),
+ )
+ assert_iter_sum_pattern(res[0], 88, 1)
+
+ # constraint on single var
+ res = tvm.arith.detect_iter_map([i], var_dom([(i, 48)]), tvm.tir.all(i < 10))
+ assert_iter_sum_pattern(res[0], 10, 0)
+
+ # iterations are subparts of constraint, invalid, case 1
+ res = tvm.arith.detect_iter_map(
+ [i, j, k],
+ var_dom([(i, 128), (j, 128), (k, 128)]),
+ tvm.tir.all(i * 16384 + j * 128 + k < 100),
+ )
+ assert len(res) == 0
+
+ # iterations are subparts of constraint, invalid, case 2
+ res = tvm.arith.detect_iter_map(
+ [i * 128 + j, k],
+ var_dom([(i, 128), (j, 128), (k, 128)]),
+ tvm.tir.all(i * 16384 + j * 128 + k < 100),
+ )
+ assert len(res) == 0
+
+ # constraint on nested fused iters
+ res = tvm.arith.detect_iter_map(
+ [i * 8 + j * 2 + k],
+ var_dom([(i, 11), (j, 5), (k, 2)]),
+ tvm.tir.all(1 <= j * 2 + k, j * 2 + k < 9, 3 <= i * 8 + j * 2 + k, i * 8 + j * 2 + k < 25),
+ )
+ assert_iter_sum_pattern(res[0], 22, 3)
+
+ # duplicate constraint on one fused iter
+ res = tvm.arith.detect_iter_map(
+ [i * 6 + j * 2 + k],
+ var_dom([(i, 11), (j, 5), (k, 2)]),
+ tvm.tir.all(1 <= j * 2 + k, 2 <= j * 2 + k, j * 2 + k < 8, j * 2 + k < 9),
+ )
+ assert_iter_sum_pattern(res[0], 66, 2)
+
+ # duplicate constraint on nested fused iters
+ res = tvm.arith.detect_iter_map(
+ [i * 6 + j * 2 + k],
+ var_dom([(i, 11), (j, 5), (k, 2)]),
+ tvm.tir.all(
+ 1 <= j * 2 + k,
+ 2 <= j * 2 + k,
+ j * 2 + k < 8,
+ j * 2 + k < 9,
+ 3 <= i * 6 + j * 2 + k,
+ i * 6 + j * 2 + k < 25,
+ 1 <= i * 6 + j * 2 + k,
+ i * 6 + j * 2 + k < 18,
+ ),
+ )
+ assert_iter_sum_pattern(res[0], 15, 3)
+
+ # constraint on non-disjoint fused iters should fail
+ res = tvm.arith.detect_iter_map(
+ [i * 8 + j * 2 + k],
+ var_dom([(i, 11), (j, 5), (k, 2)]),
+ tvm.tir.all(2 <= j * 2 + k, 0 <= i * 4 + j),
+ )
+ assert len(res) == 0
+
+ # constraint on many disjoint fused iters, case 1
+ # i4 * 6 + i5 in [3, 9), extent=6 (= scale of i2)
+ # i2 * 30 + i3 * 15 in [30, 90), extent=60 (= scale of i1)
+ # i1 * 60 in [60, 240), extent=180 (= scale of i0)
+ i0 = tvm.tir.Var("i0", "int32")
+ i1 = tvm.tir.Var("i1", "int32")
+ i2 = tvm.tir.Var("i2", "int32")
+ i3 = tvm.tir.Var("i3", "int32")
+ i4 = tvm.tir.Var("i4", "int32")
+ i5 = tvm.tir.Var("i5", "int32")
+ res = tvm.arith.detect_iter_map(
+ [i0 * 180 + i1 * 60 + i2 * 30 + i3 * 15 + i4 * 6 + i5],
+ var_dom([(i0, 3), (i1, 4), (i2, 3), (i3, 2), (i4, 3), (i5, 6)]),
+ tvm.tir.all(1 <= i1, 2 <= i2 * 2 + i3, 3 <= i4 * 6 + i5),
+ )
+ assert_iter_sum_pattern(res[0], 540, 93)
+
+ # constraint on many disjoint fused iters, case 2
+ res = tvm.arith.detect_iter_map(
+ [i0 * 45 + i1 * 45 + i2 * 9 + i3 * 4 + i4],
+ var_dom([(i0, 3), (i1, 2), (i2, 5), (i3, 3), (i4, 4)]),
+ tvm.tir.all(3 <= i1 * 5 + i2, i1 * 5 + i2 < 8, 1 <= i3 * 4 + i4, i3 * 4 + i4 < 10),
+ )
+ assert_iter_sum_pattern(res[0], 135, 28)
+
+ # constraint on split iters
+ res = tvm.arith.detect_iter_map(
+ [i % 16, i // 16],
+ var_dom([(i, 1024)]),
+ tvm.tir.all(3 <= i % 16, i % 16 < 10, 4 <= i // 16, i // 16 < 12),
+ require_bijective=True,
+ )
+ assert_iter_sum_pattern(res[0], 7, 3)
+ assert_iter_sum_pattern(res[1], 8, 4)
+
+ # constraint on split iters, nested case 1
+ res = tvm.arith.detect_iter_map(
+ [(i * 32 + j) % 16],
+ var_dom([(i, 5), (j, 32)]),
+ tvm.tir.all(3 <= (i * 32 + j) % 16, (i * 32 + j) % 16 < 10),
+ )
+ assert_iter_sum_pattern(res[0], 7, 3)
+
+ # constraint on split iters, nested case 2
+ res = tvm.arith.detect_iter_map(
+ [(i * 32 + j) % 16],
+ var_dom([(i, 5), (j, 32)]),
+ tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 32),
+ )
+ assert len(res) == 0
+ res = tvm.arith.detect_iter_map(
+ [(i * 32 + j - 1) % 16, (i * 32 + j - 1) // 16],
+ var_dom([(i, 5), (j, 32)]),
+ tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 64),
+ )
+ assert_iter_sum_pattern(res[0], 16, 0)
+ assert_iter_sum_pattern(res[1], 4, 0)
+
+ # non-standard form of predicate
+ res = tvm.arith.detect_iter_map([x[0] * 10 + y[0]], var_dom([x, y]), x[0] * 10 < 128 - y[0])
assert len(res) == 1
assert_iter_sum_pattern(res[0], 128, 0)
@@ -651,6 +814,10 @@ def test_normalize_iter_map_to_expr():
)
tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res[1]), flm(x[0], 5))
+ # iter mark wrap a complex expr
+ split = tvm.arith.IterSplitExpr(tvm.arith.IterMark(x[0] * y[0] + 1, 1024), 1, 1024, 1)
+ tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(split), x[0] * y[0] + 1)
+
def test_inverse_affine_iter_map():
analyzer = tvm.arith.Analyzer()
@@ -712,6 +879,38 @@ def test_inverse_affine_iter_map():
assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0
+def test_free_variables():
+ x = tvm.tir.Var("x", "int32")
+ y = tvm.tir.Var("y", "int32")
+ z = tvm.tir.Var("z", "int32")
+
+ # illegal iter if z is within dom
+ res = tvm.arith.detect_iter_map([z * 19 + y * 3 + x], var_dom([(x, 3), (y, 3), (z, 3)]))
+ assert len(res) == 0
+
+ # iter is valid if z is free, even there are linear forms of z
+ res = tvm.arith.detect_iter_map(
+ [z * 19 + y * 3 + x],
+ var_dom(
+ [
+ (x, 3),
+ (y, 3),
+ ]
+ ),
+ )
+ assert_iter_sum_pattern(res[0], 9, z * 19)
+ res = tvm.arith.detect_iter_map(
+ [z * z + y * 3 + x],
+ var_dom(
+ [
+ (x, 3),
+ (y, 3),
+ ]
+ ),
+ )
+ assert_iter_sum_pattern(res[0], 9, z * z)
+
+
if __name__ == "__main__":
test_split()
test_trivial()
@@ -722,3 +921,4 @@ if __name__ == "__main__":
test_subspace_division()
test_complex()
test_inverse_affine_iter_map()
+ test_free_variables()
diff --git a/tests/python/unittest/test_tir_schedule_reorder.py b/tests/python/unittest/test_tir_schedule_reorder.py
index 8267a36..fd2d82d 100644
--- a/tests/python/unittest/test_tir_schedule_reorder.py
+++ b/tests/python/unittest/test_tir_schedule_reorder.py
@@ -217,9 +217,8 @@ def test_reorder_with_predicate():
sch = tir.Schedule(elementwise_predicate, debug_mask="all")
block_b = sch.get_block("B")
i, j, k, l = sch.get_loops(block_b)
- sch.reorder(l, i)
- tvm.ir.assert_structural_equal(elementwise_reordered_with_predicate, sch.mod["main"])
- verify_trace_roundtrip(sch=sch, mod=elementwise_predicate)
+ with pytest.raises(tvm.tir.ScheduleError):
+ sch.reorder(l, i)
def test_reorder_fail_with_multi_appearance_loops():
diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py b/tests/python/unittest/test_tir_schedule_rfactor.py
index 35d4f5f..f5fc5a7 100644
--- a/tests/python/unittest/test_tir_schedule_rfactor.py
+++ b/tests/python/unittest/test_tir_schedule_rfactor.py
@@ -690,11 +690,9 @@ def test_reduction_rfactor_predicate(): # pylint: disable=invalid-name
s = tir.Schedule(rowsum_predicate, debug_mask="all")
B = s.get_block("B")
_, ko, _ = s.get_loops(B)
- rf_block = s.rfactor(ko, 1)
- tvm.ir.assert_structural_equal(s.mod["main"], rowsum_predicate_rfactor)
- assert s.get(rf_block).same_as(s.get(s.get_block("B_rf")))
- assert s.get(B).same_as(s.get(s.get_block("B")))
- verify_trace_roundtrip(s, mod=rowsum_predicate)
+ # TODO: should be a tvm.tir.ScheduleError
+ with pytest.raises(tvm.TVMError):
+ rf_block = s.rfactor(ko, 1)
def test_reduction_rfactor_with_annotation():
diff --git a/tests/python/unittest/test_tir_schedule_state_cached_flags.py b/tests/python/unittest/test_tir_schedule_state_cached_flags.py
index e3bd000..d86af72 100644
--- a/tests/python/unittest/test_tir_schedule_state_cached_flags.py
+++ b/tests/python/unittest/test_tir_schedule_state_cached_flags.py
@@ -314,6 +314,45 @@ def warp_memory_negative(a: T.handle, c: T.handle) -> None:
C[warp_id * 32 + lane_id, vj] = B[vj, warp_id, lane_id] + 1.0
+@T.prim_func
+def non_perfect_tiling_cache(a: T.handle, b: T.handle) -> None:
+ X = T.match_buffer(a, [224, 224], dtype="float32")
+ Y = T.match_buffer(b, [224, 224], dtype="float32")
+ cache = T.alloc_buffer([224, 224], dtype="float32")
+ for hh_0, ww_0 in T.grid(28, 28):
+ for ax0 in T.serial(0, 10):
+ for ax1 in T.serial(0, 10):
+ with T.block("cache"):
+ h = T.axis.spatial(224, hh_0 * 8 - 1 + ax0)
+ w = T.axis.spatial(224, ww_0 * 8 - 1 + ax1)
+ T.where(
+ 1 <= hh_0 * 8 + ax0
+ and hh_0 * 8 + ax0 < 225
+ and 1 <= ww_0 * 8 + ax1
+ and ww_0 * 8 + ax1 < 225
+ )
+ cache[h, w] = X[h, w]
+ for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3):
+ with T.block("compute"):
+ h = T.axis.spatial(224, hh_0 * 8 + hh_1)
+ w = T.axis.spatial(224, ww_0 * 8 + ww_1)
+ kh, kw = T.axis.remap("RR", [khh, kww])
+ with T.init():
+ Y[h, w] = 0.0
+ Y[h, w] = T.max(
+ Y[h, w],
+ T.if_then_else(
+ T.likely(1 <= h + kh, dtype="bool")
+ and T.likely(h + kh < 225, dtype="bool")
+ and T.likely(1 <= w + kw, dtype="bool")
+ and T.likely(w + kw < 225, dtype="bool"),
+ cache[h + kh - 1, w + kw - 1],
+ 0.0,
+ dtype="float32",
+ ),
+ )
+
+
# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
@@ -702,5 +741,21 @@ def test_warp_memory_negative():
# pylint: enable=protected-access
+def test_non_perfect_tiling_cache():
+ s = tir.ScheduleState(non_perfect_tiling_cache, debug_mask="all")
+ # pylint: disable=protected-access
+ assert s._get_cached_flags(_get_block(s, "cache")) == CachedFlags(
+ affine_binding=False,
+ region_cover=True,
+ stage_pipeline=True,
+ )
+ assert s._get_cached_flags(_get_block(s, "compute")) == CachedFlags(
+ affine_binding=True,
+ region_cover=False,
+ stage_pipeline=True,
+ )
+ # pylint: enable=protected-access
+
+
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))