You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by cs...@apache.org on 2022/11/16 19:45:15 UTC
[tvm] branch main updated: [TIR][Analysis][Arith] Implement basic data-flow analysis (#13130)
This is an automated email from the ASF dual-hosted git repository.
csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a80cdc26e2 [TIR][Analysis][Arith] Implement basic data-flow analysis (#13130)
a80cdc26e2 is described below
commit a80cdc26e291abc52bbd70c950023d9e0340464d
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Wed Nov 16 13:45:07 2022 -0600
[TIR][Analysis][Arith] Implement basic data-flow analysis (#13130)
An optional utility to track known buffer values through a TIR PrimFunc, allowing simplifications based on known values.
* Updated documentation following review comments
* Unit tests for rewrites, including negative numerators for div/mod
* Fix linting error
* Added brief description on what a control graph is
* Updates based on review comments
* Updated T.assume(expr) to T.evaluate(T.assume(expr))
---
include/tvm/tir/op_attr_types.h | 31 +
src/arith/conjunctive_normal_form.cc | 26 +-
src/arith/constraint_extract.cc | 39 +-
src/arith/constraint_extract.h | 31 +-
src/arith/ir_visitor_with_analyzer.h | 6 +-
src/arith/rewrite_simplify.cc | 53 +-
src/arith/transitive_comparison_analyzer.cc | 2 +-
src/arith/unwrap_vector_expr.cc | 90 ++
.../{constraint_extract.h => unwrap_vector_expr.h} | 30 +-
src/tir/analysis/control_flow_graph.cc | 1647 ++++++++++++++++++++
src/tir/analysis/control_flow_graph.h | 653 ++++++++
src/tir/transforms/simplify.cc | 105 +-
.../python/unittest/test_arith_rewrite_simplify.py | 61 +-
.../python/unittest/test_tir_transform_simplify.py | 645 +++++++-
14 files changed, 3361 insertions(+), 58 deletions(-)
diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tir/op_attr_types.h
index 6b5d6c48dd..fa409b27d1 100644
--- a/include/tvm/tir/op_attr_types.h
+++ b/include/tvm/tir/op_attr_types.h
@@ -32,6 +32,8 @@
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/packed_func.h>
+#include <ostream>
+
namespace tvm {
namespace tir {
/*!
@@ -92,6 +94,35 @@ enum class CallEffectKind : int {
kControlJump = 6,
};
+inline std::ostream& operator<<(std::ostream& os, CallEffectKind side_effect) {
+ switch (side_effect) {
+ case CallEffectKind::kExprAnnotation:
+ return os << "kExprAnnotation";
+
+ case CallEffectKind::kPure:
+ return os << "kPure";
+
+ case CallEffectKind::kReadState:
+ return os << "kReadState";
+
+ case CallEffectKind::kUpdateState:
+ return os << "kUpdateState";
+
+ case CallEffectKind::kSpecialCallArg:
+ return os << "kSpecialCallArg";
+
+ case CallEffectKind::kEmbedInfo:
+ return os << "kEmbedInfo";
+
+ case CallEffectKind::kControlJump:
+ return os << "kControlJump";
+
+ default:
+ LOG(FATAL) << "Unknown CallEffectKind: " << static_cast<int>(side_effect);
+ return os;
+ }
+}
+
/*! \brief Use integer to record the kind. */
using TCallEffectKind = Integer;
diff --git a/src/arith/conjunctive_normal_form.cc b/src/arith/conjunctive_normal_form.cc
index 19d6a234e6..1c5f31a913 100644
--- a/src/arith/conjunctive_normal_form.cc
+++ b/src/arith/conjunctive_normal_form.cc
@@ -248,14 +248,14 @@ void AndOfOrs::TrySimplifyOr(Key* a_ptr, Key* b_ptr, Analyzer* analyzer) {
Key& a = *a_ptr;
Key& b = *b_ptr;
PrimExpr joint = GetExpr(a) || GetExpr(b);
- PrimExpr simplified = analyzer->Simplify(joint);
+ PrimExpr simplified = analyzer->rewrite_simplify(joint);
if (!ExprDeepEqual()(simplified, joint)) {
if (auto* simplified_or = simplified.as<OrNode>()) {
a = GetKey(simplified_or->a);
b = GetKey(simplified_or->b);
} else {
- a = GetKey(simplified);
- b = key_false_;
+ a = key_false_;
+ b = GetKey(simplified);
}
}
}
@@ -264,14 +264,14 @@ void AndOfOrs::TrySimplifyAnd(Key* a_ptr, Key* b_ptr, Analyzer* analyzer) {
Key& a = *a_ptr;
Key& b = *b_ptr;
PrimExpr joint = GetExpr(a) && GetExpr(b);
- PrimExpr simplified = analyzer->Simplify(joint);
+ PrimExpr simplified = analyzer->rewrite_simplify(joint);
if (!ExprDeepEqual()(simplified, joint)) {
if (auto* simplified_and = simplified.as<AndNode>()) {
a = GetKey(simplified_and->a);
b = GetKey(simplified_and->b);
} else {
- a = GetKey(simplified);
- b = key_true_;
+ a = key_true_;
+ b = GetKey(simplified);
}
}
}
@@ -362,6 +362,20 @@ void AndOfOrs::SimplifyAcrossChunks(Analyzer* analyzer) {
// (A or B) and (A or C) => A or (B and C)
auto& key_i = i_chunk[i_distinct_index.value()];
auto& key_j = j_chunk[j_distinct_index.value()];
+
+ // When attempting to simplify (B and C), the analyzer may
+ // assume that A is false.
+ PrimExpr known = [&]() {
+ PrimExpr known = Bool(true);
+ for (const auto& key : i_chunk) {
+ if (&key != &key_i) {
+ known = known && analyzer->Simplify(!GetExpr(key));
+ }
+ }
+ return known;
+ }();
+
+ With<ConstraintContext> context(analyzer, known);
TrySimplifyAnd(&key_i, &key_j, analyzer);
}
}
diff --git a/src/arith/constraint_extract.cc b/src/arith/constraint_extract.cc
index d0bf57497e..b873adcb5c 100644
--- a/src/arith/constraint_extract.cc
+++ b/src/arith/constraint_extract.cc
@@ -31,23 +31,42 @@
namespace tvm {
namespace arith {
-void CollectConstraints(const PrimExpr& expr, Analyzer* analyzer, std::vector<PrimExpr>* collect) {
- collect->push_back(expr);
+template <typename F>
+void CollectConstraints(PrimExpr expr, F callback, bool keep_composite_constraints) {
+ if (keep_composite_constraints) {
+ callback(expr);
+ }
PVar<PrimExpr> x, y;
if ((x && y).Match(expr)) {
- CollectConstraints(x.Eval(), analyzer, collect);
- CollectConstraints(y.Eval(), analyzer, collect);
- } else if ((!(x || y)).Match(expr)) {
- CollectConstraints(analyzer->rewrite_simplify(tir::Not(x.Eval())), analyzer, collect);
- CollectConstraints(analyzer->rewrite_simplify(tir::Not(y.Eval())), analyzer, collect);
+ CollectConstraints(x.Eval(), callback, keep_composite_constraints);
+ CollectConstraints(y.Eval(), callback, keep_composite_constraints);
+ } else if (!keep_composite_constraints) {
+ callback(expr);
+ }
+}
+
+std::vector<PrimExpr> ExtractConstraints(const PrimExpr& expr, bool keep_composite_constraints) {
+ std::vector<PrimExpr> out;
+ CollectConstraints(
+ expr, [&](const PrimExpr& part) { out.push_back(part); }, keep_composite_constraints);
+ return out;
+}
+
+template <typename F>
+void CollectComponents(PrimExpr expr, F callback) {
+ PVar<PrimExpr> x, y;
+ if ((x || y).Match(expr)) {
+ CollectComponents(x.Eval(), callback);
+ CollectComponents(y.Eval(), callback);
+ } else {
+ callback(expr);
}
}
-std::vector<PrimExpr> ExtractConstraints(const PrimExpr& expr) {
+std::vector<PrimExpr> ExtractComponents(const PrimExpr& expr) {
std::vector<PrimExpr> out;
- Analyzer analyzer;
- CollectConstraints(expr, &analyzer, &out);
+ CollectComponents(expr, [&](const PrimExpr& part) { out.push_back(part); });
return out;
}
diff --git a/src/arith/constraint_extract.h b/src/arith/constraint_extract.h
index ea6e0a7441..815eafeebd 100644
--- a/src/arith/constraint_extract.h
+++ b/src/arith/constraint_extract.h
@@ -42,6 +42,35 @@ namespace arith {
* Example: `i==5 || j==3` => `[i==5 || j==3]`
* Example: `!(i>5 || j==3)` => `[!(i==5 || j==3), i<=5, j!=3]`
*
+ * If `keep_composite_constraints` is true (default), a constraint
+ * that can be decomposed will be included in the output. If false,
+ * they will be excluded.
+ *
+ * Example, removing composite: `!(i>5 || j==3)` => `[i<=5, j!=3]`
+ *
+ * Intended for use in bounds analysis or simplification within a
+ * conditional, or identifying independent conditionals that may be
+ * hoisted.
+ *
+ * \param expr The expression to be analyzers
+ *
+ * \param keep_composite_constraints Whether to include composite
+ * constraints in the output.
+ *
+ * \returns A vector of independent constraints
+ */
+std::vector<PrimExpr> ExtractConstraints(const PrimExpr& expr,
+ bool keep_composite_constraints = true);
+
+/* \brief Returns components that are false if the expression is false.
+ *
+ * Utility to break up a boolean expression into independent
+ * components.
+ *
+ * Example: `i==5 || j==3` => `[i==5, j==3]`
+ * Example: `i==5 && j==3` => `[i==5 && j==3]`
+ * Example: `!(i>5 && j==3)` => `[i<=5, j!=3]`
+ *
* Intended for use in bounds analysis or simplification within a
* conditional, or identifying independent conditionals that may be
* hoisted.
@@ -50,7 +79,7 @@ namespace arith {
*
* \returns A vector of independent constraints
*/
-std::vector<PrimExpr> ExtractConstraints(const PrimExpr& expr);
+std::vector<PrimExpr> ExtractComponents(const PrimExpr& expr);
} // namespace arith
} // namespace tvm
diff --git a/src/arith/ir_visitor_with_analyzer.h b/src/arith/ir_visitor_with_analyzer.h
index f41a628f3c..416b2af196 100644
--- a/src/arith/ir_visitor_with_analyzer.h
+++ b/src/arith/ir_visitor_with_analyzer.h
@@ -57,7 +57,11 @@ class IRVisitorWithAnalyzer : public tir::StmtExprVisitor {
/*! \brief internal analyzer field. */
arith::Analyzer analyzer_;
- private:
+ /*! \brief Extract a constraint from a conditional statement
+ *
+ * Intended for preparing argument for use in
+ * `With<ConstraintContext>`.
+ */
PrimExpr ExtractRealCondition(PrimExpr condition) const;
};
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index d0fb943334..e6d876cf5a 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -292,7 +292,7 @@ std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c
// we will compare the already simplified result with the constraint,
// so simplify the constraint as well
PrimExpr new_constraint = operator()(constraint);
- for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint)) {
+ for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint, false)) {
if (SideEffect(subconstraint) <= CallEffectKind::kPure) {
literal_constraints_.push_back(subconstraint);
PrimExpr negation;
@@ -1734,7 +1734,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) {
// Pattern var to match any expression
PVar<PrimExpr> x, y;
// Pattern var match IntImm
- PVar<IntImm> c1, c2;
+ PVar<IntImm> c1, c2, c3;
PVar<int> lanes;
if (op->dtype.lanes() != 1) {
@@ -1761,6 +1761,55 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) {
TVM_TRY_REWRITE(x == c1 && x != c2, x == c1 && c1 != c2);
TVM_TRY_REWRITE(x != c2 && x == c1, x == c1 && c1 != c2);
+
+ TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && floormod(x, c2) == c3, x == c1 * c2 + c3);
+ TVM_TRY_RECURSIVE_REWRITE(floormod(x, c2) == c3 && floordiv(x, c2) == c1, x == c1 * c2 + c3);
+
+ TVM_TRY_RECURSIVE_REWRITE_IF(0 <= x - y * c1 &&
+ x - y * c1<c1, y == floordiv(x, c1), c1.Eval()->value> 0);
+ TVM_TRY_RECURSIVE_REWRITE_IF(x - y * c1 < c1 && 0 <= x - y * c1, y == floordiv(x, c1),
+ c1.Eval()->value > 0);
+
+ TVM_TRY_RECURSIVE_REWRITE(c1 < x - y * c1 && x - y * c1 <= 0, y == floordiv(x, c1));
+ TVM_TRY_RECURSIVE_REWRITE(x - y * c1 < c1 && 0 <= x - y * c1, y == floordiv(x, c1));
+ TVM_TRY_RECURSIVE_REWRITE_IF(0 <= x + y * c2 && x + y * c2 < c1, y == floordiv(x, c1),
+ c2.Eval()->value == -c1.Eval()->value);
+ TVM_TRY_RECURSIVE_REWRITE_IF(x + y * c2 < c1 && 0 <= x + y * c2, y == floordiv(x, c1),
+ c2.Eval()->value == -c1.Eval()->value);
+
+ TVM_TRY_RECURSIVE_REWRITE_IF(x < c1 && floormod(x, c2) < c3,
+ x < c1 - c2 + c3 && floormod(x, c2) < c3,
+ c1.Eval()->value % c2.Eval()->value == 0);
+ TVM_TRY_RECURSIVE_REWRITE_IF(
+ x < c1 && floormod(x, c2) < c3, x < c1 - floormod(c1, c2) + c3 && floormod(x, c2) < c3,
+ (c1.Eval()->value % c2.Eval()->value + c2.Eval()->value) % c2.Eval()->value >
+ c3.Eval()->value);
+
+ TVM_TRY_RECURSIVE_REWRITE_IF(x <= c1 && floormod(x, c2) < c3,
+ x < c1 + 1 - c2 + c3 && floormod(x, c2) < c3,
+ (c1.Eval()->value + 1) % c2.Eval()->value == 0);
+ TVM_TRY_RECURSIVE_REWRITE_IF(
+ x <= c1 && floormod(x, c2) < c3, x < c1 + 1 - floormod(c1, c2) + c3 && floormod(x, c2) < c3,
+ (((c1.Eval()->value + 1) % c2.Eval()->value) + c2.Eval()->value) % c2.Eval()->value >
+ c3.Eval()->value);
+
+ TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && floormod(x, c2) < c3,
+ c1 * c2 <= x && x < c1 * c2 + c3);
+ TVM_TRY_RECURSIVE_REWRITE(floormod(x, c2) < c3 && floordiv(x, c2) == c1,
+ c1 * c2 <= x && x < c1 * c2 + c3);
+ TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && floormod(x, c2) <= c3,
+ c1 * c2 <= x && x <= c1 * c2 + c3);
+ TVM_TRY_RECURSIVE_REWRITE(floormod(x, c2) <= c3 && floordiv(x, c2) == c1,
+ c1 * c2 <= x && x <= c1 * c2 + c3);
+
+ TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && c3 <= floormod(x, c2),
+ c1 * c2 + c3 <= x && x < (c1 + 1) * c2);
+ TVM_TRY_RECURSIVE_REWRITE(c3 <= floormod(x, c2) && floordiv(x, c2) == c1,
+ c1 * c2 + c3 <= x && x < (c1 + 1) * c2);
+ TVM_TRY_RECURSIVE_REWRITE(floordiv(x, c2) == c1 && c3 < floormod(x, c2),
+ c1 * c2 + c3 < x && x < (c1 + 1) * c2);
+ TVM_TRY_RECURSIVE_REWRITE(c3 < floormod(x, c2) && floordiv(x, c2) == c1,
+ c1 * c2 + c3 < x && x < (c1 + 1) * c2);
return ret;
}
diff --git a/src/arith/transitive_comparison_analyzer.cc b/src/arith/transitive_comparison_analyzer.cc
index b71096a479..36c2fb7707 100644
--- a/src/arith/transitive_comparison_analyzer.cc
+++ b/src/arith/transitive_comparison_analyzer.cc
@@ -547,7 +547,7 @@ std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimEx
void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr,
std::vector<Comparison>* vec) {
- for (const auto& subexpr : ExtractConstraints(expr)) {
+ for (const auto& subexpr : ExtractConstraints(expr, false)) {
if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) {
if (auto cmp = FromExpr(subexpr)) {
vec->push_back(cmp.value());
diff --git a/src/arith/unwrap_vector_expr.cc b/src/arith/unwrap_vector_expr.cc
new file mode 100644
index 0000000000..6a3e8c3d43
--- /dev/null
+++ b/src/arith/unwrap_vector_expr.cc
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file unwrap_vector_expr.cc
+ * \brief Utility for tracking currently active constraints
+ */
+
+#include "unwrap_vector_expr.h"
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/op.h>
+
+#include <unordered_map>
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+class Scalarizer : public ExprMutator {
+ public:
+ explicit Scalarizer(PrimExpr lane) : lane_(lane) {}
+
+ PrimExpr VisitExpr_(const RampNode* op) final { return op->base + lane_ * op->stride; }
+
+ PrimExpr VisitExpr_(const BroadcastNode* op) final { return op->value; }
+
+ PrimExpr VisitExpr_(const VarNode* op) final {
+ Var var = GetRef<Var>(op);
+
+ auto it = let_var_remap_.find(op);
+ if (it != let_var_remap_.end()) {
+ return it->second;
+ } else {
+ return ExprMutator::VisitExpr_(op);
+ }
+ }
+ PrimExpr VisitExpr_(const LetNode* op) final {
+ if (op->value.dtype().lanes() == 1) {
+ return ExprMutator::VisitExpr_(op);
+ }
+
+ auto it = let_var_remap_.find(op->var.get());
+ ICHECK(it == let_var_remap_.end()) << "Duplicate binding of variable " << op->var;
+
+ Var new_var(op->var->name_hint + "_scalar", op->var.dtype().element_of());
+ let_var_remap_[op->var.get()] = new_var;
+
+ PrimExpr value = this->VisitExpr(op->value);
+ PrimExpr body = this->VisitExpr(op->body);
+
+ let_var_remap_.erase(op->var.get());
+ return Let(op->var, value, body);
+ }
+
+ private:
+ // The lane to extract
+ PrimExpr lane_;
+
+ // Let binding
+ std::unordered_map<const VarNode*, Var> let_var_remap_;
+};
+
+PrimExpr UnwrapVectorExpr(const PrimExpr& vector_expr, const PrimExpr& lane) {
+ return Scalarizer(lane)(vector_expr);
+}
+
+} // namespace arith
+} // namespace tvm
diff --git a/src/arith/constraint_extract.h b/src/arith/unwrap_vector_expr.h
similarity index 57%
copy from src/arith/constraint_extract.h
copy to src/arith/unwrap_vector_expr.h
index ea6e0a7441..9f18964043 100644
--- a/src/arith/constraint_extract.h
+++ b/src/arith/unwrap_vector_expr.h
@@ -18,13 +18,13 @@
*/
/*!
- * \file contraint_extract.h
+ * \file unwrap_vector_expr.h
*
* \brief Centralized location for extraction of constraints from a boolean expression.
*/
-#ifndef TVM_ARITH_CONSTRAINT_EXTRACT_H_
-#define TVM_ARITH_CONSTRAINT_EXTRACT_H_
+#ifndef TVM_ARITH_UNWRAP_VECTOR_EXPR_H_
+#define TVM_ARITH_UNWRAP_VECTOR_EXPR_H_
#include <tvm/tir/expr.h>
@@ -33,26 +33,24 @@
namespace tvm {
namespace arith {
-/* \brief Returns constraints that are true if the expression is true.
+/* \brief Unwraps a component of a vector expression
*
- * Utility to break up a boolean expression into independent
- * constraints.
+ * Utility to break up a vector expression into a specific component
+ * of the expression.
*
- * Example: `i==5 && j==3` => `[i==5 && j==3, i==5, j==3]`
- * Example: `i==5 || j==3` => `[i==5 || j==3]`
- * Example: `!(i>5 || j==3)` => `[!(i==5 || j==3), i<=5, j!=3]`
+ * Example: `Ramp(start, stride, n)` => `start + stride*lane`
+ * Example: `Broadcast(value, n)` => `value`
+ * Example: `2*Ramp(start, stride, n) + Broadcast(value,n)` => `2*(start + stride*lane) + value`
*
- * Intended for use in bounds analysis or simplification within a
- * conditional, or identifying independent conditionals that may be
- * hoisted.
+ * \param vector_expr The vectorized expression to examine
*
- * \param expr The expression to be analyzers
+ * \param lane Which lane of the vectorized expression to extract.
*
- * \returns A vector of independent constraints
+ * \returns A scalar expression
*/
-std::vector<PrimExpr> ExtractConstraints(const PrimExpr& expr);
+PrimExpr UnwrapVectorExpr(const PrimExpr& vector_expr, const PrimExpr& lane);
} // namespace arith
} // namespace tvm
-#endif // TVM_ARITH_CONSTRAINT_EXTRACT_H_
+#endif // TVM_ARITH_UNWRAP_VECTOR_EXPR_H_
diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tir/analysis/control_flow_graph.cc
new file mode 100644
index 0000000000..42c5c8bb82
--- /dev/null
+++ b/src/tir/analysis/control_flow_graph.cc
@@ -0,0 +1,1647 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file control_flow_graph.cc
+ * \brief Utility to deduce bound of expression
+ */
+
+#include "control_flow_graph.h"
+
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <numeric>
+#include <optional>
+#include <queue>
+#include <set>
+#include <sstream>
+#include <unordered_set>
+
+#include "../../arith/conjunctive_normal_form.h"
+#include "../../arith/constraint_extract.h"
+#include "../../arith/ir_mutator_with_analyzer.h"
+#include "../../arith/ir_visitor_with_analyzer.h"
+#include "../../arith/narrow_predicate_expression.h"
+#include "../../arith/unwrap_vector_expr.h"
+
+namespace tvm {
+namespace tir {
+
+using namespace arith;
+
+namespace {
+bool HasBufferLoad(PrimExpr expr) {
+ struct Visitor : public ExprVisitor {
+ void VisitExpr_(const BufferLoadNode* node) override { found_buffer_load = true; }
+ bool found_buffer_load{false};
+ };
+
+ Visitor visitor;
+ visitor(expr);
+ return visitor.found_buffer_load;
+}
+
+Optional<PrimExpr> SubstituteParamValues(const Array<Var>& param_vars,
+ const Array<PrimExpr>& param_values,
+ const PrimExpr& expr) {
+ ICHECK_EQ(param_vars.size(), param_values.size())
+ << "Expression was defined as having " << param_vars.size() << " parameters, but received "
+ << param_values.size() << " arguments.";
+
+ Map<tir::Var, PrimExpr> var_map;
+ for (size_t i = 0; i < param_values.size(); i++) {
+ var_map.Set(param_vars[i], param_values[i]);
+ }
+
+ return Substitute(expr, var_map);
+}
+} // namespace
+
+PrimExpr BufferTouch::BeforeLoopIteration() const {
+ PrimExpr loop_predicate = Bool(true);
+ for (auto it = loop_var_expressions.rbegin(); it != loop_var_expressions.rend(); it++) {
+ const Var& loop_var = it->first;
+ const PrimExpr& loop_expr = it->second;
+ loop_predicate = (loop_var <= loop_expr) || ((loop_var == loop_expr) && loop_predicate);
+ }
+ return loop_predicate;
+}
+
+PrimExpr BufferTouch::AtLoopIteration() const {
+ PrimExpr loop_predicate = Bool(true);
+ for (auto it = loop_var_expressions.rbegin(); it != loop_var_expressions.rend(); it++) {
+ const Var& loop_var = it->first;
+ const PrimExpr& loop_expr = it->second;
+ loop_predicate = (loop_var == loop_expr) && loop_predicate;
+ }
+ return loop_predicate;
+}
+
+PrimExpr BufferTouch::AfterLoopIteration() const {
+ PrimExpr loop_predicate = Bool(true);
+ for (auto it = loop_var_expressions.rbegin(); it != loop_var_expressions.rend(); it++) {
+ const Var& loop_var = it->first;
+ const PrimExpr& loop_expr = it->second;
+ loop_predicate = (loop_var >= loop_expr) || ((loop_var == loop_expr) && loop_predicate);
+ }
+ return loop_predicate;
+}
+
+bool BufferTouch::IsSubsetOf(const BufferTouch& other, Analyzer* analyzer) const {
+ if (this->buffer.same_as(other.buffer)) {
+ With<ConstraintContext> constraint(analyzer, predicate);
+
+ return analyzer->CanProve(other.predicate);
+ } else {
+ return false;
+ }
+}
+
+bool BufferTouch::IsDistinctFrom(const BufferTouch& other, Analyzer* analyzer) const {
+ if (this->buffer.same_as(other.buffer)) {
+ With<ConstraintContext> constraint(analyzer, predicate);
+
+ return analyzer->CanProve(!other.predicate);
+ } else {
+ return true;
+ }
+}
+
+std::ostream& operator<<(std::ostream& os, const BufferTouch& tp) {
+ auto touch_type = [&]() {
+ if (tp.touch_type == BufferTouch::AccessType::Read) {
+ return "read";
+ } else if (tp.touch_type == BufferTouch::AccessType::Write) {
+ return "write";
+ } else if (tp.touch_type == BufferTouch::AccessType::Assume) {
+ return "assume";
+ } else {
+ return "???";
+ }
+ }();
+
+ os << "BufferTouch(" << tp.buffer->name << ", " << touch_type << ", " << tp.predicate
+ << ", value = " << tp.value << ")";
+ return os;
+}
+
+class BufferConstraintApply : public IRMutatorWithAnalyzer {
+ public:
+ using Parent = IRMutatorWithAnalyzer;
+
+ BufferConstraintApply(const Map<Buffer, Array<Var>>& axis_var_lookup,
+ const std::vector<BufferTouch>& knowns, Analyzer* analyzer)
+ : Parent(analyzer), axis_var_lookup_(axis_var_lookup), knowns_(knowns) {}
+
+ using Parent::VisitExpr_;
+
+ PrimExpr VisitExpr_(const BufferLoadNode* op) override {
+ for (const auto& known : knowns_) {
+ if (!op->buffer.same_as(known.buffer)) {
+ continue;
+ }
+
+ Optional<Var> lane_var = NullOpt;
+ IntImm num_lanes;
+
+ Array<PrimExpr> indices = op->indices.Map([&](const auto& index) {
+ if (index.dtype().lanes() == 1) {
+ return index;
+ } else {
+ ICHECK(!lane_var) << "Multiple indices found with non-scalar values";
+ lane_var = Var("lane", index.dtype().element_of());
+ num_lanes = IntImm(index.dtype().element_of(), index.dtype().lanes());
+ return UnwrapVectorExpr(index, lane_var.value());
+ }
+ });
+
+ auto axis_vars = axis_var_lookup_.at(op->buffer);
+ PrimExpr predicate = SubstituteParamValues(axis_vars, indices, known.predicate).value();
+
+ std::optional<With<ConstraintContext>> context;
+ if (lane_var.defined()) {
+ Var lanes = lane_var.value();
+ PrimExpr known = (IntImm(lanes.dtype(), 0) <= lanes) && (lanes < num_lanes);
+ context.emplace(analyzer_, known);
+ }
+
+ if (analyzer_->CanProve(predicate)) {
+ return SubstituteParamValues(axis_vars, op->indices, known.value).value();
+ }
+ }
+
+ return GetRef<PrimExpr>(op);
+ }
+
+ private:
+ const Map<Buffer, Array<Var>>& axis_var_lookup_;
+ const std::vector<BufferTouch>& knowns_;
+};
+
+/*! \brief Extract the control-flow graph
+ *
+ * Walk through a statement, populating the control-flow graph.
+ */
+class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer {
+ public:
+ static void Build(ControlFlowGraph* out, const Stmt& stmt) {
+ ControlFlowGraphBuilder extractor(out);
+ extractor.AppendControlBlock();
+ extractor(stmt);
+ }
+
+ private:
+ ControlFlowGraphBuilder(ControlFlowGraph* out) : out_(out) {}
+
+ using Parent = IRVisitorWithAnalyzer;
+ using Parent::VisitExpr_;
+ using Parent::VisitStmt_;
+
+ void VisitStmt(const Stmt& stmt) override {
+ // Update the lookup table to determine which control-flow block
+ // contains the start of the specified statement. This is used
+ // later to determine which set of known values should be used to
+ // simplify a statement.
+ out_->control_flow_lookup_[stmt.get()] = CurrentControlBlock();
+ Stmt prev_stmt = current_stmt_;
+ current_stmt_ = stmt;
+ Parent::VisitStmt(stmt);
+ current_stmt_ = prev_stmt;
+ }
+
+ void VisitStmt_(const EvaluateNode* op) override {
+ if (auto* call = op->value.as<CallNode>()) {
+ if (call->op.same_as(builtin::assume())) {
+ Assume(call->args[0], true);
+ return;
+ }
+ }
+
+ Parent::VisitStmt_(op);
+ }
+
+ void Assume(PrimExpr assumption, bool from_assume_statement) {
+ for (const auto& expr : ExtractConstraints(assumption, false)) {
+ AssumeConstraintComponent(expr, from_assume_statement);
+ }
+ }
+
+ void AssumeConstraintComponent(PrimExpr assumption, bool from_assume_statement) {
+ PrimExpr additional_predicate = Bool(true);
+
+ std::vector<PrimExpr> buffer_exprs;
+ for (const auto& expr : ExtractComponents(assumption)) {
+ auto side_effect = tir::SideEffect(expr);
+ if (side_effect <= tir::CallEffectKind::kPure) {
+ // Pulling out portions of the assumption that do not depend
+ // on a buffer value allows the following two forms to be
+ // treated identically.
+ //
+ // Option 1: if i < 3: T.assume(buf[i] == value)
+ // Option 2: T.assume(i>=3 or buf[i] == value)
+ additional_predicate = additional_predicate && logical_not(expr);
+ } else if (side_effect == tir::CallEffectKind::kReadState) {
+ buffer_exprs.push_back(expr);
+ } else {
+ LOG(FATAL) << "Assumption must be pure or read-only, but contained expression " << expr
+ << " with side-effect \'" << side_effect << "\'";
+ }
+ }
+
+ if (buffer_exprs.empty()) {
+ out_->non_buffer_assumptions_.push_back(!CurrentScopePredicate() || assumption);
+ return;
+ }
+
+ CHECK_EQ(buffer_exprs.size(), 1) << "T.assume must contain only a single buffer expression";
+
+ auto* as_equal_node = buffer_exprs[0].as<tir::EQNode>();
+ CHECK(as_equal_node || !from_assume_statement)
+ << "T.assume buffer constraint must be of the form 'buffer[indices] == "
+ "value', but received "
+ << assumption;
+ if (!as_equal_node) {
+ // This assumption is an inequality on a data-dependent
+ // conditional. Not an error for this to occur, but also not
+ // something that is currently supported.
+ return;
+ }
+
+ tir::BufferLoad load;
+ PrimExpr value;
+ if (auto* as_load = as_equal_node->a.as<tir::BufferLoadNode>()) {
+ load = GetRef<tir::BufferLoad>(as_load);
+ value = as_equal_node->b;
+ } else if (auto* as_load = as_equal_node->b.as<tir::BufferLoadNode>()) {
+ load = GetRef<tir::BufferLoad>(as_load);
+ value = as_equal_node->a;
+ } else if (!from_assume_statement) {
+ return;
+ } else {
+ LOG(FATAL) << "T.assume buffer constraint must be of the form 'buffer[indices] == value'";
+ }
+
+ auto has_side_effect = tir::SideEffect(value) > tir::CallEffectKind::kPure;
+ CHECK(!has_side_effect || !from_assume_statement)
+ << "Buffer value in constraint must be pure expression, but was " << value;
+ if (has_side_effect) {
+ return;
+ }
+
+ {
+ InternalConstraintContext context(this, additional_predicate);
+ VisitAccess(load, BufferTouch::AccessType::Assume, value);
+ }
+ // Appending a control block ensures that all control blocks have
+ // at most one statement that changes the known buffer contents.
+ auto prev_block = CurrentControlBlock();
+ auto new_block = AppendControlBlock();
+ MarkControlFlow(prev_block, new_block);
+ }
+
+ void VisitExpr_(const LetNode* op) override {
+ std::optional<BindLetVar> binding;
+ if (UsesLoopVar(op->value)) {
+ binding.emplace(this, op->var, op->value);
+ }
+ Parent::VisitExpr_(op);
+ }
+
+ void VisitStmt_(const LetStmtNode* op) override {
+ std::optional<BindLetVar> binding;
+ if (UsesLoopVar(op->value)) {
+ binding.emplace(this, op->var, op->value);
+ }
+ Parent::VisitStmt_(op);
+ }
+
+ void VisitExpr_(const BufferLoadNode* op) override {
+ Parent::VisitExpr_(op);
+ BufferLoad load = GetRef<BufferLoad>(op);
+ VisitAccess(load, BufferTouch::AccessType::Read, load);
+ }
+
+ void VisitStmt_(const BufferStoreNode* op) override {
+ Parent::VisitStmt_(op);
+ VisitAccess(GetRef<BufferStore>(op), BufferTouch::AccessType::Write, op->value);
+ // Appending a control block ensures that all control blocks have
+ // at most one statement that changes the buffer contents.
+ auto prev_block = CurrentControlBlock();
+ auto new_block = AppendControlBlock();
+ MarkControlFlow(prev_block, new_block);
+ }
+
+ void VisitStmt_(const ForNode* op) override {
+ out_->iterator_ranges_.Set(op->loop_var, Range::FromMinExtent(op->min, op->extent));
+
+ auto before_loop = CurrentControlBlock();
+ size_t loop_start = -1;
+
+ {
+ BindActiveLoopVar binding(this, op->loop_var, op->min, op->extent);
+ loop_start = AppendControlBlock();
+ Parent::VisitStmt_(op);
+ }
+
+ auto loop_end = CurrentControlBlock();
+ auto after_loop = AppendControlBlock();
+ PrimExpr max_iterator_value = analyzer_.Simplify(op->min + op->extent - 1);
+ {
+ auto [forward, backward] = MarkControlFlow(before_loop, loop_start);
+ backward.post_condition = (op->loop_var == op->min);
+ forward.var_remap = {{op->loop_var, op->min}};
+ }
+ {
+ auto [forward, backward] = MarkControlFlow(loop_end, after_loop);
+ backward.var_remap = {{op->loop_var, max_iterator_value}};
+ forward.post_condition = (op->loop_var == max_iterator_value);
+ }
+ {
+ auto [forward, backward] = MarkControlFlow(loop_end, loop_start);
+ backward.var_remap = {{op->loop_var, op->loop_var - 1}};
+ forward.var_remap = {{op->loop_var, op->loop_var + 1}};
+ backward.post_condition = (op->loop_var > op->min);
+ forward.post_condition = (op->loop_var < max_iterator_value);
+ }
+ }
+
+ void VisitStmt_(const IfThenElseNode* op) override {
+ this->VisitExpr(op->condition);
+
+ PrimExpr real_condition = ExtractRealCondition(op->condition);
+
+ auto before_branching = CurrentControlBlock();
+
+ auto branch_start = AppendControlBlock();
+ MarkControlFlow(before_branching, branch_start);
+
+ {
+ InternalConstraintContext context(this, real_condition);
+ auto then_start = AppendControlBlock();
+ if (context.assume.defined()) {
+ Assume(context.assume.value(), false);
+ }
+ auto [forward, backward] = MarkControlFlow(branch_start, then_start);
+ backward.post_condition = real_condition;
+ forward.post_condition = real_condition;
+ this->VisitStmt(op->then_case);
+ }
+ auto then_end = CurrentControlBlock();
+
+ auto negation = analyzer_.rewrite_simplify(!real_condition);
+ {
+ InternalConstraintContext context(this, negation);
+ auto else_start = AppendControlBlock();
+ if (context.assume.defined()) {
+ Assume(context.assume.value(), false);
+ }
+ auto [forward, backward] = MarkControlFlow(branch_start, else_start);
+ backward.post_condition = negation;
+ forward.post_condition = negation;
+
+ if (op->else_case.defined()) {
+ this->VisitStmt(op->else_case.value());
+ }
+ }
+
+ auto else_end = CurrentControlBlock();
+ auto after_branching = AppendControlBlock();
+
+ if (HasBufferLoad(real_condition)) {
+ // The buffer value may have changed during the body of the
+ // condition, so we can't provide it as a post-condition.
+ MarkControlFlow(then_end, after_branching);
+ MarkControlFlow(else_end, after_branching);
+ } else {
+ {
+ auto [forward, backward] = MarkControlFlow(then_end, after_branching);
+ backward.post_condition = real_condition;
+ forward.post_condition = real_condition;
+ }
+ {
+ auto [forward, backward] = MarkControlFlow(else_end, after_branching);
+ backward.post_condition = negation;
+ forward.post_condition = negation;
+ }
+ }
+ }
+
+ /*! \brief Internal utility, returns true if the expression depends
+ * on a loop iterator
+ */
+ bool UsesLoopVar(const PrimExpr& expr) {
+ return UsesVar(expr, [&](const VarNode* expr_var) {
+ return loop_dependent_vars_.find(expr_var) != loop_dependent_vars_.end();
+ });
+ }
+
+ /*! \brief Record the interaction with the buffer.
+ *
+ * \param node The TIR node that accesses the buffer. Should be
+ * either a BufferLoad or BufferStore node.
+ *
+ * \param touch_type The type of buffer access being performed. A
+ * BufferStore should always use AccessType::Write. A BufferLoad
+ * may use either AccessType::Read or AccessType::Assume, depending
+ * on whether the BufferLoad occurs within `builtin::assume`.
+ *
+ * \param known_value_expr The value in the buffer following the access.
+ */
+ template <typename BufferAccess>
+ void VisitAccess(const BufferAccess& node, BufferTouch::AccessType touch_type,
+ PrimExpr known_value_expr) {
+ auto& current_block = out_->control_flow_.back();
+ BufferTouch buffer_touch = current_block.MakeBufferTouch(out_, node->buffer, node->indices,
+ touch_type, known_value_expr);
+ current_block.touch_points.push_back(buffer_touch);
+ }
+
+ /*! \brief Return a predicate for having reached the current
+ * control-flow block
+ *
+ * For example, while inside an IfThenElse, will return the
+ * IfThenElse's condition.
+ */
+ PrimExpr CurrentScopePredicate() const {
+ PrimExpr predicate = Bool(true);
+ for (const auto& condition : conditions_) {
+ predicate = predicate && condition;
+ }
+ return predicate;
+ }
+
+ /* \brief Add a new control block, returning its index */
+ size_t AppendControlBlock() {
+ size_t index = out_->control_flow_.size();
+ auto& block = out_->control_flow_.emplace_back();
+ block.active_loop_iterators = active_loop_iterators_;
+ block.let_bindings_using_loop = let_bindings_using_loop_;
+ block.scope_predicate = CurrentScopePredicate();
+ return index;
+ }
+
+ /* \brief The index of the current control block */
+ size_t CurrentControlBlock() { return out_->control_flow_.size() - 1; }
+
+ /* \brief Mark a possible control from one block to another
+ *
+ * \param from_block The block from which control leaves
+ *
+ * \param to_block The block to which control enters
+ *
+ * \param var_remap Variable replacements that should be made in
+ * known expression while traversing this edge. For example,
+ * replacing `i` with `i-1` when entering the next loop iteration,
+ * or replacing `i` with `n-1` when concluding a loop.
+ */
+ std::pair<ControlFlowGraph::ControlFlowEdge&, ControlFlowGraph::ControlFlowEdge&> MarkControlFlow(
+ size_t from_block, size_t to_block) {
+ ICHECK_LE(from_block, out_->control_flow_.size());
+ ICHECK_LE(to_block, out_->control_flow_.size());
+
+ auto& forward = out_->control_flow_[from_block].successors.emplace_back(
+ ControlFlowGraph::ControlFlowEdge{to_block, {}, NullOpt});
+ auto& backward = out_->control_flow_[to_block].predecessors.emplace_back(
+ ControlFlowGraph::ControlFlowEdge{from_block, {}, NullOpt});
+ return {forward, backward};
+ }
+
+ // Internal utility, context manager for entering/leaving a scoped constraint
+ struct InternalConstraintContext {
+ InternalConstraintContext(ControlFlowGraphBuilder* self, PrimExpr constraint)
+ : self(self), analyzer_context(&self->analyzer_, constraint) {
+ old_num_constraints = self->conditions_.size();
+
+ auto side_effect = tir::SideEffect(constraint);
+ if (side_effect <= tir::CallEffectKind::kPure) {
+ self->conditions_.push_back(constraint);
+ } else if (side_effect <= tir::CallEffectKind::kReadState) {
+ assume = constraint;
+ }
+
+ new_num_constraints = self->conditions_.size();
+ }
+ ~InternalConstraintContext() {
+ ICHECK_EQ(self->conditions_.size(), new_num_constraints)
+ << "Internal error: Each condition should only be popped once.";
+ self->conditions_.erase(self->conditions_.begin() + old_num_constraints,
+ self->conditions_.end());
+ }
+
+ ControlFlowGraphBuilder* self{nullptr};
+ With<ConstraintContext> analyzer_context;
+ size_t old_num_constraints{0};
+ size_t new_num_constraints{0};
+ Optional<PrimExpr> assume{NullOpt};
+
+ // Disable default-generated copy/move assignment and constructors
+ InternalConstraintContext(const InternalConstraintContext&) = delete;
+ InternalConstraintContext& operator=(const InternalConstraintContext&) = delete;
+ InternalConstraintContext(InternalConstraintContext&&) = delete;
+ InternalConstraintContext& operator=(InternalConstraintContext&&) = delete;
+ };
+
+ // Internal utility, context manager for tracking a loop
+ struct BindActiveLoopVar {
+ BindActiveLoopVar(ControlFlowGraphBuilder* self, Var var, PrimExpr loop_min,
+ PrimExpr loop_extent)
+ : self(self), var(var) {
+ PrimExpr loop_max = loop_min + (loop_extent - 1);
+ auto loop_range = Range::FromMinExtent(loop_min, loop_extent);
+ self->active_loop_iterators_.push_back({var, loop_min, loop_max, loop_range});
+ self->loop_dependent_vars_.insert(var.get());
+ }
+ ~BindActiveLoopVar() { self->active_loop_iterators_.pop_back(); }
+
+ ControlFlowGraphBuilder* self;
+ Var var;
+
+ // Disable default-generated copy/move assignment and constructors
+ BindActiveLoopVar(const BindActiveLoopVar&) = delete;
+ BindActiveLoopVar& operator=(const BindActiveLoopVar&) = delete;
+ BindActiveLoopVar(BindActiveLoopVar&&) = delete;
+ BindActiveLoopVar& operator=(BindActiveLoopVar&&) = delete;
+ };
+
+ // Internal utility, context manager for tracking a variable binding
+ struct BindLetVar {
+ BindLetVar(ControlFlowGraphBuilder* self, Var var, PrimExpr value) : self(self), var(var) {
+ self->let_bindings_using_loop_.Set(var, value);
+ self->loop_dependent_vars_.insert(var.get());
+ }
+ ~BindLetVar() {
+ self->loop_dependent_vars_.erase(var.get());
+ self->let_bindings_using_loop_.erase(var);
+ }
+ ControlFlowGraphBuilder* self;
+ Var var;
+
+ // Disable default-generated copy/move assignment and constructors
+ BindLetVar(const BindLetVar&) = delete;
+ BindLetVar& operator=(const BindLetVar&) = delete;
+ BindLetVar(BindLetVar&&) = delete;
+ BindLetVar& operator=(BindLetVar&&) = delete;
+ };
+
+ struct LoopEntry {
+ Var loop_var;
+ PrimExpr loop_min;
+ PrimExpr loop_max;
+ Range loop_range;
+ };
+
+ // Track in order to know which Vars to write in terms of the buffer
+ // indices and substitute out of the predicate.
+ std::vector<ControlFlowGraph::ControlFlowBlock::LoopEntry> active_loop_iterators_;
+
+ // Track all loop iterators, along with values derived from loop iterators.
+ std::unordered_set<const VarNode*> loop_dependent_vars_;
+
+ // Any let binding that depends, directly or indirectly, on a loop
+ // binding. When making a predicate in terms of the buffer indices,
+ // these need to be substituted out.
+ // std::unordered_map<const VarNode*, PrimExpr> let_bindings_using_loop_;
+ Map<Var, PrimExpr> let_bindings_using_loop_;
+
+ // Track in order to know what conditions limit the buffer access
+ std::vector<PrimExpr> conditions_;
+
+ // Track in order to know what statement initiated the buffer access
+ Stmt current_stmt_;
+
+ // Output data structure
+ ControlFlowGraph* out_;
+};
+
+std::pair<BufferTouch, Map<Var, Range>> ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(
+ const tir::Buffer& buf, Array<Var> index_variables, Array<PrimExpr> indices,
+ BufferTouch::AccessType touch_type, PrimExpr known_value_expr) const {
+ const auto& current_block = *this;
+
+ Analyzer local_analyzer;
+
+ Optional<Var> lane_var = NullOpt;
+ IntImm num_lanes;
+
+ Array<PrimExpr> index_expressions = indices.Map([&](const auto& index) {
+ if (index.dtype().lanes() == 1) {
+ return index;
+ } else {
+ ICHECK(!lane_var) << "Multiple indices found with non-scalar values";
+ lane_var = Var("lane", index.dtype().element_of());
+ num_lanes = IntImm(index.dtype().element_of(), index.dtype().lanes());
+ return UnwrapVectorExpr(index, lane_var.value());
+ }
+ });
+
+ Array<Var> loop_vars;
+
+ Map<Var, Range> loop_ranges;
+ for (const auto& loop_entry : current_block.active_loop_iterators) {
+ loop_vars.push_back(loop_entry.loop_var);
+ loop_ranges.Set(loop_entry.loop_var, loop_entry.loop_range);
+ }
+
+ // If the indices contain multiple lanes, treat the lane variable
+ // as an additional loop iterator to be solved for and substituted
+ // out.
+ if (lane_var) {
+ loop_vars.push_back(lane_var.value());
+ loop_ranges.Set(lane_var.value(), Range::FromMinExtent(0, num_lanes));
+ }
+
+ IntConstraintsTransform transform = [&]() {
+ ICHECK_EQ(index_variables.size(), index_expressions.size());
+
+ Array<PrimExpr> relations;
+
+ for (size_t i = 0; i < index_expressions.size(); i++) {
+ PrimExpr expr = index_expressions[i];
+ Var var = index_variables[i];
+
+ expr = Substitute(expr, current_block.let_bindings_using_loop);
+ relations.push_back(var == expr);
+ }
+
+ IntConstraints system(loop_vars, loop_ranges, relations);
+ return arith::SolveLinearEquations(system);
+ }();
+
+ Map<Var, PrimExpr> loop_var_to_axis_var = transform->src_to_dst;
+ Map<Var, Range> free_params = transform->dst->ranges;
+ PrimExpr transform_predicate =
+ std::accumulate(transform->dst->relations.begin(), transform->dst->relations.end(),
+ PrimExpr(Bool(true)), [](PrimExpr a, PrimExpr b) { return a && b; });
+
+ transform_predicate = SimplifyAsAndOfOrs(transform_predicate, &local_analyzer);
+
+ auto find_removable_params = [&]() -> Map<Var, PrimExpr> {
+ Map<Var, PrimExpr> removable_params;
+
+ // The arith::SolveLinearEquations is more general than the
+ // utilities in iter_affine_map.h, but can introduce free
+ // parameters that could later be determined with the known
+ // constraints. This step removes all such free parameters.
+ for (const auto& expr : ExtractConstraints(transform_predicate)) {
+ if (auto* as_equal = expr.as<EQNode>()) {
+ auto check_expr = [&](const PrimExpr& a, const PrimExpr& b) {
+ auto* var_ptr = a.as<VarNode>();
+ if (!var_ptr) {
+ return;
+ }
+
+ Var var = GetRef<Var>(var_ptr);
+ if (free_params.count(var) == 0) {
+ return;
+ }
+
+ bool uses_free_param =
+ UsesVar(b, [&](const VarNode* v) { return free_params.count(GetRef<Var>(v)) > 0; });
+ if (uses_free_param) {
+ return;
+ }
+ removable_params.Set(var, b);
+ };
+ check_expr(as_equal->a, as_equal->b);
+ check_expr(as_equal->b, as_equal->a);
+ }
+ }
+
+ // In addition, the arith::SolveLinearEquation can introduce
+ // free parameters with an extent of one. Filtering them out here
+ // avoids needing to track them through later simplifications.
+ for (const auto [var, range] : free_params) {
+ if (is_one(range->extent)) {
+ removable_params.Set(var, range->min);
+ }
+ }
+
+ return removable_params;
+ };
+ for (auto removable_params = find_removable_params(); removable_params.size() > 0;
+ removable_params = find_removable_params()) {
+ auto update = [&](const PrimExpr& expr) {
+ return local_analyzer.Simplify(Substitute(expr, removable_params));
+ };
+
+ Map<Var, PrimExpr> new_map;
+ for (const auto [loop_var, expr] : loop_var_to_axis_var) {
+ static_cast<void>(expr); // gcc 7.x bug, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
+ new_map.Set(loop_var, update(expr));
+ }
+ loop_var_to_axis_var = new_map;
+
+ transform_predicate = update(transform_predicate);
+
+ for (const auto [var, expr] : removable_params) {
+ static_cast<void>(expr); // gcc 7.x bug, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
+ free_params.erase(var);
+ }
+ }
+
+ // Normalization function, applied to both the predicate and the
+ // known value. Converts from an expression in terms of loop
+ // iterators to an expression in terms of buffer indices.
+ auto normalize_expr = [&](PrimExpr expr) -> PrimExpr {
+ expr = Substitute(expr, current_block.let_bindings_using_loop);
+
+ if (lane_var) {
+ expr = UnwrapVectorExpr(expr, lane_var.value());
+ }
+ expr = Substitute(expr, loop_var_to_axis_var);
+
+ return expr;
+ };
+
+ // Collect the current loop variables, along with an expression for
+ // the loop variables in terms of the buffer axis variables. This
+ // is used during forward/backward propagation to generate predicate
+ // tracking whether a loop iteration has been reached.
+ std::vector<std::pair<Var, PrimExpr>> loop_var_expressions;
+ for (const auto& entry : current_block.active_loop_iterators) {
+ auto expr_it = loop_var_to_axis_var.find(entry.loop_var);
+ ICHECK(expr_it != loop_var_to_axis_var.end());
+ loop_var_expressions.push_back({entry.loop_var, (*expr_it).second});
+ }
+
+ // The full predicate is composed of the values required to reach
+ // the scope of the BufferStore or builtin::assume(), any bounds
+ // implied by solving for the axis variables, and any additional
+ // statements resulting from unpacking the expression contained in
+ // builtin::assume().
+ PrimExpr scope_predicate = normalize_expr(current_block.scope_predicate);
+ transform_predicate = normalize_expr(transform_predicate);
+
+ known_value_expr = local_analyzer.Simplify(normalize_expr(known_value_expr));
+
+ // Deliberately use an analyzer without scope-based information,
+ // to avoid simplifying `scope_predicate` to True.
+ PrimExpr predicate_expr = local_analyzer.Simplify(transform_predicate && scope_predicate);
+
+ BufferTouch buffer_touch = {buf, predicate_expr, known_value_expr, loop_var_expressions,
+ touch_type};
+
+ return {buffer_touch, free_params};
+}
+
+BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph* graph,
+ const tir::Buffer& buf,
+ const Array<PrimExpr>& indices,
+ BufferTouch::AccessType touch_type,
+ PrimExpr known_value_expr) const {
+ ICHECK(graph);
+ auto [buffer_touch, free_params] = MakeBufferTouch(buf, graph->GetIndexVariables(buf, indices),
+ indices, touch_type, known_value_expr);
+ for (const auto& pair : free_params) {
+ graph->free_predicate_parameters_.Set(pair.first, pair.second);
+ }
+ return buffer_touch;
+}
+
+ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, size_t max_revisits) {
+ ControlFlowGraphBuilder::Build(this, stmt);
+ ForwardPropagateKnownValues(max_revisits);
+ BackwardPropagateUnusedValues(max_revisits);
+}
+
+std::ostream& operator<<(std::ostream& os, const ControlFlowGraph::ControlFlowEdge& edge) {
+ os << edge.index;
+ if (edge.var_remap.size()) {
+ os << " with remap " << edge.var_remap;
+ }
+ if (edge.post_condition) {
+ os << " with postcondition " << edge.post_condition;
+ }
+
+ return os;
+}
+
+std::ostream& operator<<(std::ostream& os, const ControlFlowGraph::ControlFlowBlock& block) {
+ os << "Predecessors: [";
+ for (size_t i = 0; i < block.predecessors.size(); i++) {
+ if (i) {
+ os << ", ";
+ }
+ os << block.predecessors[i];
+ }
+ os << "]\n";
+
+ os << "Active loop iterators: [";
+ for (size_t i = 0; i < block.active_loop_iterators.size(); i++) {
+ if (i) {
+ os << ", ";
+ }
+ os << block.active_loop_iterators[i].loop_var;
+ }
+ os << "]\n";
+
+ os << "Before block knowns: " << block.known_at_block_start << "\n";
+
+ os << "Before block unused: " << block.unused_at_block_start << "\n";
+
+ for (size_t i = 0; i < block.touch_points.size(); i++) {
+ os << "Touch[" << i << "] = " << block.touch_points[i] << "\n";
+ }
+ os << "After block: " << block.known_at_block_end << "\n";
+
+ os << "After block unused: " << block.unused_at_block_end << "\n";
+
+ os << "Successors: [";
+ for (size_t i = 0; i < block.successors.size(); i++) {
+ if (i) {
+ os << ", ";
+ }
+ os << block.successors[i];
+ }
+ os << "]";
+ return os;
+}
+
+std::ostream& operator<<(std::ostream& os, const ControlFlowGraph& pattern) {
+ os << "Touch pattern contains " << pattern.control_flow_.size() << " control blocks."
+ << (pattern.control_flow_.size() ? "\n" : "");
+ for (size_t i = 0; i < pattern.control_flow_.size(); i++) {
+ os << "\t"
+ << "ControlBlock[" << i << "] = " << pattern.control_flow_[i] << "\n";
+ }
+
+ return os;
+}
+
+bool BufferTouch::IsEquivalentTo(const BufferTouch& other, Analyzer* analyzer) const {
+ // Constraints must apply to the same buffer to be equivalent
+ if (!buffer.same_as(other.buffer) || touch_type != other.touch_type) {
+ return false;
+ }
+
+ ExprDeepEqual deep_equal;
+
+ auto implies = [&](const PrimExpr& a, const PrimExpr& b) -> bool {
+ With<ConstraintContext> context(analyzer, a);
+ return analyzer->CanProve(b);
+ };
+
+ // Predicates must be equivalent expressions, or must both be undefined
+ bool equivalent_predicates =
+ deep_equal(predicate, other.predicate) ||
+ (implies(predicate, other.predicate) && implies(other.predicate, predicate));
+ if (!equivalent_predicates) {
+ return false;
+ }
+
+ // The known value must be equal
+ if (!deep_equal(value, other.value) && !analyzer->CanProveEqual(value, other.value)) {
+ return false;
+ }
+
+ return true;
+}
+
+std::ostream& operator<<(std::ostream& os, const BufferState& state) {
+ for (size_t i = 0; i < state.constraints_.size(); i++) {
+ os << "constraints[" << i << "] = " << state.constraints_[i]
+ << (i + 1 == state.constraints_.size() ? "" : "\n");
+ }
+ return os;
+}
+
+PrimExpr BufferState::SubstituteKnownBufferValues(
+ PrimExpr expr, const Map<tir::Buffer, Array<tir::Var>>& axis_var_lookup,
+ Analyzer* analyzer) const {
+ BufferConstraintApply mutator(axis_var_lookup, constraints_, analyzer);
+ return mutator(std::move(expr));
+}
+
+void BufferState::AddCondition(const PrimExpr& condition) {
+ for (auto& constraint : constraints_) {
+ constraint.predicate = constraint.predicate && condition;
+ }
+}
+
+void BufferState::Substitute(const Map<Var, PrimExpr>& var_remap, Analyzer* analyzer) {
+ if (var_remap.size()) {
+ for (auto& prior : constraints_) {
+ PrimExpr updated = tvm::tir::Substitute(prior.predicate, var_remap);
+ if (!updated.same_as(prior.predicate)) {
+ prior.predicate = SimplifyAsAndOfOrs(updated, analyzer);
+ }
+ }
+ }
+}
+
+void BufferState::Simplify(Analyzer* analyzer) {
+ for (auto& constraint : constraints_) {
+ constraint.predicate = SimplifyAsAndOfOrs(constraint.predicate, analyzer);
+ }
+}
+
+void BufferState::Union(const BufferState& b, Analyzer* analyzer) {
+ for (const auto& b_constraint : b.constraints_) {
+ bool used = false;
+ for (auto& a_constraint : constraints_) {
+ if (a_constraint.buffer.same_as(b_constraint.buffer) &&
+ analyzer->CanProveEqual(a_constraint.value, b_constraint.value)) {
+ a_constraint.predicate =
+ SimplifyAsAndOfOrs(a_constraint.predicate || b_constraint.predicate, analyzer);
+ used = true;
+ break;
+ }
+ }
+ if (!used) {
+ constraints_.push_back(b_constraint);
+ }
+ }
+}
+
+void BufferState::Intersection(const BufferState& b, Analyzer* analyzer) {
+ // For a constraint to be in the output, it must be present in both
+ // inputs.
+
+ std::vector<BufferTouch> new_constraints;
+ for (const auto& ai : constraints_) {
+ for (const auto& bi : b.constraints_) {
+ if (ai.buffer.same_as(bi.buffer)) {
+ PrimExpr predicate = SimplifyAsAndOfOrs(ai.predicate && bi.predicate, analyzer);
+ if (!is_zero(predicate)) {
+ With<ConstraintContext> context(analyzer, predicate);
+ PrimExpr known_value_a = ai.value;
+ PrimExpr known_value_b = bi.value;
+
+ bool is_consistent = analyzer->CanProveEqual(known_value_a, known_value_b);
+ if (is_consistent) {
+ new_constraints.push_back({ai.buffer, predicate, known_value_a});
+ }
+ }
+ }
+ }
+ }
+
+ constraints_ = std::move(new_constraints);
+}
+
+class BufferRegionCollector : public ExprVisitor {
+ public:
+ struct Region {
+ PrimExpr region_predicate;
+ std::unordered_map<const BufferLoadNode*, Optional<PrimExpr>> known_values;
+ };
+
+ static std::vector<Region> Collect(const Map<Buffer, Array<Var>>& axis_var_lookup,
+ const std::vector<BufferTouch>& knowns,
+ const std::vector<Optional<PrimExpr>>& exprs,
+ Analyzer* analyzer) {
+ BufferRegionCollector collector(axis_var_lookup, knowns, analyzer);
+ for (const auto& expr : exprs) {
+ if (expr) {
+ collector(expr.value());
+ }
+ }
+
+ return collector.regions_;
+ }
+
+ private:
+ using Parent = ExprVisitor;
+
+ BufferRegionCollector(const Map<Buffer, Array<Var>>& axis_var_lookup,
+ const std::vector<BufferTouch>& knowns, Analyzer* analyzer)
+ : analyzer_(analyzer), axis_var_lookup_(axis_var_lookup), knowns_(knowns) {
+ regions_.push_back(Region{Bool(true), {}});
+ }
+
+ using Parent::VisitExpr_;
+
+ void VisitExpr_(const BufferLoadNode* op) override {
+ // Helper struct for the known values of this BufferLoad
+ struct Known {
+ PrimExpr predicate;
+ Optional<PrimExpr> value;
+ };
+
+ std::vector<Known> new_regions;
+
+ PrimExpr unknown_region = Bool(true);
+
+ for (const BufferTouch& constraint : knowns_) {
+ if (!op->buffer.same_as(constraint.buffer)) {
+ // This is a different buffer, so continue searching.
+ continue;
+ }
+
+ auto axis_vars = axis_var_lookup_.at(op->buffer);
+ PrimExpr touch_predicate =
+ SubstituteParamValues(axis_vars, op->indices, constraint.predicate).value();
+ touch_predicate = SimplifyAsAndOfOrs(touch_predicate, analyzer_);
+
+ if (!is_zero(touch_predicate)) {
+ Optional<PrimExpr> known_value =
+ SubstituteParamValues(axis_vars, op->indices, constraint.value);
+ new_regions.push_back(Known{touch_predicate, known_value});
+
+ unknown_region = unknown_region && !touch_predicate;
+ unknown_region = SimplifyAsAndOfOrs(unknown_region, analyzer_);
+ }
+ }
+
+ if (new_regions.size()) {
+ Analyzer local_analyzer;
+
+ if (!is_zero(unknown_region)) {
+ new_regions.insert(new_regions.begin(), Known{unknown_region, NullOpt});
+ }
+
+ std::vector<Region> updated_regions;
+ for (const auto& prev_region : regions_) {
+ for (const auto& new_region : new_regions) {
+ PrimExpr intersection =
+ SimplifyAsAndOfOrs(prev_region.region_predicate && new_region.predicate, analyzer_);
+
+ if (!is_zero(intersection)) {
+ Region merged{intersection, prev_region.known_values};
+ merged.known_values[op] = new_region.value;
+ updated_regions.push_back(std::move(merged));
+ }
+ }
+ }
+ regions_ = updated_regions;
+ }
+ }
+
+ Analyzer* analyzer_;
+ std::vector<Region> regions_;
+ const Map<Buffer, Array<Var>>& axis_var_lookup_;
+ const std::vector<BufferTouch>& knowns_;
+};
+
+class BufferRegionValueReplacer : public IRMutatorWithAnalyzer {
+ public:
+ static PrimExpr Apply(
+ const std::unordered_map<const BufferLoadNode*, Optional<PrimExpr>>& known_values,
+ PrimExpr expr, Analyzer* analyzer) {
+ BufferRegionValueReplacer mutator(known_values, analyzer);
+ PrimExpr result = mutator(expr);
+ // Simplification must occur after the substitution, as known
+ // values may provide enable simplifications. Also, cannot track
+ // whether a BufferLoad was
+ result = analyzer->Simplify(result);
+ return result;
+ }
+
+ private:
+ using Parent = IRMutatorWithAnalyzer;
+
+ BufferRegionValueReplacer(
+ const std::unordered_map<const BufferLoadNode*, Optional<PrimExpr>>& known_values,
+ Analyzer* analyzer)
+ : Parent(analyzer), known_values_(known_values) {}
+
+ using Parent::VisitExpr_;
+
+ PrimExpr VisitExpr_(const BufferLoadNode* op) override {
+ auto it = known_values_.find(op);
+ if (it != known_values_.end() && it->second) {
+ return it->second.value();
+ } else {
+ return GetRef<PrimExpr>(op);
+ }
+ }
+
+ const std::unordered_map<const BufferLoadNode*, Optional<PrimExpr>>& known_values_;
+};
+
+void BufferState::ApplyTouches(const Map<Buffer, Array<Var>>& axis_var_lookup,
+ const std::vector<BufferTouch>& touch_points, Analyzer* analyzer) {
+ std::vector<BufferTouch> new_knowns;
+ Map<Buffer, PrimExpr> keep_prior_known_at;
+
+ for (auto& touch : touch_points) {
+ if (touch.touch_type == BufferTouch::AccessType::Read) {
+ continue;
+ }
+
+ PrimExpr known_value = touch.value;
+
+ PrimExpr predicate = touch.predicate && touch.AfterLoopIteration();
+ auto regions = BufferRegionCollector::Collect(axis_var_lookup, constraints_,
+ {predicate, touch.value}, analyzer);
+
+ for (const auto& region : regions) {
+ PrimExpr updated_predicate = BufferRegionValueReplacer::Apply(
+ region.known_values, region.region_predicate && predicate, analyzer);
+
+ updated_predicate = SimplifyAsAndOfOrs(updated_predicate, analyzer);
+ PrimExpr updated_value =
+ BufferRegionValueReplacer::Apply(region.known_values, known_value, analyzer);
+
+ if (!is_zero(updated_predicate)) {
+ if (auto it = keep_prior_known_at.find(touch.buffer); it != keep_prior_known_at.end()) {
+ keep_prior_known_at.Set(touch.buffer, (*it).second && !updated_predicate);
+ } else {
+ keep_prior_known_at.Set(touch.buffer, !updated_predicate);
+ }
+
+ if (!HasBufferLoad(updated_value)) {
+ BufferTouch new_constraint{touch.buffer, updated_predicate, updated_value};
+ new_knowns.push_back(new_constraint);
+ }
+ }
+ }
+ }
+
+ if (keep_prior_known_at.size()) {
+ for (auto& constraint : constraints_) {
+ if (auto it = keep_prior_known_at.find(constraint.buffer); it != keep_prior_known_at.end()) {
+ constraint.predicate = SimplifyAsAndOfOrs(constraint.predicate && (*it).second, analyzer);
+ }
+ }
+ }
+
+ if (new_knowns.size()) {
+ std::vector<bool> used(new_knowns.size(), false);
+
+ for (auto& constraint : constraints_) {
+ PrimExpr expand_known_at = Bool(false);
+
+ PrimExpr prev_value = constraint.value;
+
+ for (size_t i = 0; i < new_knowns.size(); i++) {
+ if (new_knowns[i].buffer.same_as(constraint.buffer)) {
+ Optional<PrimExpr> overwritten_with = new_knowns[i].value;
+ if (overwritten_with && analyzer->CanProveEqual(prev_value, overwritten_with.value())) {
+ expand_known_at =
+ SimplifyAsAndOfOrs(expand_known_at || new_knowns[i].predicate, analyzer);
+ used[i] = true;
+ }
+ }
+ }
+
+ if (!is_zero(expand_known_at)) {
+ constraint.predicate =
+ SimplifyAsAndOfOrs(constraint.predicate || expand_known_at, analyzer);
+ }
+ }
+
+ for (size_t i = 0; i < new_knowns.size(); i++) {
+ if (!used[i]) {
+ constraints_.push_back(new_knowns[i]);
+ }
+ }
+ }
+
+ constraints_.erase(
+ std::remove_if(constraints_.begin(), constraints_.end(),
+ [&](const auto& constraint) { return is_zero(constraint.predicate); }),
+ constraints_.end());
+}
+
+void BufferState::BackpropUnusedIndices(const Map<Buffer, Array<Var>>& axis_var_lookup,
+ const std::vector<BufferTouch>& touch_points,
+ Analyzer* analyzer) {
+ std::vector<BufferTouch> new_knowns;
+ Map<Buffer, PrimExpr> keep_prior_known_at;
+
+ Map<Buffer, PrimExpr> regions_written;
+ Map<Buffer, PrimExpr> regions_read;
+ for (auto it = touch_points.rbegin(); it != touch_points.rend(); it++) {
+ const auto& touch = *it;
+
+ Map<Buffer, PrimExpr>* to_update{nullptr};
+ if (touch.touch_type == BufferTouch::AccessType::Write) {
+ to_update = ®ions_written;
+
+ } else if (touch.touch_type == BufferTouch::AccessType::Read) {
+ to_update = ®ions_read;
+ } else {
+ continue;
+ }
+
+ PrimExpr prev = to_update->Get(touch.buffer).value_or(Bool(false));
+ PrimExpr new_predicate = touch.predicate && touch.BeforeLoopIteration();
+ to_update->Set(touch.buffer, prev || new_predicate);
+ }
+
+ auto update_map = [&](auto& map) {
+ Map<Buffer, PrimExpr> new_map;
+ for (auto [buffer, predicate] : map) {
+ new_map.Set(buffer, SimplifyAsAndOfOrs(predicate, analyzer));
+ }
+ map = std::move(new_map);
+ };
+ update_map(regions_written);
+ update_map(regions_read);
+
+ // If buffer is already in used, widen the predicate
+ for (auto& prev_unused : constraints_) {
+ if (auto opt_predicate = regions_written.Get(prev_unused.buffer)) {
+ PrimExpr new_predicate = prev_unused.predicate || opt_predicate.value();
+ prev_unused.predicate = SimplifyAsAndOfOrs(new_predicate, analyzer);
+ regions_written.erase(prev_unused.buffer);
+ }
+ }
+
+ // Otherwise, add new "touch" to represent the unused values
+ for (auto [buffer, predicate] : regions_written) {
+ constraints_.push_back(
+ BufferTouch{buffer, predicate, tir::Call(buffer->dtype, builtin::undef(), {})});
+ }
+
+ // If buffer is read out, narrow the predicate
+ for (auto& prev_unused : constraints_) {
+ if (auto opt_pred = regions_read.Get(prev_unused.buffer)) {
+ PrimExpr predicate = opt_pred.value();
+ prev_unused.predicate = SimplifyAsAndOfOrs(prev_unused.predicate && !predicate, analyzer);
+ }
+ }
+
+ // Clean-up and remove any empty constraints
+ constraints_.erase(
+ std::remove_if(constraints_.begin(), constraints_.end(),
+ [](const auto& constraint) { return is_zero(constraint.predicate); }),
+ constraints_.end());
+}
+
+void BufferState::RemoveFreeParameters(const Map<Var, Range>& free_predicate_parameters,
+ Analyzer* analyzer) {
+ for (auto& known : constraints_) {
+ known.predicate = NarrowPredicateExpression(known.predicate, free_predicate_parameters);
+ known.predicate = SimplifyAsAndOfOrs(known.predicate, analyzer);
+ }
+}
+
+bool BufferState::IsEquivalentTo(const BufferState& other, Analyzer* analyzer) const {
+ if (constraints_.size() != other.constraints_.size()) {
+ return false;
+ }
+
+ for (size_t i = 0; i < constraints_.size(); i++) {
+ if (!constraints_[i].IsEquivalentTo(other.constraints_[i], analyzer)) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+Optional<Array<Var>> ControlFlowGraph::GetIndexVariables(const Buffer& buf) const {
+ if (auto it = axis_var_lookup_.find(buf); it != axis_var_lookup_.end()) {
+ return (*it).second;
+ } else {
+ return NullOpt;
+ }
+}
+
+Array<Var> ControlFlowGraph::GetIndexVariables(const Buffer& buf, const Array<PrimExpr>& indices) {
+ if (auto it = axis_var_lookup_.find(buf); it != axis_var_lookup_.end()) {
+ return (*it).second;
+ }
+
+ Array<Var> vars;
+ for (size_t i = 0; i < indices.size(); i++) {
+ std::stringstream ss;
+ ss << buf->name << "_axis_" << i;
+ vars.push_back(Var(ss.str(), indices[i].dtype().element_of()));
+ }
+
+ axis_var_lookup_.Set(buf, vars);
+ return vars;
+}
+
+void ControlFlowGraph::ForwardPropagateKnownValues(size_t max_revisits) {
+ // Values to visit when searching. Using a std::set to
+ // preferentially visit nodes near the start of the control flow.
+ std::set<size_t> to_visit;
+
+ // Map from a block's index
+ std::unordered_map<size_t, size_t> visit_count_lookup;
+
+ // Initiatize the locations to search from, propagating values
+ // forward from all locations that have a known value.
+ for (size_t i = 0; i < control_flow_.size(); i++) {
+ bool has_known_value = false;
+ for (const auto& touch : control_flow_[i].touch_points) {
+ if (!HasBufferLoad(touch.value)) {
+ has_known_value = true;
+ break;
+ }
+ }
+
+ if (has_known_value) {
+ to_visit.insert(i);
+ }
+ }
+
+ Analyzer analyzer;
+ analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
+ arith::RewriteSimplifier::kTransitivelyProveInequalities |
+ arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches));
+
+ analyzer.Bind(iterator_ranges_);
+ analyzer.Bind(free_predicate_parameters_);
+
+ while (to_visit.size()) {
+ size_t visiting = *to_visit.begin();
+ to_visit.erase(visiting);
+
+ size_t num_previous_visits = visit_count_lookup[visiting]++;
+
+ ControlFlowBlock& block = control_flow_[visiting];
+
+ // Step 1: Collect known values provided from each predecessor
+ block.known_at_block_start = [&]() -> BufferState {
+ if (num_previous_visits >= max_revisits) {
+ return BufferState();
+ }
+
+ // Validate internal constraint. This should be true by
+ // construction, as ControlFlowGraphBuilder only builds graphs
+ // that have two or fewer predecessors.
+ ICHECK_LE(block.predecessors.size(), 2)
+ << "InternalError: Each block should have at most two predecessors. "
+ << "Graph constructed in ControlFlowGraphBuilder did not satisfy this constraint.";
+
+ std::vector<BufferState> states;
+ for (const auto& pred : block.predecessors) {
+ const auto& pred_block = control_flow_[pred.index];
+ BufferState state = pred_block.known_at_block_end;
+ state.Substitute(pred.var_remap, &analyzer);
+ states.push_back(state);
+ }
+
+ if (std::all_of(block.predecessors.begin(), block.predecessors.end(),
+ [&](const auto& pred) { return visit_count_lookup[pred.index] == 0; })) {
+ // Predecessors, if any, are unvisited.
+ return {};
+ } else if (block.predecessors.size() == 1) {
+ // Block has only a single predecessor
+ return states[0];
+ }
+
+ const auto& pred_a = block.predecessors[0];
+ const auto& pred_b = block.predecessors[1];
+
+ auto& priors_a = states[0];
+ auto& priors_b = states[1];
+
+ // During the first visit of a block, predecessor blocks may be
+ // unvisited, even though we preferentially visit earlier blocks
+ // first. (e.g. During the first visit of the start of a For
+ // loop, the end of the For loop has not yet been visited.) If
+ // this is the case, assume the best-case scenario that all
+ // knowns are consistent, and rely on a later visit to
+ // resolve/remove any conflicts.
+ if (visit_count_lookup[pred_a.index] == 0) {
+ return priors_b;
+ } else if (visit_count_lookup[pred_b.index] == 0) {
+ return priors_a;
+ }
+
+ if (pred_a.post_condition && pred_b.post_condition) {
+ // The predicate can identify which predecessor block applies
+ // (e.g. i==0 for the first loop iteration, i>0 for remaining
+ // loop iterations). Therefore, we can use all buffer
+ // constraints, conditional on having come from the
+ // predecessor that provides it.
+ priors_a.AddCondition(pred_a.post_condition.value());
+ priors_b.AddCondition(pred_b.post_condition.value());
+ priors_a.Union(priors_b, &analyzer);
+ return priors_a;
+ } else {
+ // We don't know which predecessor applies. Therefore, the
+ // only buffer constraints that can be used are those that
+ // appear in both predecessors.
+ priors_a.Intersection(priors_b, &analyzer);
+ return priors_a;
+ }
+ }();
+
+ // Step 2: Collect knowns provided as a result of executing this block
+ auto post_state = [&]() {
+ if (num_previous_visits >= max_revisits) {
+ return BufferState();
+ }
+ auto post_state = block.known_at_block_start;
+ post_state.ApplyTouches(axis_var_lookup_, block.touch_points, &analyzer);
+ post_state.RemoveFreeParameters(free_predicate_parameters_, &analyzer);
+ return post_state;
+ }();
+
+ // Step 3: If any changes are made to the post knowns since the
+ // previous time we visited this block, mark the successor block
+ // as needing to be visited.
+ if (num_previous_visits == 0 ||
+ !post_state.IsEquivalentTo(block.known_at_block_end, &analyzer)) {
+ block.known_at_block_end = std::move(post_state);
+ for (const auto& successor : block.successors) {
+ to_visit.insert(successor.index);
+ }
+ }
+ }
+}
+
+void ControlFlowGraph::BackwardPropagateUnusedValues(size_t max_revisits) {
+ // Values to visit when searching. Using a std::set to
+ // preferentially visit nodes near the end of the control flow.
+ std::set<size_t> to_visit;
+
+ // Map from a block's index
+ std::unordered_map<size_t, size_t> visit_count_lookup;
+
+ // Initiatize the locations to search from, propagating values
+ // backward from anywhere that performs a write.
+ for (size_t i = 0; i < control_flow_.size(); i++) {
+ const auto& touch_points = control_flow_[i].touch_points;
+ bool performs_write = std::any_of(
+ touch_points.begin(), touch_points.end(),
+ [](const auto& touch) { return touch.touch_type == BufferTouch::AccessType::Write; });
+ if (performs_write) {
+ to_visit.insert(i);
+ }
+ }
+
+ Analyzer analyzer;
+ analyzer.rewrite_simplify.SetEnabledExtensions(
+ arith::RewriteSimplifier::kTransitivelyProveInequalities);
+
+ analyzer.Bind(iterator_ranges_);
+ analyzer.Bind(free_predicate_parameters_);
+
+ while (to_visit.size()) {
+ size_t visiting = *to_visit.rbegin();
+ to_visit.erase(visiting);
+
+ size_t num_previous_visits = visit_count_lookup[visiting]++;
+
+ ControlFlowBlock& block = control_flow_[visiting];
+
+ // Step 1: Collect known unused indices provided by each successor
+ block.unused_at_block_end = [&]() -> BufferState {
+ if (num_previous_visits >= max_revisits) {
+ return BufferState();
+ }
+ ICHECK_LE(block.successors.size(), 2)
+ << "Each block should have at most two successors, but block " << visiting
+ << " breaks this requirement";
+
+ std::vector<BufferState> states;
+ for (const auto& successor : block.successors) {
+ const auto& successor_block = control_flow_[successor.index];
+ BufferState state = successor_block.unused_at_block_start;
+ state.Substitute(successor.var_remap, &analyzer);
+ states.push_back(state);
+ }
+
+ if (std::all_of(block.successors.begin(), block.successors.end(), [&](const auto& successor) {
+ return visit_count_lookup[successor.index] == 0;
+ })) {
+ // Successors, if any, are unvisited.
+ return {};
+ } else if (block.successors.size() == 1) {
+ // Block has only a single successor
+ return states[0];
+ }
+
+ const auto& successor_a = block.successors[0];
+ const auto& successor_b = block.successors[1];
+
+ auto& post_a = states[0];
+ auto& post_b = states[1];
+
+ // During the first visit of a block, successor blocks may be
+ // unvisited, even though we preferentially visit later blocks
+ // first. (e.g. During the first visit of the end of a For
+ // loop, the start of the For loop has not yet been visited.)
+ // If this is the case, assume the best-case scenario that all
+ // knowns are consistent, and rely on a later visit to
+ // resolve/remove any conflicts.
+ if (visit_count_lookup[successor_a.index] == 0) {
+ return post_b;
+ } else if (visit_count_lookup[successor_b.index] == 0) {
+ return post_a;
+ }
+
+ if (successor_a.post_condition && successor_b.post_condition) {
+ // The predicate can identify which successor block applies
+ // (e.g. i==n-1 for the last loop iteration, i<n-1 for earlier
+ // loop iterations). Therefore, we can use all buffer
+ // constraints, conditional on having come from the
+ // successor that provides it.
+ post_a.AddCondition(successor_a.post_condition.value());
+ post_b.AddCondition(successor_b.post_condition.value());
+ post_a.Union(post_b, &analyzer);
+ return post_a;
+ } else {
+ // We don't know which successor applies. Therefore, the
+ // only buffer constraints that can be used are those that
+ // appear in both successors.
+ post_a.Intersection(post_b, &analyzer);
+ return post_a;
+ }
+ }();
+
+ // Step 2: Collect knowns provided as a result of executing this block
+ auto unused_at_block_start = [&]() {
+ if (num_previous_visits >= max_revisits) {
+ return BufferState();
+ }
+ auto prior_state = block.unused_at_block_end;
+ prior_state.BackpropUnusedIndices(axis_var_lookup_, block.touch_points, &analyzer);
+ prior_state.RemoveFreeParameters(free_predicate_parameters_, &analyzer);
+ return prior_state;
+ }();
+
+ // Step 3: If any changes are made to the post knowns since the
+ // previous time we visited this block, mark the successor block
+ // as needing to be visited.
+ if (num_previous_visits == 0 ||
+ !unused_at_block_start.IsEquivalentTo(block.unused_at_block_start, &analyzer)) {
+ block.unused_at_block_start = std::move(unused_at_block_start);
+ for (const auto& pred : block.predecessors) {
+ to_visit.insert(pred.index);
+ }
+ }
+ }
+}
+
+bool ControlFlowGraph::IsOverwrittenWithoutEffect(const tir::BufferStore& store,
+ const Stmt& context) const {
+ Optional<Array<Var>> index_variables = GetIndexVariables(store->buffer);
+ if (!index_variables) {
+ return false;
+ }
+
+ auto it = control_flow_lookup_.find(context.get());
+ ICHECK(it != control_flow_lookup_.end())
+ << "Context " << PrettyPrint(context) << " did not occur within analyzed statement";
+ const auto& context_block = control_flow_[it->second];
+
+ auto [store_touch, free_params] = context_block.MakeBufferTouch(
+ store->buffer, index_variables.value(), store->indices, BufferTouch::AccessType::Write,
+ BufferLoad(store->buffer, store->indices));
+
+ Analyzer local_analyzer;
+ local_analyzer.Bind(free_predicate_parameters_);
+ local_analyzer.Bind(iterator_ranges_);
+ local_analyzer.Bind(free_params);
+ local_analyzer.rewrite_simplify.SetEnabledExtensions(
+ RewriteSimplifier::kTransitivelyProveInequalities);
+
+ PrimExpr predicate = store_touch.predicate && store_touch.AtLoopIteration();
+
+ predicate = SimplifyAsAndOfOrs(predicate, &local_analyzer);
+
+ for (const auto& unused : context_block.unused_at_block_end.constraints_) {
+ if (store_touch.buffer.same_as(unused.buffer)) {
+ PrimExpr difference = SimplifyAsAndOfOrs(predicate && !unused.predicate, &local_analyzer);
+ if (is_zero(difference)) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+PrimExpr ControlFlowGraph::SimplifyInContext(PrimExpr expr, const tir::Stmt& context,
+ Analyzer* analyzer) const {
+ size_t context_index = [&]() {
+ auto it = control_flow_lookup_.find(context.get());
+ ICHECK(it != control_flow_lookup_.end())
+ << "Context did not occur in the Stmt provided to BufferTouchPattern's constructor";
+ return it->second;
+ }();
+
+ PrimExpr constraint = Bool(true);
+ for (const auto& known : non_buffer_assumptions_) {
+ constraint = constraint && known;
+ }
+ With<ConstraintContext> constraint_context(analyzer, constraint);
+
+ expr = control_flow_[context_index].known_at_block_start.SubstituteKnownBufferValues(
+ std::move(expr), axis_var_lookup_, analyzer);
+
+ expr = analyzer->Simplify(std::move(expr));
+ return expr;
+}
+
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/analysis/control_flow_graph.h b/src/tir/analysis/control_flow_graph.h
new file mode 100644
index 0000000000..aa9023ba29
--- /dev/null
+++ b/src/tir/analysis/control_flow_graph.h
@@ -0,0 +1,653 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file control_flow_graph.h
+ * \brief Utility for extracting and interacting with buffer touch points
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/container/array.h>
+#include <tvm/tir/buffer.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/var.h>
+
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#ifndef TVM_TIR_ANALYSIS_CONTROL_FLOW_GRAPH_H_
+#define TVM_TIR_ANALYSIS_CONTROL_FLOW_GRAPH_H_
+
+namespace tvm {
+namespace tir {
+
+/*! \brief Represents an interaction with a buffer */
+struct BufferTouch {
+ enum class AccessType {
+ /*! \brief Buffer access occurs in BufferLoad */
+ Read,
+
+ /*! \brief Buffer access occurs in BufferStore */
+ Write,
+
+ /*! \brief Buffer access occurs in tir::builtin::assume() */
+ Assume,
+ };
+
+ BufferTouch(Buffer buffer, PrimExpr predicate, PrimExpr value)
+ : buffer(buffer),
+ predicate(predicate),
+ value(value),
+ loop_var_expressions({}),
+ touch_type(AccessType::Assume) {}
+
+ BufferTouch(Buffer buffer, PrimExpr predicate, PrimExpr value,
+ std::vector<std::pair<Var, PrimExpr>> loop_var_expressions, AccessType touch_type)
+ : buffer(buffer),
+ predicate(predicate),
+ value(value),
+ loop_var_expressions(loop_var_expressions),
+ touch_type(touch_type) {}
+
+ /*! \brief The buffer being touched */
+ Buffer buffer;
+
+ /*! \brief A predicate that is true when this touch applies
+ *
+ * May be in terms of axis variables to indicate touches that impact
+ * only a portion of a buffer.
+ */
+ PrimExpr predicate;
+
+ /*! \brief The value in this buffer after the touch
+ *
+ * May be in terms of axis variables to indicate a known
+ * non-constant value. May be in terms of a BufferLoad to indicate
+ * an unknown value.
+ */
+ PrimExpr value;
+
+ /*! \brief Active loops during the buffer touch
+ *
+ * The vector contains one entry for each loop that contains the
+ * buffer touch. The `Var` item in each entry is the loop variable
+ * itself. The `PrimExpr` item is an expression for the loop
+ * variable in terms of the buffer axis variables in
+ * `ControlFlowGraph::axis_var_lookup_`.
+ *
+ * Used to construct boolean expressions indicating whether the loop
+ * iteration that performs this touch has been reached.
+ */
+ std::vector<std::pair<Var, PrimExpr>> loop_var_expressions;
+
+ /*! \brief How the buffer was interacted with
+ *
+ * When used as a constraint (e.g. in BufferState), should use
+ * Assume.
+ */
+ AccessType touch_type{AccessType::Assume};
+
+ /*! \brief Generate a boolean expression that is true for indices
+ * accessed by this touch during this iteration or a previous
+ * loop iteration.
+ *
+ * Used during forward propagation, to track known values that were
+ * written in the current loop iteration, or in a preceding loop
+ * iteration.
+ */
+ PrimExpr BeforeLoopIteration() const;
+
+ /*! \brief Generate a boolean expression that is true for indices
+ * accessed by this touch during this loop iteration.
+ *
+ * Used during speculative no-op insertion checks, to specify which
+ * indices must be later overwritten for a store to have no impact
+ * on final results.
+ */
+ PrimExpr AtLoopIteration() const;
+
+ /*! \brief Generate a boolean expression that is true for indices
+ * accessed by this touch during this loop iteration or a
+ * subsequent loop iteration.
+ *
+ * Used during backward propagation, to track indices that that are
+ * overwritten in the current loop iteration or in a later loop
+ * iteration.
+ */
+ PrimExpr AfterLoopIteration() const;
+
+ /* \brief Checks if this touch affects a subset of indices of another
+ *
+ * Returns true if the indices accessed by this touch are a subset
+ * of predicate is true can be proven to be a subset of the other
+ * subset. Returns false if it cannot be proven to be a subset of
+ * ther other subset.
+ */
+ bool IsSubsetOf(const BufferTouch& other, arith::Analyzer* analyzer) const;
+
+ /* \brief Checks if this touch affects distinct indices from another
+ *
+ * Returns true if it can be proven that the two predicates cannot
+ * be simultaneously true. Returns false if it cannot be proven
+ * that the two predicates are distinct.
+ */
+ bool IsDistinctFrom(const BufferTouch& other, arith::Analyzer* analyzer) const;
+
+ /* \brief Checks if this touch affects distinct indices from another
+ *
+ * Returns true if it can be proven that the two predicates cannot
+ * be simultaneously true. Returns false if it cannot be proven
+ * that the two predicates are distinct.
+ */
+ bool IsEquivalentTo(const BufferTouch& other, arith::Analyzer* analyzer) const;
+
+ friend std::ostream& operator<<(std::ostream& os, const BufferTouch& expr);
+};
+
+/*! \brief Represents the known state of buffers at a specific point */
+class BufferState {
+ public:
+ /*! Default constructor
+ *
+ * Initialize the buffer state with no known information.
+ */
+ BufferState() {}
+
+ /*! \brief Replace BufferLoad instances with known values
+ *
+ * \param expr The expression to be updated.
+ *
+ * \param axis_var_lookup A map from buffer to the variables
+ * representing positions along the buffer's axes.
+ *
+ * \param analyzer The analyzer to use when validating a
+ * constraint's predicate.
+ *
+ * \returns The modified expression. If no substitutions are made,
+ * the original expression is returned.
+ */
+ PrimExpr SubstituteKnownBufferValues(PrimExpr expr,
+ const Map<Buffer, Array<Var>>& axis_var_lookup,
+ arith::Analyzer* analyzer) const;
+
+ /*! \brief Apply a condition to all known constraints
+ *
+ * For example, when propagating pre-loop constraints into the body
+ * of a loop, add a condition that the loop iterator is zero.
+ *
+ * \param condition The condition to apply
+ */
+ void AddCondition(const PrimExpr& condition);
+
+ /*! \brief Perform a variable substitution for all constraints
+ *
+ * For example, when propagating constraints from the end of a loop
+ * to the beginning, replace `i` with `i-1`.
+ *
+ * \param var_remap The variable remapping to apply.
+ */
+ void Substitute(const Map<Var, PrimExpr>& var_remap, arith::Analyzer* analyzer);
+
+ /*! \brief Simplify the predicate of all constraints
+ *
+ * \param analyzer The analyzer with which to simplify
+ */
+ void Simplify(arith::Analyzer* analyzer);
+
+ /*! \brief Update the known buffer values based on buffer touches
+ *
+ * For any Write or Assume touches, update the known values. For
+ * any Read touches, ignore. Used to determine known values at the
+ * end of a control flow block, given the known values at the start.
+ *
+ * \param axis_var_lookup A map from buffer to the variables
+ * representing positions along the buffer's axes.
+ *
+ * \param touch_points The buffer touch points to apply
+ *
+ * \param analyzer The analyzer to use for simplifications
+ */
+ void ApplyTouches(const Map<Buffer, Array<Var>>& axis_var_lookup,
+ const std::vector<BufferTouch>& touch_points, arith::Analyzer* analyzer);
+
+ /*! \brief Update unused buffer locations based on buffer touches
+ *
+ * For any Write, mark the written-to indices as unused. (That is,
+ * immediately prior to assigning `buf[i] = expr`, the value stored
+ * at `buf[i]` is irrelevant.) For any Read, mark the read-from
+ * indices as used. This method is used to determine unused buffer
+ * indices at the start of a control flow block, given the unused
+ * buffer indices values at the end.
+ *
+ * \param axis_var_lookup A map from buffer to the variables
+ * representing positions along the buffer's axes.
+ *
+ * \param touch_points The buffer touch points to apply
+ *
+ * \param analyzer The analyzer to use for simplifications
+ */
+ void BackpropUnusedIndices(const Map<Buffer, Array<Var>>& axis_var_lookup,
+ const std::vector<BufferTouch>& touch_points,
+ arith::Analyzer* analyzer);
+
+ /*! \brief Remove free parameters from the constraints
+ *
+ * \param free_predicate_parameters
+ *
+ * \param analyzer The analyzer with which to simplify after removal
+ */
+ void RemoveFreeParameters(const Map<Var, Range>& free_predicate_parameters,
+ arith::Analyzer* analyzer);
+
+ /*! \brief Check if two buffer states are equivalent
+ *
+ * \param other
+ *
+ * \param analyzer The analyzer used to check equality of PrimExpr
+ *
+ * \return True if the two states are provably equivalent, false otherwise.
+ */
+ bool IsEquivalentTo(const BufferState& other, arith::Analyzer* analyzer) const;
+
+ /* \brief Add known values provided by another state
+ *
+ * \param other The state with which to merge constraints
+ *
+ * \param analyzer The analyzer with which to simplify the result
+ */
+ void Union(const BufferState& other, arith::Analyzer* analyzer);
+
+ /* \brief Remove all known values not consistent with another state
+ *
+ * \param other The state with which to merge constraints
+ *
+ * \param analyzer The analyzer with which to simplify the result
+ */
+ void Intersection(const BufferState& other, arith::Analyzer* analyzer);
+
+ friend std::ostream& operator<<(std::ostream& os, const BufferState&);
+
+ private:
+ friend class ControlFlowGraph;
+ /*! \brief The known constraints */
+ std::vector<BufferTouch> constraints_;
+};
+
+/*! \brief Represents the flow of control through a `tir::Stmt`
+ *
+ * This class contains an internal representation of the possible
+ * control flow that may occur during execution of a `tir::Stmt`. It
+ * consists of a collection of ControlFlowBlock objects, each of which
+ * represents a subset of operations performed during execution, along
+ * with edges that represent allowed transitions between
+ * `ControlFlowBlock`.
+ *
+ * In addition, the following restrictions are used.
+ *
+ * 1. Each block may have at most two predecessors, and at most two
+ * successors.
+ *
+ * 2. Within each block, values stored in a buffer do not change.
+ * That is, encountering a `BufferStore` node requires creating a
+ * new block.
+ *
+ * For example, consider the following PrimFunc
+ *
+ * ```python
+ * @T.prim_func
+ * def func(T.Buffer[16, "float32"]):
+ * for i in T.serial(16):
+ * if i < 8:
+ * B[i] = i
+ * else:
+ * B[i] = i-8
+ * ```
+ *
+ * The control flow graph would have eight control blocks.
+ *
+ * 1. function_entry, from the start of the function through the
+ * evaluation of the loop's extent.
+ *
+ * Predecessors: n/a
+ * Successors: loop_start
+ *
+ * 2. loop_start, after entering the body of the loop, through the
+ * evaluation of the conditional `i < 8`
+ *
+ * Predecessors: function_entry, after_conditional
+ * Successors: then_clause_start, else_clause_start
+ *
+ * 3. then_clause_start, after entering the then_clause of `i < 8`,
+ * through evaluation of the value `i`.
+ *
+ * Predecessors: loop_start
+ * Successors: then_clause_end
+ *
+ * 4. then_clause_end, after storing to `B[i]` prior to exiting the
+ * then_clause.
+ *
+ * Predecessors: then_clause_start
+ * Successors: after_conditional
+ *
+ * 5. else_clause_start, after entering the else_clause of `i < 8`,
+ * through evaluation of the value `i-8`.
+ *
+ * Predecessors: loop_start
+ * Successors: else_clause_end
+ *
+ * 6. else_clause_end, after storing to `B[i]` prior to exiting the
+ * else_clause.
+ *
+ * Predecessors: else_clause_start
+ * Successors: after_conditional
+ *
+ * 7. after_conditional, after the end of the if/then/else, before the
+ * end of the loop body
+ *
+ * Predecessors: then_clause_end, else_clause_end
+ * Successors: loop_start, after_loop
+ *
+ * 8. after_loop, after the loop
+ *
+ * Predecessors: after_conditional
+ * Successors: n/a
+ *
+ *
+ * By identifying `BufferStore` nodes whose value does not depend on
+ * values stored in input buffers (e.g. initializing `buf[i] = 0.0`),
+ * or whose values are provided using `builtin::assume()`
+ * (e.g. `T.assume(buf[i] == 0.0)`), the value stored in a buffer at
+ * those indices may be known for a given control block. These known
+ * values can then be propagated forward to successor blocks, to be
+ * used in context-dependent simplifications.
+ *
+ * In addition to the allowed transitions between control-flow
+ * blocks, each block also tracks the buffer touch points; which
+ * indices are read from a buffer, which values are written to which
+ * indices of a buffer, and assumptions are provided using
+ * `builtin::assume()`; that occur during the control-flow block.
+ *
+ * Note: The current implementation only tracks the values of
+ * buffers that are constrained to a specific value, and does not
+ * track inequalities that may partially constrain buffer values.
+ * That is, entering a scoped context with a data-dependent equality
+ * condition (e.g. `if buf[i] == value`) is tracked, but entering a
+ * scoped context with a data-dependent inequality condition
+ * (e.g. `if buf[i] > value`) is not tracked.
+ */
+class ControlFlowGraph {
+ public:
+ /* \brief Extract the touch pattern from a TIR statement
+ */
+ explicit ControlFlowGraph(const Stmt& stmt, size_t max_revisits = 5);
+
+ /* \brief Check if a write is overwritten without impacting final results
+ *
+ * \param store The store to be examined
+ *
+ * \param context The context in which the buffer store occurs, used
+ * to identify the control-flow block in which the store occurs. In
+ * most cases, this will be the same object as the `store` itself.
+ *
+ * \param analyzer The analyzer to be used for simplifications
+ *
+ * \return True if the specified store can be proven to be
+ * overwritten without contributing to any later statements.
+ * Returns false otherwise.
+ */
+ bool IsOverwrittenWithoutEffect(const BufferStore& store, const Stmt& context) const;
+
+ /* \brief Simplify the expression, assuming it occurs within the given context
+ *
+ * \param expr The expression to be simplified. Does not need to
+ * have occurred within the statement used to construct this
+ * BufferTouchPattern.
+ *
+ * \param context The statement where this expression occurred, or
+ * is to be inserted. Must occur within the statement used to
+ * construct this BufferTouchPattern.
+ *
+ * \param analyzer The analyzer to be used for simplifications
+ *
+ * \returns The simplified statement
+ */
+ PrimExpr SimplifyInContext(PrimExpr expr, const Stmt& context, arith::Analyzer* analyzer) const;
+
+ /*! \brief Remove the specified BufferStore from the control-flow
+ * graph
+ *
+ * Removing the specified store, which may reflow known values.
+ * This is necessary when simplifying sequential stores of the same
+ * value. Otherwise, the first could be removed as a no-op because
+ * it is overwritten by the second, and the second could be removed
+ * as a no-op because it is the same value as the first.
+ *
+ * \param store The store to remove
+ */
+ void RemoveStore(const tir::BufferStore& store);
+
+ friend std::ostream& operator<<(std::ostream& os, const ControlFlowGraph& pattern);
+
+ private:
+ /*! \brief Return index variables representing locations within a
+ * buffer.
+ *
+ * For a given buffer, will always return the same set of variables.
+ *
+ * \param buf The buffer being accessed
+ *
+ * \param indices The indices at which the buffer is being accessed.
+ * These are used to set the dtype of the buffer axis variables.
+ *
+ * \returns Variables representing a position along the buffer's axis.
+ */
+ Array<Var> GetIndexVariables(const Buffer& buf, const Array<PrimExpr>& indices);
+
+ /*! \brief Return index variables representing locations within a
+ * buffer, if they have been generated before.
+ *
+ * For a given buffer, will always return the same set of variables.
+ *
+ * \param buf The buffer being accessed
+ *
+ * \returns Variables representing a position along the buffer's axis.
+ */
+ Optional<Array<Var>> GetIndexVariables(const Buffer& buf) const;
+
+ /*! \brief Propagate known values from known BufferStore/assume
+ * subsequent control flow blocks
+ */
+ void ForwardPropagateKnownValues(size_t max_revisits);
+
+ /*! \brief Propagate overwritten/unused indices to preceding control
+ * flow blocks
+ */
+ void BackwardPropagateUnusedValues(size_t max_revisits);
+
+ struct ControlFlowEdge {
+ /* \brief The source block of the control flow edge
+ *
+ * Lookup index into `control_flow_`
+ */
+ size_t index;
+
+ /*! \brief Variable remaps
+ *
+ * e.g. Replacing loop iterator `i` with `i-1` when following an
+ * edge from the end of a loop to the beginning of the loop.
+ */
+ Map<Var, PrimExpr> var_remap;
+
+ /*! \brief Condition that must to true after following this edge
+ *
+ * This is applied after variable remapping. For example, `i >
+ * loop_min` when following the an edge from the end of a loop to
+ * the beginning of the loop.
+ */
+ Optional<PrimExpr> post_condition;
+ };
+ friend std::ostream& operator<<(std::ostream& os, const ControlFlowEdge& edge);
+
+ struct ControlFlowBlock {
+ struct LoopEntry {
+ Var loop_var;
+ PrimExpr loop_min;
+ PrimExpr loop_max;
+ Range loop_range;
+ };
+
+ /*! \brief Loop iterators that are active during this block */
+ std::vector<LoopEntry> active_loop_iterators;
+
+ /*! \brief Loop-dependent Let bindings that may appear within the block */
+ Map<Var, PrimExpr> let_bindings_using_loop;
+
+ /*! \brief Predicate that must be true to have reached this block */
+ PrimExpr scope_predicate{Bool(true)};
+
+ /*! \brief All known values prior to executing the block */
+ BufferState known_at_block_start;
+
+ /*! \brief All known values after executing the block */
+ BufferState known_at_block_end;
+
+ /*! \brief Indices whose value at the start of the block is known to be unused */
+ BufferState unused_at_block_start;
+
+ /*! \brief Indices whose value at the end of the block is known to be unused */
+ BufferState unused_at_block_end;
+
+ /* \brief Buffer touches that occur within the block
+ *
+ * All buffer touches within a block can be treated as occurring
+ * simultaneously.
+ */
+ std::vector<BufferTouch> touch_points;
+
+ /* \brief The blocks that occur after this block
+ *
+ * Lookup index into `control_flow_`
+ */
+ std::vector<ControlFlowEdge> successors;
+
+ /* \brief The blocks that occur before this block */
+ std::vector<ControlFlowEdge> predecessors;
+
+ /* \brief Construct a BufferTouch instance within this
+ * ControlFlowBlock
+ *
+ * \param graph The mutable ControlFlowGraph that owns the buffer
+ * touch. Any free parameters used in the BufferTouch's predicate
+ * will be tracked by the ControlFlowGraph.
+ *
+ * \param buf The Buffer being accessed
+ *
+ * \param indices The indices at which the buffer is accessed, in
+ * terms of the loop variables.
+ *
+ * \param touch_type The type of touch being generated
+ *
+ * \param known_expr_value The value being written to the buffer
+ *
+ * \returns The newly generated BufferTouch
+ */
+ BufferTouch MakeBufferTouch(ControlFlowGraph* graph, const Buffer& buf,
+ const Array<PrimExpr>& indices, BufferTouch::AccessType touch_type,
+ PrimExpr known_value_expr) const;
+
+ /* \brief Construct a BufferTouch instance as if it occurred in
+ * this ControlFlowBlock
+ *
+ * Used when speculative checking if a BufferStore could be
+ * inserted.
+ *
+ * \param buf The Buffer being accessed
+ *
+ * \param index_variables The variables representing location
+ * within a buffer, with one variable for each axis of the buffer.
+ *
+ * \param indices The indices at which the buffer is accessed, in
+ * terms of the loop variables.
+ *
+ * \param touch_type The type of touch being generated
+ *
+ * \param known_expr_value The value being written to the buffer
+ *
+ * \returns The newly generated BufferTouch, and a map specifying
+ * all free parameters that may occur in the BufferTouch's
+ * predicate.
+ */
+ std::pair<BufferTouch, Map<Var, Range>> MakeBufferTouch(const Buffer& buf,
+ Array<Var> index_variables,
+ Array<PrimExpr> indices,
+ BufferTouch::AccessType touch_type,
+ PrimExpr known_value_expr) const;
+ };
+ friend std::ostream& operator<<(std::ostream& os, const ControlFlowBlock& pattern);
+
+ /* \brief The control flow that occurs within the analyzed statement */
+ std::vector<ControlFlowBlock> control_flow_;
+
+ /* \brief A lookup into control_flow_
+ *
+ * A map to look up the control flow block that contains the
+ * statement.
+ */
+ std::unordered_map<const StmtNode*, size_t> control_flow_lookup_;
+
+ /*! \brief A map from free parameters to their range
+ *
+ * A BufferStore/BufferLoad has indices in terms of loop iterators,
+ * while the internal BufferTouch must have predicate in terms of
+ * the buffer's axes. While converting to the internal BufferTouch,
+ * reduction axes show up as free parameters. Tracking the range of
+ * the free parameters allows them to be removed later, by requiring
+ * a predicate to be true for all values of the free parameters.
+ */
+ Map<Var, Range> free_predicate_parameters_;
+
+ /*! \brief Ranges of iterators found in the analyzed statement */
+ Map<Var, Range> iterator_ranges_;
+
+ /* \brief A map from buffer to the variables representing positions
+ * along the buffer's axes.
+ *
+ * This is stored here, rather than as part of the BufferState or
+ * BufferTouch, to ensure that all access of a buffer use the same
+ * variables to represent the buffer's axes, reducing the amount of
+ * variable substitution required.
+ */
+ Map<Buffer, Array<Var>> axis_var_lookup_;
+
+ /* \brief Assumptions that do not depend on buffer values
+ *
+ * These may be collected as part of the handling of `builtin::assume()`, and do not depend on any
+ * buffer. Since TIR only allows mutable values as part of buffers, these assumptions may be used
+ * anywhere the
+ */
+ std::vector<PrimExpr> non_buffer_assumptions_;
+
+ friend class ControlFlowGraphBuilder;
+};
+
+} // namespace tir
+} // namespace tvm
+#endif // TVM_TIR_ANALYSIS_CONTROL_FLOW_GRAPH_H_
diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc
index 1dbf9e6880..49d3a9ceae 100644
--- a/src/tir/transforms/simplify.cc
+++ b/src/tir/transforms/simplify.cc
@@ -29,7 +29,10 @@
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
+#include <optional>
+
#include "../../arith/ir_mutator_with_analyzer.h"
+#include "../../tir/analysis/control_flow_graph.h"
namespace tvm {
namespace arith {
@@ -38,6 +41,8 @@ using namespace tir;
struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
bool transitively_prove_inequalities;
+ bool propagate_knowns_to_prove_conditional;
+ bool propagate_knowns_to_simplify_expressions;
bool convert_boolean_to_and_of_ors;
bool apply_constraints_to_boolean_branches;
@@ -47,6 +52,17 @@ struct SimplifyConfigNode : public tvm::AttrsNode<SimplifyConfigNode> {
"If true, simplify conditionals with transitive combinations of scoped constraints")
.set_default(false);
+ TVM_ATTR_FIELD(propagate_knowns_to_prove_conditional)
+ .describe(
+ "If true, known buffer values are propagated and used to statically prove conditionals")
+ .set_default(false);
+
+ TVM_ATTR_FIELD(propagate_knowns_to_simplify_expressions)
+ .describe(
+ "If true, known buffer values are propagated and used to replace BufferLoad wherever "
+ "possible")
+ .set_default(false);
+
TVM_ATTR_FIELD(convert_boolean_to_and_of_ors)
.describe("If true, simplify conditionals into an AND of ORs")
.set_default(false);
@@ -85,16 +101,46 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.Simplify", SimplifyConfig);
class StmtSimplifier : public IRMutatorWithAnalyzer {
public:
- explicit StmtSimplifier(Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer) {}
+ static Stmt Apply(Stmt stmt, Analyzer* analyzer, Optional<SimplifyConfig> config_opt = NullOpt) {
+ auto config = config_opt.value_or(AttrsWithDefaultValues<arith::SimplifyConfig>());
+ analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions());
+
+ std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
+ if (config->propagate_knowns_to_prove_conditional ||
+ config->propagate_knowns_to_simplify_expressions) {
+ touch_pattern = ControlFlowGraph(stmt);
+ }
+ StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern));
+ return simplifier(std::move(stmt));
+ }
+
+ private:
+ explicit StmtSimplifier(Analyzer* analyzer, SimplifyConfig config,
+ std::optional<ControlFlowGraph> touch_pattern)
+ : IRMutatorWithAnalyzer(analyzer), config_(config), touch_pattern_(touch_pattern) {}
using Parent = IRMutatorWithAnalyzer;
using Parent::VisitStmt;
using Parent::VisitStmt_;
- PrimExpr VisitExpr(const PrimExpr& expr) final { return analyzer_->Simplify(expr); }
+ PrimExpr VisitExpr(const PrimExpr& expr) final {
+ if (config_->propagate_knowns_to_simplify_expressions) {
+ return touch_pattern_->SimplifyInContext(expr, current_stmt_.value(), analyzer_);
+ } else {
+ return analyzer_->Simplify(expr);
+ }
+ }
Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); }
+ Stmt VisitStmt(const Stmt& stmt) override {
+ Optional<Stmt> cache = this->current_stmt_;
+ this->current_stmt_ = stmt;
+ Stmt output = Parent::VisitStmt(stmt);
+ this->current_stmt_ = std::move(cache);
+ return output;
+ }
+
Stmt VisitStmt_(const ForNode* op) final {
analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
@@ -111,7 +157,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return SideEffect(op->value) <= CallEffectKind::kPure;
}
- Stmt VisitStmt_(const LetStmtNode* op) {
+ Stmt VisitStmt_(const LetStmtNode* op) override {
PrimExpr value = this->VisitExpr(op->value);
if (CanInlineLetStmt(op)) {
// it is fine to discard the let binding
@@ -134,26 +180,24 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
}
}
- Stmt VisitStmt_(const IfThenElseNode* op) {
- PrimExpr cond = analyzer_->Simplify(Substitute(op->condition, non_inlined_bindings_));
- if (const int64_t* as_int = as_const_int(cond)) {
- if (*as_int) {
+ Stmt VisitStmt_(const IfThenElseNode* op) override {
+ if (Optional<Bool> cond = ProveCondition(op->condition)) {
+ if (cond.value()->value) {
return this->VisitStmt(op->then_case);
} else if (op->else_case) {
return this->VisitStmt(op->else_case.value());
} else {
return Evaluate(0);
}
+ } else {
+ return Parent::VisitStmt_(op);
}
- return Parent::VisitStmt_(op);
}
- PrimExpr VisitExpr_(const CallNode* op) {
+ PrimExpr VisitExpr_(const CallNode* op) override {
if (op->op.same_as(builtin::if_then_else())) {
- PrimExpr cond = this->VisitExpr(op->args[0]);
- cond = analyzer_->Simplify(Substitute(std::move(cond), non_inlined_bindings_));
- if (const int64_t* as_int = as_const_int(cond)) {
- if (*as_int) {
+ if (Optional<Bool> cond = ProveCondition(op->args[0])) {
+ if (cond.value()->value) {
return this->VisitExpr(op->args[1]);
} else {
return this->VisitExpr(op->args[2]);
@@ -196,23 +240,50 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return true;
}
+ /* \brief Internal utility for checking conditionals
+ *
+ * Uses more aggressive optimization, such as performing additional
+ * inlining and tracking known buffer values.
+ */
+ Optional<Bool> ProveCondition(PrimExpr condition) const {
+ condition = Substitute(condition, non_inlined_bindings_);
+ if (config_->propagate_knowns_to_prove_conditional) {
+ ICHECK(touch_pattern_.has_value());
+ condition = touch_pattern_->SimplifyInContext(condition, current_stmt_.value(), analyzer_);
+ } else {
+ condition = analyzer_->Simplify(condition);
+ }
+ if (const int64_t* as_int = as_const_int(condition)) {
+ return Bool(*as_int);
+ } else {
+ return NullOpt;
+ }
+ }
+
+ SimplifyConfig config_;
+ std::optional<ControlFlowGraph> touch_pattern_;
+
Map<Var, PrimExpr> non_inlined_bindings_;
+ Optional<Stmt> current_stmt_{NullOpt};
};
} // namespace arith
namespace tir {
+
+Stmt Simplify(Stmt stmt, arith::Analyzer* analyzer) {
+ return arith::StmtSimplifier::Apply(stmt, analyzer);
+}
+
namespace transform {
Pass Simplify() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
arith::Analyzer analyzer;
- auto cfg = ctx->GetConfig<arith::SimplifyConfig>("tir.Simplify")
- .value_or(AttrsWithDefaultValues<arith::SimplifyConfig>());
- analyzer.rewrite_simplify.SetEnabledExtensions(cfg->GetEnabledExtensions());
+ auto cfg = ctx->GetConfig<arith::SimplifyConfig>("tir.Simplify");
auto* n = f.CopyOnWrite();
- n->body = arith::StmtSimplifier(&analyzer).Simplify(std::move(n->body));
+ n->body = arith::StmtSimplifier::Apply(std::move(n->body), &analyzer, cfg);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.Simplify", {});
diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py
index 4477e1d9c7..4199cb9a56 100644
--- a/tests/python/unittest/test_arith_rewrite_simplify.py
+++ b/tests/python/unittest/test_arith_rewrite_simplify.py
@@ -16,7 +16,7 @@
# under the License.
import pytest
import tvm
-from tvm import te
+from tvm import te, tir
class RewriteChecker:
@@ -873,6 +873,65 @@ def test_cmp_simplify():
ck.verify(fld(x + 2, 4) * 4 >= x - y, tvm.tir.LE(flm(x + 2, 4) + (-2), y))
# End DivMod Rules
+ # merging flm/fld into known value
+ ck.verify(tir.all(fld(x, 8) == 3, flm(x, 8) == 4), x == 28)
+ ck.verify(tir.all(flm(x, 8) == 4, fld(x, 8) == 3), x == 28)
+ ck.verify(tir.all(fld(x, 8) == -3, flm(x, 8) == 4), x == -20)
+ ck.verify(tir.all(flm(x, 8) == 4, fld(x, 8) == -3), x == -20)
+
+ # Rewrite based on definition of integer division
+ ck.verify(tir.all(tvm.runtime.convert(0) <= x - y * 5, x - y * 5 < 5), y == fld(x, 5))
+ ck.verify(tir.all(x - y * 5 < 5, tvm.runtime.convert(0) <= x - y * 5), y == fld(x, 5))
+
+ # Narrow upper bound using floormod
+ ck.verify(tir.all(x < 20, flm(x, 5) < 2), tir.all(x < 17, flm(x, 5) < 2))
+ ck.verify(tir.all(x < 18, flm(x, 5) < 2), tir.all(x < 17, flm(x, 5) < 2))
+ ck.verify(tir.all(x <= 19, flm(x, 5) < 2), tir.all(x < 17, flm(x, 5) < 2))
+ ck.verify(tir.all(x <= 18, flm(x, 5) < 2), tir.all(x < 17, flm(x, 5) < 2))
+ ck.verify(tir.all(x < -20, flm(x, 5) < 2), tir.all(x < -23, flm(x, 5) < 2))
+ ck.verify(tir.all(x < 18 - 40, flm(x, 5) < 2), tir.all(x < 17 - 40, flm(x, 5) < 2))
+ ck.verify(tir.all(x <= -21, flm(x, 5) < 2), tir.all(x < -23, flm(x, 5) < 2))
+ ck.verify(tir.all(x <= -22, flm(x, 5) < 2), tir.all(x < -23, flm(x, 5) < 2))
+ # No change if the floormod cannot help narrow the upper bound
+ ck.verify(tir.all(x < 16, flm(x, 5) < 2), tir.all(x < 16, flm(x, 5) < 2))
+ ck.verify(tir.all(x <= 15, flm(x, 5) < 2), tir.all(x <= 15, flm(x, 5) < 2))
+
+ # Merge a known floordiv and an upper bound of floormod into a value range
+ ck.verify(
+ tir.all(fld(x, 10) == 5, flm(x, 10) < 7),
+ tir.all(tvm.runtime.convert(50) <= x, x < 57),
+ )
+ ck.verify(
+ tir.all(fld(x, 10) == 5, flm(x, 10) <= 7),
+ tir.all(tvm.runtime.convert(50) <= x, x <= 57),
+ )
+ ck.verify(
+ tir.all(fld(x, 10) == -5, flm(x, 10) < 7),
+ tir.all(tvm.runtime.convert(-50) <= x, x < -43),
+ )
+ ck.verify(
+ tir.all(fld(x, 10) == -5, flm(x, 10) <= 7),
+ tir.all(tvm.runtime.convert(-50) <= x, x <= -43),
+ )
+
+ # Merge a known floordiv and an lower bound of floormod into a value range
+ ck.verify(
+ tir.all(fld(x, 10) == 5, tvm.runtime.convert(7) < flm(x, 10)),
+ tir.all(tvm.runtime.convert(57) < x, x < 60),
+ )
+ ck.verify(
+ tir.all(fld(x, 10) == 5, tvm.runtime.convert(7) <= flm(x, 10)),
+ tir.all(tvm.runtime.convert(57) <= x, x < 60),
+ )
+ ck.verify(
+ tir.all(fld(x, 10) == -5, tvm.runtime.convert(7) < flm(x, 10)),
+ tir.all(tvm.runtime.convert(-43) < x, x < -40),
+ )
+ ck.verify(
+ tir.all(fld(x, 10) == -5, tvm.runtime.convert(7) <= flm(x, 10)),
+ tir.all(tvm.runtime.convert(-43) <= x, x < -40),
+ )
+
ck.verify(tvm.te.min(x, 11) < 10, x < 10)
ck.verify(tvm.te.min(x, 8) < 10, tvm.tir.const(1, "bool"))
ck.verify(tvm.te.max(8, x) > 10, tvm.tir.LT(10, x))
diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py
index 8d9c76c6b2..fd98b715a4 100644
--- a/tests/python/unittest/test_tir_transform_simplify.py
+++ b/tests/python/unittest/test_tir_transform_simplify.py
@@ -140,6 +140,8 @@ class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
transitively_prove_inequalities = False
convert_boolean_to_and_of_ors = False
apply_constraints_to_boolean_branches = False
+ propagate_knowns_to_prove_conditional = False
+ propagate_knowns_to_simplify_expressions = False
def transform(self):
def inner(mod):
@@ -148,6 +150,8 @@ class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
"transitively_prove_inequalities": self.transitively_prove_inequalities,
"convert_boolean_to_and_of_ors": self.convert_boolean_to_and_of_ors,
"apply_constraints_to_boolean_branches": self.apply_constraints_to_boolean_branches,
+ "propagate_knowns_to_prove_conditional": self.propagate_knowns_to_prove_conditional,
+ "propagate_knowns_to_simplify_expressions": self.propagate_knowns_to_simplify_expressions,
}
}
with tvm.transform.PassContext(config=config):
@@ -777,7 +781,7 @@ class TestRewriteAsAndOfOrsWithSimplificationBetweenReorderedGroups(BaseBeforeAf
A[0] = (i == 0 or j == 10 or k == 20) and (j == 10 or k != 30 or i == 0)
def expected(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32, k: T.int32):
- A[0] = i == 0 or j == 10 or k == 20
+ A[0] = j == 10 or k == 20 or i == 0
class TestRewriteAsAndOfOrUsingSimplificationAcrossAnd(BaseBeforeAfter):
@@ -794,7 +798,7 @@ class TestRewriteAsAndOfOrUsingSimplificationAcrossAnd(BaseBeforeAfter):
A[0] = (k == 20) and ((i == 0 or j == 10) and (k != 30))
def expected(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32, k: T.int32):
- A[0] = (k == 20) and (i == 0 or j == 10)
+ A[0] = (i == 0 or j == 10) and (k == 20)
class TestRewriteAsAndOfOrUsingSimplificationWithinOr(BaseBeforeAfter):
@@ -815,7 +819,7 @@ class TestRewriteAsAndOfOrUsingSimplificationWithinOr(BaseBeforeAfter):
A[0] = (i == 20) or (j == 0) or (i != 30)
def expected(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32, k: T.int32):
- A[0] = (i != 30) or (j == 0)
+ A[0] = (j == 0) or (i != 30)
class TestConditionalFloorMod(BaseBeforeAfter):
@@ -1049,5 +1053,640 @@ class TestMostRestrictiveConditional(BaseBeforeAfter):
return func
+class TestProvableConditionWithOffset(BaseBeforeAfter):
+ """Use scoped-constraint to prove inequalities"""
+
+ transitively_prove_inequalities = False
+
+ def before(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32):
+ if i < j:
+ A[0] = i < j + 1
+
+ def expected(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32):
+ if i < j:
+ A[0] = True
+
+
+class TestAlteredBufferContents(BaseBeforeAfter):
+ """Propagation of data-dependent conditionals.
+
+ A literal constraint must not be propagated if the values
+ referenced may change. TIR requires single assignment of
+ variables, so Var objects may be assumed constant, but BufferLoad
+ may not.
+ """
+
+ propagate_knowns_to_prove_conditional = True
+
+ def before(A: T.Buffer[(1,), "int32"], n: T.int32):
+ if A[0] == n:
+ A[0] = A[0] + 1
+ # If the simplifier incorrectly uses the invalidated
+ # A[0]==n condition required to reach this point, then it
+ # will incorrectly simplify to the then-case. If the
+ # simplifier correctly determines that A[0] now contains
+ # n+1, then it will correctly simplify to the else-case.
+ if A[0] == n:
+ A[0] = 5
+ else:
+ A[0] = 10
+
+ def expected(A: T.Buffer[(1,), "int32"], n: T.int32):
+ if A[0] == n:
+ A[0] = A[0] + 1
+ A[0] = 10
+
+
+class TestPossiblyAlteredBufferContents(BaseBeforeAfter):
+ """No simplification of data-dependent conditionals.
+
+ Like TestAlteredBufferContents, but the `m==0` conditional
+ prevents the value of `A[0]` from being known at the point of the
+ inner conditional, either as `A[0] == n` from the outer
+ conditional or as `A[0] == n+1` from the write statement.
+ """
+
+ propagate_knowns_to_prove_conditional = True
+
+ def before(A: T.Buffer[(1,), "int32"], n: T.int32, m: T.int32):
+ if A[0] == n:
+ if m == 0:
+ A[0] = A[0] + 1
+
+ if A[0] == n:
+ A[0] = 5
+ else:
+ A[0] = 10
+
+ expected = before
+
+
+class TestSimplifyInputAssumption(BaseBeforeAfter):
+ """A T.assume annotation may be used to simplify"""
+
+ propagate_knowns_to_prove_conditional = True
+
+ def before(A: T.Buffer[1, "int32"], n: T.int32):
+ T.evaluate(T.assume(n == 0))
+ if n == 0:
+ A[0] = 42
+
+ def expected(A: T.Buffer[1, "int32"], n: T.int32):
+ T.evaluate(T.assume(n == 0))
+ A[0] = 42
+
+
+class TestSimplifyInputAssumption(BaseBeforeAfter):
+ """A T.assume annotation may be used to simplify"""
+
+ propagate_knowns_to_prove_conditional = True
+
+ def before(A: T.Buffer[1, "int32"], n: T.int32):
+ T.evaluate(T.assume(n == 0))
+ if n == 0:
+ A[0] = 42
+
+ def expected(A: T.Buffer[1, "int32"], n: T.int32):
+ T.evaluate(T.assume(n == 0))
+ A[0] = 42
+
+
+class TestNoSimplifyFromScopedInputAssumption(BaseBeforeAfter):
+ """A T.assume inside a scope may not apply outside that scope"""
+
+ propagate_knowns_to_prove_conditional = True
+
+ def before(A: T.Buffer[1, "int32"], n: T.int32, m: T.int32):
+ if m == 0:
+ T.evaluate(T.assume(n == 0))
+
+ if n == 0:
+ A[0] = 42
+
+ expected = before
+
+
+class TestSimplifyConditionalUsingBufferValue(BaseBeforeAfter):
+ """Simplify a conditional using the known value in the buffer"""
+
+ propagate_knowns_to_prove_conditional = True
+
+ def before(A: T.Buffer[1, "int32"]):
+ A[0] = 0
+
+ if A[0] == 0:
+ A[0] = 42
+
+ def expected(A: T.Buffer[1, "int32"]):
+ A[0] = 0
+ A[0] = 42
+
+
+class TestKeepExpressionSimplifyUsingBufferValue(BaseBeforeAfter):
+ """Do not simplify expressions in general using known values in the buffer
+
+ For now, because this is equivalent to inlining, preventing this
+ usage from occurring. Known buffer values may be used to prove
+ conditionals, but should not be used for other simplifications.
+ """
+
+ propagate_knowns_to_prove_conditional = True
+
+ def before(A: T.Buffer[1, "int32"], B: T.Buffer[1, "int32"]):
+ A[0] = 0
+ B[0] = A[0]
+
+ expected = before
+
+
+class TestSimplifyConditionalInLoopUsingBufferValue(BaseBeforeAfter):
+ """Simplify a conditional using the known value in the buffer
+
+ Like TestSimplifyConditionalUsingBufferValue, but the value used
+ to simplify is set in a previous loop.
+ """
+
+ propagate_knowns_to_prove_conditional = True
+
+ def before(A: T.Buffer[16, "int32"], B: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ A[i] = i
+
+ for j in T.serial(16):
+ if A[j] == j:
+ B[j] = 42
+ else:
+ B[j] = 100
+
+ def expected(A: T.Buffer[16, "int32"], B: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ A[i] = i
+
+ for j in T.serial(16):
+ B[j] = 42
+
+
+class TestSimplifyUsingBufferAssumption(BaseBeforeAfter):
+ """A T.assume may apply to a buffer's contents"""
+
+ propagate_knowns_to_prove_conditional = True
+
+ def before(A: T.Buffer[1, "int32"]):
+ T.evaluate(T.assume(A[0] == 0))
+
+ if A[0] == 0:
+ A[0] = 42
+
+ def expected(A: T.Buffer[1, "int32"]):
+ T.evaluate(T.assume(A[0] == 0))
+ A[0] = 42
+
+
+class TestSimplifyUsingBufferAssumptionInLoop(BaseBeforeAfter):
+ """An assumption about buffer contents may apply to a range"""
+
+ propagate_knowns_to_prove_conditional = True
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ T.evaluate(T.assume(A[i] == i))
+
+ for i in T.serial(16):
+ if A[i] < 100:
+ A[i] = 0
+
+ def expected(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ T.evaluate(T.assume(A[i] == i))
+
+ for i in T.serial(16):
+ A[i] = 0
+
+
+class TestSimplifyUsingPartiallyKnownBufferConditional(BaseBeforeAfter):
+ """An assumption about buffer contents may apply to only part of a buffer"""
+
+ propagate_knowns_to_prove_conditional = True
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ if 14 <= i:
+ T.evaluate(T.assume(A[i] == 0))
+
+ for i in T.serial(16):
+ if 14 <= i:
+ if A[i] == 0:
+ A[i] = 42
+
+ else:
+ if A[i] == 0:
+ A[i] = 100
+
+ def expected(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ if 14 <= i:
+ T.evaluate(T.assume(A[i] == 0))
+
+ for i in T.serial(16):
+ if 14 <= i:
+ A[i] = 42
+
+ else:
+ if A[i] == 0:
+ A[i] = 100
+
+
+class TestSimplifyUsingPartiallyKnownBufferExpression(BaseBeforeAfter):
+ """An assumption about buffer contents may apply to only part of a buffer
+
+ Like TestSimplifyUsingPartiallyKnownBufferConditional, but the
+ conditional is expressed as part of T.assume, instead of in the
+ control flow.
+ """
+
+ propagate_knowns_to_prove_conditional = True
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ T.evaluate(T.assume(i < 14 or A[i] == 0))
+
+ for i in T.serial(16):
+ if 14 <= i:
+ if A[i] == 0:
+ A[i] = 42
+
+ def expected(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ T.evaluate(T.assume(i < 14 or A[i] == 0))
+
+ for i in T.serial(16):
+ if 14 <= i:
+ A[i] = 42
+
+
+class TestNoSimplificationIfPredicateNotMet(BaseBeforeAfter):
+ """Assumptions about buffer contents must apply to all cases to be used
+
+ Like TestSimplifyUsingPartialBufferAssumptionInLoop, but the
+ predicate in the second loop does not match the predicate in the
+ first loop. Therefore, the `T.assume` refers to a different set
+ of indices.
+ """
+
+ propagate_knowns_to_prove_conditional = True
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ if 14 <= i:
+ T.evaluate(T.assume(A[i] == 0))
+
+ for i in T.serial(16):
+ if i < 14:
+ if A[i] == 0:
+ A[i] = 42
+
+ expected = before
+
+
+class TestNoSimplifyUsingInvalidatedScopedConstraint(BaseBeforeAfter):
+ """A write may not be used for proofs outside its conditional"""
+
+ propagate_knowns_to_prove_conditional = True
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ if i == 0:
+ A[i] = 0
+
+ if A[i] == 0:
+ A[i] = 42
+
+ expected = before
+
+
+class TestNoSimplifyUsingOverwrittenValue(BaseBeforeAfter):
+ """A write that may have been overwritten may not be treated as known
+
+ The appearance of "A[i] = 5" must prevent the earlier constraint
+ from being used for simplification.
+ """
+
+ propagate_knowns_to_prove_conditional = True
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ T.evaluate(T.assume(A[i] == 0))
+
+ for i in T.serial(16):
+ if i == 0:
+ A[i] = 5
+
+ if A[i] == 0:
+ A[i] = 42
+
+ expected = before
+
+
+class TestNoSimplifyUsingLoopDependentBufferValue(BaseBeforeAfter):
+ """Do not simplify assuming reads are invariant
+
+ If a buffer's value changes across loop iterations, the buffer's
+ value before the loop should not be used to simplify conditionals
+ within the loop.
+ """
+
+ propagate_knowns_to_prove_conditional = True
+
+ def before(A: T.Buffer[16, "int32"], B: T.Buffer[1, "int32"]):
+ B[0] = 0
+ for i in T.serial(16):
+ if B[0] < 10:
+ B[0] = A[i] * 2 + B[0]
+ else:
+ B[0] = A[i] + B[0]
+
+ expected = before
+
+
+class TestSimplifyPriorToOverwrittenValue(BaseBeforeAfter):
+ """A known value may be used until it is overwritten
+
+ Like TestNoSimplifyUsingOverwrittenValue, but the use of the
+ known `A[i]` value occurs before it is overwritten.
+
+ Like TestNoSimplifyUsingLoopDependentBufferValue, but the loop
+ iterations are all independent.
+ """
+
+ propagate_knowns_to_prove_conditional = True
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ T.evaluate(T.assume(A[i] == 0))
+
+ for i in T.serial(16):
+ if A[i] == 0:
+ A[i] = 17
+
+ if i == 0:
+ A[i] = 5
+
+ if A[i] == 0:
+ A[i] = 42
+
+ def expected(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ T.evaluate(T.assume(A[i] == 0))
+
+ for i in T.serial(16):
+ A[i] = 17
+
+ if i == 0:
+ A[i] = 5
+
+ if A[i] == 0:
+ A[i] = 42
+
+
+class TestSimplifyElementWiseUsingPreLoopBufferValue(BaseBeforeAfter):
+ """Allow data-Do not simplify assuming reads are invariant
+
+ If an element-wise loop reads and overwrites a buffer value, the
+ pre-loop buffer value may be used to simplify conditions that
+ occur prior to the write.
+ """
+
+ propagate_knowns_to_prove_conditional = True
+
+ def before(A: T.Buffer[16, "int32"], B: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ B[i] = 0
+
+ for i in T.serial(16):
+ if B[i] < 10:
+ B[i] = A[i] * 2 + B[i]
+ else:
+ B[i] = A[i] + B[i]
+
+ def expected(A: T.Buffer[16, "int32"], B: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ B[i] = 0
+
+ for i in T.serial(16):
+ B[i] = A[i] * 2 + B[i]
+
+
+class TestSimplifyNonConditional(BaseBeforeAfter):
+ """Propagate a known value to later expressions."""
+
+ propagate_knowns_to_simplify_expressions = True
+
+ def before(A: T.Buffer[1, "int32"]):
+ A[0] = 0
+ A[0] = A[0] + 1
+
+ def expected(A: T.Buffer[1, "int32"]):
+ A[0] = 0
+ A[0] = 1
+
+
+class TestSuppressSimplifyNonConditional(BaseBeforeAfter):
+ """Propagate a known value to later expressions.
+
+ Like TestSimplifyNonConditional, but with data-propagation turned off.
+ """
+
+ propagate_knowns_to_simplify_expressions = False
+
+ def before(A: T.Buffer[1, "int32"]):
+ A[0] = 0
+ A[0] = A[0] + 1
+
+ expected = before
+
+
+class TestSimplifyUsingTransitiveKnownBufferValue(BaseBeforeAfter):
+ """Propagate known buffer values
+
+ If a known value of a buffer depends on another known value, it
+ can be tracked backwards through both.
+ """
+
+ propagate_knowns_to_prove_conditional = True
+
+ def before(A: T.Buffer[1, "int32"]):
+ T.evaluate(T.assume(A[0] == 0))
+
+ A[0] = A[0] + 1
+ A[0] = A[0] + 1
+ A[0] = A[0] + 1
+
+ if A[0] == 3:
+ A[0] = 42
+
+ def expected(A: T.Buffer[1, "int32"]):
+ T.evaluate(T.assume(A[0] == 0))
+
+ A[0] = A[0] + 1
+ A[0] = A[0] + 1
+ A[0] = A[0] + 1
+
+ A[0] = 42
+
+
+class TestSimplifyRampIndexBroadcastValue(BaseBeforeAfter):
+ """Simplifications involving buffer loads with ramp indices"""
+
+ propagate_knowns_to_prove_conditional = True
+
+ def before(A: T.Buffer[4, "int32"]):
+ A[T.ramp(0, 1, 4)] = T.broadcast(0, 4)
+
+ if A[0] == 0:
+ A[0] = 42
+
+ if A[1] == 0:
+ A[1] = 60
+
+ def expected(A: T.Buffer[4, "int32"]):
+ A[T.ramp(0, 1, 4)] = T.broadcast(0, 4)
+
+ A[0] = 42
+ A[1] = 60
+
+
+class TestSimplifyRampIndexRampValue(BaseBeforeAfter):
+ """Simplifications involving buffer loads with ramp indices"""
+
+ propagate_knowns_to_prove_conditional = True
+
+ def before(A: T.Buffer[4, "int32"]):
+ A[T.ramp(0, 1, 4)] = T.ramp(11, 1, 4)
+
+ if A[0] == 11:
+ A[0] = 42
+
+ if A[1] == 12:
+ A[1] = 60
+
+ def expected(A: T.Buffer[4, "int32"]):
+ A[T.ramp(0, 1, 4)] = T.ramp(11, 1, 4)
+
+ A[0] = 42
+ A[1] = 60
+
+
+class TestSimplifyUsingPartiallyProvenBufferValueGather(BaseBeforeAfter):
+ """Propagate known buffer values in part of buffer.
+
+ Even if a constraint can't be solved for all values in an
+ assignment, it may be provable in part of a buffer. Here, the
+ known 0 values in the padding of A produces known 0 values in the
+ padding of B.
+ """
+
+ transitively_prove_inequalities = True
+ propagate_knowns_to_prove_conditional = True
+
+ def before(A: T.Buffer[24, "int32"], B: T.Buffer[24, "int32"], F: T.Buffer[3, "int32"]):
+ # A has non-zero values only in the range 3 <= i < 17
+ for i in T.serial(24):
+ T.evaluate(T.assume(((3 <= i) and (i < 17)) or A[i] == 0))
+
+ # After convoluting with F, B has non-zero values only in the
+ # range 3 <= i < 19.
+ for i in T.serial(24):
+ B[i] = 0
+ for f in T.serial(3):
+ if 0 <= i - f:
+ B[i] = B[i] + A[i - f] * F[f]
+
+ # Which means that this loop is unnecessary. It would be
+ # removed entirely in tir.transform.RemoveNoOp, but here we
+ # want to test that the simplification works as intended.
+ for i in T.serial(24):
+ if i < 3 or 19 <= i:
+ if B[i] != 0:
+ B[i] = 0
+
+ def expected(A: T.Buffer[24, "int32"], B: T.Buffer[24, "int32"], F: T.Buffer[3, "int32"]):
+ for i in T.serial(24):
+ T.evaluate(T.assume(((3 <= i) and (i < 17)) or A[i] == 0))
+
+ for i in T.serial(24):
+ B[i] = 0
+ for f in T.serial(3):
+ if 0 <= i - f:
+ B[i] = B[i] + A[i - f] * F[f]
+
+ for i in T.serial(24):
+ if i < 3 or 19 <= i:
+ T.evaluate(0)
+
+
+class TestSimplifyUsingPartiallyProvenBufferValueScatter(BaseBeforeAfter):
+ """Propagate known buffer values in part of buffer.
+
+ Like TestSimplifyUsingPartiallyProvenBufferValueGather, but the
+ compute loop is over the input buffer A, rather than the output
+ buffer B.
+ """
+
+ propagate_knowns_to_prove_conditional = True
+
+ def before(A: T.Buffer[24, "int32"], B: T.Buffer[24, "int32"], F: T.Buffer[3, "int32"]):
+ # A has non-zero values only in the range 3 <= i < 17
+ for i in T.serial(24):
+ T.evaluate(T.assume(((3 <= i) and (i < 17)) or A[i] == 0))
+
+ for i in T.serial(24):
+ B[i] = 0
+
+ # After convoluting with F, B has non-zero values only in the
+ # range 3 <= i < 19.
+ for i in T.serial(24):
+ for f in T.serial(3):
+ if i + f >= 0 and i + f < 24:
+ B[i + f] = B[i + f] + A[i] * F[f]
+
+ # Which means that this loop is unnecessary. It actually gets
+ # removed in tir.transform.RemoveNoOp, but here we want to
+ # test that the simplification works as intended.
+ for i in T.serial(24):
+ if i < 3 or 19 <= i:
+ if B[i] != 0:
+ B[i] = 0
+
+ def expected(A: T.Buffer[24, "int32"], B: T.Buffer[24, "int32"], F: T.Buffer[3, "int32"]):
+ for i in T.serial(24):
+ T.evaluate(T.assume(((3 <= i) and (i < 17)) or A[i] == 0))
+
+ for i in T.serial(24):
+ B[i] = 0
+
+ for i in T.serial(24):
+ for f in T.serial(3):
+ if i + f < 24:
+ B[i + f] = B[i + f] + A[i] * F[f]
+
+ for i in T.serial(24):
+ if i < 3 or 19 <= i:
+ T.evaluate(0)
+
+
+class TestSimplifyBufferStore(BaseBeforeAfter):
+ """Simplification using prior known"""
+
+ propagate_knowns_to_simplify_expressions = True
+
+ def before(A: T.Buffer[1, "int32"]):
+ A[0] = 5
+ A[0] = A[0] + 7
+
+ def expected(A: T.Buffer[1, "int32"]):
+ A[0] = 5
+ A[0] = 12
+
+
if __name__ == "__main__":
tvm.testing.main()