You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/11/25 22:05:44 UTC
[tvm] branch main updated: [TIR][Transform] Optional data-flow analysis in RemoveNoOp (#13217)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 101e3a4ade [TIR][Transform] Optional data-flow analysis in RemoveNoOp (#13217)
101e3a4ade is described below
commit 101e3a4ade226a2b9cdef6437a285af18aef9cf8
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Fri Nov 25 16:05:37 2022 -0600
[TIR][Transform] Optional data-flow analysis in RemoveNoOp (#13217)
* [TIR][Transform] Optional data-flow analysis in RemoveNoOp
Previously, `RemoveNoOp` would remove statements that could be locally
analyzed as having no effect (e.g. `For` with empty loop extents).
This commit adds opt-in use of data-flow analysis to identify
two types of statements that are no-ops based on their context:
* Buffer stores that are overwritten without ever being read.
```python
buf[i] = 5 # Overwritten by next statement
buf[i] = 10
```
* Storing a value that is already known to be present.
```python
buf[0:16] = T.ramp(0, 16, 1)
buf[5] = 5 # Previous load already stored this value
```
* Avoid dangling pointers in var_range_map_
---
src/arith/rewrite_simplify.cc | 7 +
src/tir/analysis/control_flow_graph.cc | 117 +++--
src/tir/analysis/control_flow_graph.h | 12 +-
src/tir/transforms/remove_no_op.cc | 230 +++++++--
.../unittest/test_tir_transform_remove_no_op.py | 521 +++++++++++++++++++++
.../python/unittest/test_tir_transform_simplify.py | 1 +
6 files changed, 796 insertions(+), 92 deletions(-)
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index e6d876cf5a..90c448f4ea 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -1644,6 +1644,11 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) {
TVM_TRY_RECURSIVE_REWRITE(x + c1 < c2, x < c2 - c1);
TVM_TRY_RECURSIVE_REWRITE(x - c1 < c2, x < c2 + c1);
TVM_TRY_REWRITE(x - c1 < 0, x < c1);
+
+ TVM_TRY_RECURSIVE_REWRITE(x - 1 < y, x <= y);
+ TVM_TRY_RECURSIVE_REWRITE(x < y + 1, x <= y);
+ TVM_TRY_RECURSIVE_REWRITE(x + (-1) < y, x <= y);
+ TVM_TRY_RECURSIVE_REWRITE(x < y - (-1), x <= y);
// clang-format on
}
return std::move(ret);
@@ -1886,6 +1891,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) {
TVM_TRY_REWRITE(x <= y || y < x, ctrue);
TVM_TRY_REWRITE(y < x || x <= y, ctrue);
+ TVM_TRY_REWRITE(x < y || y < x, x != y);
+
TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue, c2.Eval()->value < c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue, c2.Eval()->value < c1.Eval()->value);
diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tir/analysis/control_flow_graph.cc
index 42c5c8bb82..2e537450d2 100644
--- a/src/tir/analysis/control_flow_graph.cc
+++ b/src/tir/analysis/control_flow_graph.cc
@@ -31,6 +31,7 @@
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
+#include <algorithm>
#include <numeric>
#include <optional>
#include <queue>
@@ -819,10 +820,30 @@ BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph
return buffer_touch;
}
-ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, size_t max_revisits) {
+ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, size_t max_revisits)
+ : max_revisits_(max_revisits) {
ControlFlowGraphBuilder::Build(this, stmt);
- ForwardPropagateKnownValues(max_revisits);
- BackwardPropagateUnusedValues(max_revisits);
+ ForwardPropagateKnownValues();
+ BackwardPropagateUnusedValues();
+}
+
+void ControlFlowGraph::RemoveStore(const tir::BufferStore& store) {
+ size_t context_index = [&]() {
+ auto it = control_flow_lookup_.find(store.get());
+ ICHECK(it != control_flow_lookup_.end())
+ << "BufferStore did not occur in the Stmt provided to BufferTouchPattern's constructor";
+ return it->second;
+ }();
+
+ auto& touch_points = control_flow_[context_index].touch_points;
+
+ touch_points.erase(std::remove_if(touch_points.begin(), touch_points.end(),
+ [](const BufferTouch& touch) {
+ return touch.touch_type == BufferTouch::AccessType::Write;
+ }),
+ touch_points.end());
+ ForwardPropagateKnownValues(context_index);
+ BackwardPropagateUnusedValues(context_index);
}
std::ostream& operator<<(std::ostream& os, const ControlFlowGraph::ControlFlowEdge& edge) {
@@ -1327,33 +1348,38 @@ Array<Var> ControlFlowGraph::GetIndexVariables(const Buffer& buf, const Array<Pr
return vars;
}
-void ControlFlowGraph::ForwardPropagateKnownValues(size_t max_revisits) {
+void ControlFlowGraph::ForwardPropagateKnownValues(std::optional<size_t> flow_from) {
// Values to visit when searching. Using a std::set to
// preferentially visit nodes near the start of the control flow.
std::set<size_t> to_visit;
- // Map from a block's index
- std::unordered_map<size_t, size_t> visit_count_lookup;
-
- // Initiatize the locations to search from, propagating values
- // forward from all locations that have a known value.
- for (size_t i = 0; i < control_flow_.size(); i++) {
- bool has_known_value = false;
- for (const auto& touch : control_flow_[i].touch_points) {
- if (!HasBufferLoad(touch.value)) {
- has_known_value = true;
- break;
+ if (flow_from.has_value()) {
+ to_visit.insert(flow_from.value());
+ } else {
+ // Initiatize the locations to search from, propagating values
+ // forward from all locations that have a known value.
+ for (size_t i = 0; i < control_flow_.size(); i++) {
+ bool has_known_value = false;
+ for (const auto& touch : control_flow_[i].touch_points) {
+ if (!HasBufferLoad(touch.value)) {
+ has_known_value = true;
+ break;
+ }
}
- }
- if (has_known_value) {
- to_visit.insert(i);
+ if (has_known_value) {
+ to_visit.insert(i);
+ }
}
}
+ // Map from a block's index
+ std::unordered_map<size_t, size_t> visit_count_lookup;
+
Analyzer analyzer;
analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
arith::RewriteSimplifier::kTransitivelyProveInequalities |
+ arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches));
analyzer.Bind(iterator_ranges_);
@@ -1369,7 +1395,7 @@ void ControlFlowGraph::ForwardPropagateKnownValues(size_t max_revisits) {
// Step 1: Collect known values provided from each predecessor
block.known_at_block_start = [&]() -> BufferState {
- if (num_previous_visits >= max_revisits) {
+ if (num_previous_visits >= max_revisits_) {
return BufferState();
}
@@ -1437,7 +1463,7 @@ void ControlFlowGraph::ForwardPropagateKnownValues(size_t max_revisits) {
// Step 2: Collect knowns provided as a result of executing this block
auto post_state = [&]() {
- if (num_previous_visits >= max_revisits) {
+ if (num_previous_visits >= max_revisits_) {
return BufferState();
}
auto post_state = block.known_at_block_start;
@@ -1459,29 +1485,35 @@ void ControlFlowGraph::ForwardPropagateKnownValues(size_t max_revisits) {
}
}
-void ControlFlowGraph::BackwardPropagateUnusedValues(size_t max_revisits) {
+void ControlFlowGraph::BackwardPropagateUnusedValues(std::optional<size_t> flow_from) {
// Values to visit when searching. Using a std::set to
// preferentially visit nodes near the end of the control flow.
std::set<size_t> to_visit;
- // Map from a block's index
- std::unordered_map<size_t, size_t> visit_count_lookup;
-
- // Initiatize the locations to search from, propagating values
- // backward from anywhere that performs a write.
- for (size_t i = 0; i < control_flow_.size(); i++) {
- const auto& touch_points = control_flow_[i].touch_points;
- bool performs_write = std::any_of(
- touch_points.begin(), touch_points.end(),
- [](const auto& touch) { return touch.touch_type == BufferTouch::AccessType::Write; });
- if (performs_write) {
- to_visit.insert(i);
+ if (flow_from.has_value()) {
+ to_visit.insert(flow_from.value());
+ } else {
+ // Initiatize the locations to search from, propagating values
+ // backward from anywhere that performs a write.
+ for (size_t i = 0; i < control_flow_.size(); i++) {
+ const auto& touch_points = control_flow_[i].touch_points;
+ bool performs_write = std::any_of(
+ touch_points.begin(), touch_points.end(),
+ [](const auto& touch) { return touch.touch_type == BufferTouch::AccessType::Write; });
+ if (performs_write) {
+ to_visit.insert(i);
+ }
}
}
+ // Map from a block's index
+ std::unordered_map<size_t, size_t> visit_count_lookup;
+
Analyzer analyzer;
- analyzer.rewrite_simplify.SetEnabledExtensions(
- arith::RewriteSimplifier::kTransitivelyProveInequalities);
+ analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
+ arith::RewriteSimplifier::kTransitivelyProveInequalities |
+ arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
+ arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches));
analyzer.Bind(iterator_ranges_);
analyzer.Bind(free_predicate_parameters_);
@@ -1496,7 +1528,7 @@ void ControlFlowGraph::BackwardPropagateUnusedValues(size_t max_revisits) {
// Step 1: Collect known unused indices provided by each successor
block.unused_at_block_end = [&]() -> BufferState {
- if (num_previous_visits >= max_revisits) {
+ if (num_previous_visits >= max_revisits_) {
return BufferState();
}
ICHECK_LE(block.successors.size(), 2)
@@ -1561,7 +1593,7 @@ void ControlFlowGraph::BackwardPropagateUnusedValues(size_t max_revisits) {
// Step 2: Collect knowns provided as a result of executing this block
auto unused_at_block_start = [&]() {
- if (num_previous_visits >= max_revisits) {
+ if (num_previous_visits >= max_revisits_) {
return BufferState();
}
auto prior_state = block.unused_at_block_end;
@@ -1603,8 +1635,10 @@ bool ControlFlowGraph::IsOverwrittenWithoutEffect(const tir::BufferStore& store,
local_analyzer.Bind(free_predicate_parameters_);
local_analyzer.Bind(iterator_ranges_);
local_analyzer.Bind(free_params);
- local_analyzer.rewrite_simplify.SetEnabledExtensions(
- RewriteSimplifier::kTransitivelyProveInequalities);
+ local_analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
+ arith::RewriteSimplifier::kTransitivelyProveInequalities |
+ arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
+ arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches));
PrimExpr predicate = store_touch.predicate && store_touch.AtLoopIteration();
@@ -1630,13 +1664,16 @@ PrimExpr ControlFlowGraph::SimplifyInContext(PrimExpr expr, const tir::Stmt& con
return it->second;
}();
+ const auto& control_flow_block = control_flow_[context_index];
+
PrimExpr constraint = Bool(true);
for (const auto& known : non_buffer_assumptions_) {
constraint = constraint && known;
}
With<ConstraintContext> constraint_context(analyzer, constraint);
+ With<ConstraintContext> control_flow_scope(analyzer, control_flow_block.scope_predicate);
- expr = control_flow_[context_index].known_at_block_start.SubstituteKnownBufferValues(
+ expr = control_flow_block.known_at_block_start.SubstituteKnownBufferValues(
std::move(expr), axis_var_lookup_, analyzer);
expr = analyzer->Simplify(std::move(expr));
diff --git a/src/tir/analysis/control_flow_graph.h b/src/tir/analysis/control_flow_graph.h
index aa9023ba29..590392cf65 100644
--- a/src/tir/analysis/control_flow_graph.h
+++ b/src/tir/analysis/control_flow_graph.h
@@ -29,6 +29,7 @@
#include <tvm/tir/stmt.h>
#include <tvm/tir/var.h>
+#include <optional>
#include <unordered_map>
#include <utility>
#include <vector>
@@ -474,13 +475,17 @@ class ControlFlowGraph {
/*! \brief Propagate known values from known BufferStore/assume
* subsequent control flow blocks
+ *
+ * \param flow_from If specified, re-flow only from that block.
*/
- void ForwardPropagateKnownValues(size_t max_revisits);
+ void ForwardPropagateKnownValues(std::optional<size_t> flow_from = std::nullopt);
/*! \brief Propagate overwritten/unused indices to preceding control
* flow blocks
+ *
+ * \param flow_from If specified, re-flow only from that block.
*/
- void BackwardPropagateUnusedValues(size_t max_revisits);
+ void BackwardPropagateUnusedValues(std::optional<size_t> flow_from = std::nullopt);
struct ControlFlowEdge {
/* \brief The source block of the control flow edge
@@ -646,6 +651,9 @@ class ControlFlowGraph {
std::vector<PrimExpr> non_buffer_assumptions_;
friend class ControlFlowGraphBuilder;
+
+ /*! \brief The maximum number of revisits while flowing constraints */
+ size_t max_revisits_;
};
} // namespace tir
diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc
index 41250408a7..3374f975f5 100644
--- a/src/tir/transforms/remove_no_op.cc
+++ b/src/tir/transforms/remove_no_op.cc
@@ -29,21 +29,71 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
+#include <optional>
#include <unordered_map>
#include "../../arith/const_fold.h"
+#include "../../arith/ir_mutator_with_analyzer.h"
+#include "../analysis/control_flow_graph.h"
#include "ir_utils.h"
namespace tvm {
namespace tir {
+struct RemoveNoOpConfigNode : public tvm::AttrsNode<RemoveNoOpConfigNode> {
+ bool use_dataflow_analysis;
+
+ TVM_DECLARE_ATTRS(RemoveNoOpConfigNode, "tir.transform.RemoveNoOpConfig") {
+ TVM_ATTR_FIELD(use_dataflow_analysis)
+ .describe(
+ "If true, known buffer values are propagated and used "
+ "to statically prove statements as no-ops.")
+ .set_default(false);
+ }
+};
+
+class RemoveNoOpConfig : public Attrs {
+ public:
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RemoveNoOpConfig, Attrs, RemoveNoOpConfigNode);
+};
+
+TVM_REGISTER_NODE_TYPE(RemoveNoOpConfigNode);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.RemoveNoOp", RemoveNoOpConfig);
+
// Mark the statement of each stage.
-class NoOpRemover : public StmtMutator {
+class NoOpRemover : public arith::IRMutatorWithAnalyzer {
public:
+ static Stmt Apply(Stmt stmt, arith::Analyzer* analyzer,
+ std::optional<ControlFlowGraph> touch_pattern, const StmtNode* context) {
+ NoOpRemover visitor(analyzer, touch_pattern, context);
+ return visitor(std::move(stmt));
+ }
+
+ private:
+ using Parent = IRMutatorWithAnalyzer;
+ using Parent::VisitStmt;
+ using Parent::VisitStmt_;
+
+ NoOpRemover(arith::Analyzer* analyzer, std::optional<ControlFlowGraph> touch_pattern,
+ const StmtNode* context)
+ : Parent(analyzer), touch_pattern_(touch_pattern), context_(context) {}
+
Stmt VisitStmt_(const LetStmtNode* op) final {
- Stmt stmt = StmtMutator::VisitStmt_(op);
+ Stmt stmt = Parent::VisitStmt_(op);
op = stmt.as<LetStmtNode>();
- return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
+ if (is_no_op(op->body)) {
+ return MakeEvaluate(op->value);
+ }
+
+ bool body_uses_bound_variable =
+ !UsesVar(op->body, [&](const VarNode* var) { return var == op->var.get(); });
+ if (body_uses_bound_variable && HasSideEffect(op->value)) {
+ return SeqStmt({MakeEvaluate(op->value), op->body});
+ } else if (body_uses_bound_variable) {
+ return op->body;
+ } else {
+ return stmt;
+ }
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == "pragma_debug_skip_region") {
@@ -58,24 +108,26 @@ class NoOpRemover : public StmtMutator {
// We assume that such wait is a nop.
auto inner = op->body.as<AttrStmtNode>();
ICHECK(inner);
- return StmtMutator::VisitStmt(inner->body);
+ return Parent::VisitStmt(inner->body);
}
}
- Stmt stmt = StmtMutator::VisitStmt_(op);
+ Stmt stmt = Parent::VisitStmt_(op);
op = stmt.as<AttrStmtNode>();
return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
}
Stmt VisitStmt_(const IfThenElseNode* op) final {
- Stmt stmt = StmtMutator::VisitStmt_(op);
+ Stmt stmt = Parent::VisitStmt_(op);
op = stmt.as<IfThenElseNode>();
if (op->else_case) {
- if (is_no_op(op->else_case.value())) {
- if (is_no_op(op->then_case)) {
- return MakeEvaluate(op->condition);
- } else {
- return IfThenElse(op->condition, op->then_case);
- }
+ bool no_op_else = is_no_op(op->else_case.value());
+ bool no_op_then = is_no_op(op->then_case);
+ if (no_op_else && no_op_then) {
+ return MakeEvaluate(op->condition);
+ } else if (no_op_else) {
+ return IfThenElse(op->condition, op->then_case);
+ } else if (no_op_then) {
+ return IfThenElse(!op->condition, op->else_case.value());
} else {
return stmt;
}
@@ -88,13 +140,13 @@ class NoOpRemover : public StmtMutator {
}
}
Stmt VisitStmt_(const ForNode* op) final {
- var_range_map_[op->loop_var.get()] = arith::IntSet::FromMinExtent(op->min, op->extent);
auto extent_range = arith::EvalSet(op->extent, var_range_map_);
if (!arith::is_neg_inf(extent_range.max()) && !arith::is_pos_inf(extent_range.max()) &&
- analyzer_.CanProve(extent_range.max() <= 0)) {
+ analyzer_->CanProve(extent_range.max() <= 0)) {
return Evaluate(0);
}
- Stmt stmt = StmtMutator::VisitStmt_(op);
+ var_range_map_[op->loop_var.get()] = arith::IntSet::FromMinExtent(op->min, op->extent);
+ Stmt stmt = Parent::VisitStmt_(op);
var_range_map_.erase(op->loop_var.get());
op = stmt.as<ForNode>();
if (is_zero(op->extent)) {
@@ -114,42 +166,104 @@ class NoOpRemover : public StmtMutator {
return is_no_op(op->body) ? op->body : stmt;
}
Stmt VisitStmt_(const EvaluateNode* op) final {
- if (SideEffect(op->value) > CallEffectKind::kReadState) return GetRef<Stmt>(op);
- return Evaluate(0);
+ if (HasSideEffect(op->value)) {
+ return GetRef<Stmt>(op);
+ } else {
+ return Evaluate(0);
+ }
}
Stmt VisitStmt_(const SeqStmtNode* op) final {
- Stmt ret = StmtMutator::VisitSeqStmt_(op, true);
- op = ret.as<SeqStmtNode>();
- ICHECK(op != nullptr);
- bool need_compact = false;
- for (size_t i = 0; i < op->size(); ++i) {
- if (is_no_op(op->seq[i])) need_compact = true;
- }
+ auto ret = Downcast<SeqStmt>(StmtMutator::VisitSeqStmt_(op, true));
+
+ bool need_compact = std::any_of(ret->seq.begin(), ret->seq.end(),
+ [](const auto& stmt) { return is_no_op(stmt); });
+
if (need_compact) {
- auto n = CopyOnWrite(op);
- size_t top = 0;
- for (size_t i = 0; i < n->seq.size(); ++i) {
- if (!is_no_op(n->seq[i])) {
- n->seq.Set(top++, n->seq[i]);
+ Array<Stmt> filtered;
+ for (Stmt stmt : ret->seq) {
+ if (!is_no_op(stmt)) {
+ filtered.push_back(std::move(stmt));
}
}
- if (top == 1) {
- return n->seq[0];
- } else {
- n->seq.resize(top);
- return Stmt(n);
- }
+ ret = SeqStmt(filtered);
+ }
+
+ if (ret->size() == 0) {
+ return Evaluate(0);
+ } else if (ret->size() == 1) {
+ return ret->seq[0];
} else {
- if (op->size() == 1) {
- return op->seq[0];
- } else {
- return ret;
+ return std::move(ret);
+ }
+ }
+
+ Stmt VisitStmt_(const BufferStoreNode* op) final {
+ BufferStore store = GetRef<BufferStore>(op);
+
+ // Helper function that returns a statement containing only the
+ // side effects of evaluating this BufferStore, but not the store
+ // itself.
+ auto only_side_effects = [&]() {
+ Array<Stmt> statements;
+ statements.push_back(MakeEvaluate(store->value));
+ for (const auto& index : store->indices) {
+ statements.push_back(MakeEvaluate(index));
+ }
+ return this->VisitStmt(SeqStmt(statements));
+ };
+
+ if (touch_pattern_.has_value()) {
+ // A write that is later overwritten is a no-op.
+ Stmt context = context_ ? GetRef<Stmt>(context_) : store;
+ if (touch_pattern_->IsOverwrittenWithoutEffect(store, context)) {
+ touch_pattern_->RemoveStore(store);
+ return only_side_effects();
+ }
+
+ // A write whose destination is known to already contain the
+ // values to be written is a no-op.
+ PrimExpr stores_existing_value = store->value == BufferLoad(store->buffer, store->indices);
+
+ PrimExpr simplified =
+ touch_pattern_->SimplifyInContext(stores_existing_value, context, analyzer_);
+ if (auto* as_int = as_const_int(simplified); as_int && *as_int) {
+ return only_side_effects();
}
}
+
+ // If the stored value is a load from the same location, the
+ // statement is a no-op, regardless of contextual information.
+ if (const BufferLoadNode* load = store->value.as<BufferLoadNode>()) {
+ if (load->buffer->data.same_as(store->buffer->data) &&
+ analyzer_->CanProveEqual(load->buffer->elem_offset, store->buffer->elem_offset) &&
+ ArrayValueEqual(load->buffer->shape, store->buffer->shape) &&
+ ArrayValueEqual(load->buffer->strides, store->buffer->strides) &&
+ ArrayValueEqual(load->indices, store->indices)) {
+ return only_side_effects();
+ }
+ }
+
+ return std::move(store);
}
private:
+ bool ArrayValueEqual(const Array<PrimExpr>& a, const Array<PrimExpr>& b) {
+ if (a.size() != b.size()) {
+ return false;
+ }
+ for (size_t i = 0; i < a.size(); i++) {
+ if (!analyzer_->CanProveEqual(a[i], b[i])) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ bool HasSideEffect(const PrimExpr& value) {
+ return SideEffect(value) > CallEffectKind::kReadState;
+ }
+
Stmt MakeEvaluate(PrimExpr value) {
if (SideEffect(value) > CallEffectKind::kReadState) {
return Evaluate(value);
@@ -158,31 +272,47 @@ class NoOpRemover : public StmtMutator {
}
}
Stmt MakeEvaluate(const Array<PrimExpr>& values) {
- Stmt stmt;
+ Array<Stmt> stmts;
for (PrimExpr e : values) {
if (SideEffect(e) > CallEffectKind::kReadState) {
- if (stmt.defined()) {
- stmt = SeqStmt({stmt, Evaluate(e)});
- } else {
- stmt = Evaluate(e);
- }
+ stmts.push_back(Evaluate(e));
}
}
- return stmt.defined() ? stmt : Evaluate(0);
+
+ if (stmts.size() == 0) {
+ return Evaluate(0);
+ } else if (stmts.size() == 1) {
+ return stmts[0];
+ } else {
+ return SeqStmt(stmts);
+ }
}
std::unordered_map<const VarNode*, arith::IntSet> var_range_map_;
- arith::Analyzer analyzer_;
+ std::optional<ControlFlowGraph> touch_pattern_;
+ const StmtNode* context_;
};
-Stmt RemoveNoOp(Stmt stmt) { return NoOpRemover()(std::move(stmt)); }
-
namespace transform {
Pass RemoveNoOp() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
+
+ RemoveNoOpConfig config = ctx->GetConfig<RemoveNoOpConfig>("tir.RemoveNoOp")
+ .value_or(AttrsWithDefaultValues<RemoveNoOpConfig>());
+ if (config->use_dataflow_analysis) {
+ touch_pattern.emplace(f->body);
+ }
+
+ arith::Analyzer analyzer;
+ analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
+ arith::RewriteSimplifier::kTransitivelyProveInequalities |
+ arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
+ arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches));
+
auto* n = f.CopyOnWrite();
- n->body = NoOpRemover()(std::move(n->body));
+ n->body = NoOpRemover::Apply(std::move(n->body), &analyzer, std::move(touch_pattern), nullptr);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {});
diff --git a/tests/python/unittest/test_tir_transform_remove_no_op.py b/tests/python/unittest/test_tir_transform_remove_no_op.py
index 820e32eb7e..ce37329b7e 100644
--- a/tests/python/unittest/test_tir_transform_remove_no_op.py
+++ b/tests/python/unittest/test_tir_transform_remove_no_op.py
@@ -19,6 +19,8 @@ from tvm import te
from tvm.script import tir as T
import tvm.testing
+import pytest
+
def nop():
return tvm.tir.Evaluate(0)
@@ -82,5 +84,524 @@ def test_remove_no_op_with_invalid_extent():
assert isinstance(ret, tvm.tir.Evaluate)
+class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
+ use_dataflow_analysis = False
+
+ def transform(self):
+ def inner(mod):
+ config = {
+ "tir.RemoveNoOp": {
+ "use_dataflow_analysis": self.use_dataflow_analysis,
+ }
+ }
+ with tvm.transform.PassContext(config=config):
+ mod = tvm.tir.transform.RemoveNoOp()(mod)
+ return mod
+
+ return inner
+
+
+class TestRemoveEmptyForLoop(BaseBeforeAfter):
+ """A for-loop whose body is a no-op is itself a no-op."""
+
+ def before():
+ for i in T.serial(16):
+ T.evaluate(0)
+
+ def expected():
+ T.evaluate(0)
+
+
+class TestRemoveZeroExtentLoop(BaseBeforeAfter):
+ """A for-loop with no extent is a no-op."""
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(0):
+ A[i] = 42
+
+ def expected(A: T.Buffer[16, "int32"]):
+ T.evaluate(0)
+
+
+class TestRemoveUnusedLet(BaseBeforeAfter):
+ """A let statement that is never used is a no-op."""
+
+ def before(A: T.Buffer[16, "int32"]):
+ x = 5
+ for i in T.serial(16):
+ A[i] = 0
+
+ def expected(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ A[i] = 0
+
+
+class TestRemoveLetUsedOnlyInNoOp(BaseBeforeAfter):
+ """A let statement that is never used is a no-op.
+
+ Similar to TestRemoveUnusedLet, but the usage of the let binding
+ may have been removed by an earlier removal of another no-op.
+ """
+
+ def before(A: T.Buffer[16, "int32"]):
+ x = 5
+ for i in T.serial(0):
+ A[i] = x
+
+ def expected(A: T.Buffer[16, "int32"]):
+ T.evaluate(0)
+
+
+class TestKeepSideEffectsOfLet(BaseBeforeAfter):
+ """The side effects of a no-op let must be kept."""
+
+ def before():
+ x = T.call_extern("extern_func", dtype="int32")
+ T.evaluate(0)
+
+ def expected():
+ T.evaluate(T.call_extern("extern_func", dtype="int32"))
+
+
+class TestRemoveEmptyThenCase(BaseBeforeAfter):
+ """A no-op then_case can be removed."""
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ if i < 8:
+ T.evaluate(0)
+ else:
+ A[i] = 42
+
+ def expected(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ if not (i < 8):
+ A[i] = 42
+
+
+class TestRemoveEmptyElseCase(BaseBeforeAfter):
+ """A no-op else_case can be removed."""
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ if i < 8:
+ A[i] = 42
+ else:
+ T.evaluate(0)
+
+ def expected(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ if i < 8:
+ A[i] = 42
+
+
+class TestRemoveUnusedWrite(BaseBeforeAfter):
+ """For two sequential writes, the first is a no-op"""
+
+ use_dataflow_analysis = True
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ A[i] = 100
+ A[i] = 42
+
+ def expected(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ A[i] = 42
+
+
+class TestSuppressRemovalOfUnusedWrite(BaseBeforeAfter):
+ """Dataflow analysis requires the config to opt-in
+
+ Like TestRemoveUnusedWrite, but dataflow analysis isn't enabled.
+ """
+
+ use_dataflow_analysis = False
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ A[i] = 100
+ A[i] = 42
+
+ expected = before
+
+
+class TestKeepSideEffectsOfUnusedWrite(BaseBeforeAfter):
+ """For two sequential writes, the first value may have side effects"""
+
+ use_dataflow_analysis = True
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ A[i] = T.call_extern("extern_func", dtype="int32")
+ A[i] = 42
+
+ def expected(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ T.evaluate(T.call_extern("extern_func", dtype="int32"))
+ A[i] = 42
+
+
+class TestKeepFirstWriteWhenUsed(BaseBeforeAfter):
+ """For two sequential writes, keep the first if it is used"""
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ A[i] = 100
+ A[i] = A[i] + 1
+
+ expected = before
+
+
+class TestRemoveOverwrittenLoop(BaseBeforeAfter):
+ """Remove repeated writes to the same region
+
+ If two loops write to the same region, the first is a no-op.
+ """
+
+ use_dataflow_analysis = True
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ A[i] = 100
+
+ for i in T.serial(16):
+ A[i] = 42
+
+ def expected(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ A[i] = 42
+
+
+class TestRemoveOverwrittenSubloop(BaseBeforeAfter):
+ """Remove repeated writes to the same region
+
+ If the first loop writes to a subset of the region, the first loop
+ is a no-op. Similar to TestRemoveOverwrittenLoop, but the first
+ loop's extents are a subset of the second loop.
+ """
+
+ use_dataflow_analysis = True
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(4, 12):
+ A[i] = 100
+
+ for i in T.serial(16):
+ A[i] = 42
+
+ def expected(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ A[i] = 42
+
+
+class TestKeepPartiallyOverwrittenLoop(BaseBeforeAfter):
+ """Keep partially overwritten regions
+
+ If the second loop doesn't entirely overwrite the first, the first
+ may not be removed be kept.
+ """
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ A[i] = 100
+
+ for i in T.serial(16):
+ if i < 12:
+ A[i] = 42
+
+ expected = before
+
+
+class TestRemoveOverwrittenPredicatedLoopWithIdenticalCondition(BaseBeforeAfter):
+ """Remove repeated writes to the same predicated region.
+
+ Similar to TestKeepPartiallyOverwrittenLoop, except the first loop
+ has the same predicate as the second, and can therefore be
+ removed.
+ """
+
+ use_dataflow_analysis = True
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ if i < 12:
+ A[i] = 100
+
+ for i in T.serial(16):
+ if i < 12:
+ A[i] = 42
+
+ def expected(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ if i < 12:
+ A[i] = 42
+
+
+class TestRemoveOverwrittenPredicatedLoopWithProvableCondition(BaseBeforeAfter):
+ """Remove repeated writes to the same predicated region.
+
+ Similar to
+ TestRemoveOverwrittenPredicatedLoopWithIdenticalCondition, except
+ the first loop's predicate is not a precise match for the second
+ loop's predicate. So long as the regions written in the first
+ loop are a subset of those written in the second loop, they can be
+ removed.
+ """
+
+ use_dataflow_analysis = True
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ if i < 10:
+ A[i] = 100
+
+ for i in T.serial(16):
+ if i // 4 < 3:
+ A[i] = 42
+
+ def expected(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ if i // 4 < 3:
+ A[i] = 42
+
+
+class TestRemoveSeparatedOverwrites(BaseBeforeAfter):
+ """Remove repeated writes to the same predicated region.
+
+ Similar to TestRemoveOverwrittenLoopRegion, but with an
+ independent loop between the first and second write of the buffer.
+ """
+
+ use_dataflow_analysis = True
+
+ def before(A: T.Buffer[16, "int32"], B: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ A[i] = 100
+
+ for i in T.serial(16):
+ B[i] = 0
+
+ for i in T.serial(16):
+ A[i] = 42
+
+ def expected(A: T.Buffer[16, "int32"], B: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ B[i] = 0
+
+ for i in T.serial(16):
+ A[i] = 42
+
+
+@pytest.mark.xfail(reason="Not implemented yet")
+class TestRemoveSeparatedOverwriteOfPredicatedLoop(BaseBeforeAfter):
+ """Remove repeated writes to the same predicated region.
+
+ Similar to TestRemoveSeparatedOverwrites, but the independent loop
+ between the first and second writes writes to a different subset
+ of the same buffer.
+ """
+
+ use_dataflow_analysis = True
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ if i < 12:
+ A[i] = 100
+
+ for i in T.serial(16):
+ if i > 12:
+ A[i] = 15
+
+ for i in T.serial(16):
+ if i < 12:
+ A[i] = 42
+
+ def expected(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ if i > 12:
+ A[i] = 15
+
+ for i in T.serial(16):
+ if i < 12:
+ A[i] = 42
+
+
+class TestRemoveReadWrite(BaseBeforeAfter):
+ """Writing a value to the same location as was just read is a no-op."""
+
+ def before(A: T.Buffer[1, "int32"]):
+ A[0] = A[0]
+
+ def expected(A: T.Buffer[1, "int32"]):
+ T.evaluate(0)
+
+
+class TestKeepReadWriteToDifferentIndices(BaseBeforeAfter):
+ """Writing a value to a different index should not be removed"""
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(15):
+ A[i] = A[i + 1]
+
+ expected = before
+
+
+class TestRemoveReadWriteSameIndexDifferentExpression(BaseBeforeAfter):
+ """Writing a value to the same location as the read is a no-op.
+
+ If the value of the index can be proven to be the same, then the
+ no-op can be removed, even if they have different forms of the
+ expression.
+ """
+
+ def before(A: T.Buffer[16, "int32"]):
+ for io, ii in T.grid(4, 4):
+ i = 4 * io + ii
+ A[4 * io + ii] = A[i]
+
+ def expected(A: T.Buffer[16, "int32"]):
+ T.evaluate(0)
+
+
+class TestRemoveReadWriteSameIndexUsingConstraint(BaseBeforeAfter):
+ """Writing a value to the same location as the read is a no-op.
+
+ If the value of the index can be proven to be the same, then the
+ no-op can be removed. This may require using the a constraint
+ that is known from a conditional containing the read/write.
+ """
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ if i != 0:
+ A[i] = A[i - 1]
+ else:
+ A[i] = A[0]
+
+ def expected(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ if i != 0:
+ A[i] = A[i - 1]
+
+
+class TestRemoveWritingOfKnownValue(BaseBeforeAfter):
+ """Writing a value that already exists at that index is a no-op"""
+
+ use_dataflow_analysis = True
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ A[i] = i
+
+ A[4] = 4
+
+ def expected(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ A[i] = i
+
+
+class TestKeepOneOfDuplicateLoops(BaseBeforeAfter):
+ """Must not reason based on a touch point after removing it.
+
+ If the first loop is removed because it is overwritten by the
+ second loop, and the second loop is removed because it writes the
+ same value as the first loop, the overall transformation is no
+ longer valid. In this case, only one of the two should be
+ removed.
+ """
+
+ use_dataflow_analysis = True
+
+ def before(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ A[i] = i
+
+ for i in T.serial(16):
+ A[i] = i
+
+ def expected(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ A[i] = i
+
+
+class TestRemoveEmptyTemporary(BaseBeforeAfter):
+ """An allocation with a no-op body is a no-op."""
+
+ def before():
+ A = T.allocate([16], "int32", "local")
+ T.evaluate(0)
+
+ def expected():
+ T.evaluate(0)
+
+
+@pytest.mark.xfail(reason="Not implemented yet")
+class TestRemoveUnusedTemporary(BaseBeforeAfter):
+ """An unused allocation is a no-op."""
+
+ def before(A: T.Buffer[16, "int32"]):
+ B = T.allocate([16], "int32", "local")
+ for i in T.serial(16):
+ A[i] = 1
+
+ def expected(A: T.Buffer[16, "int32"]):
+ for i in T.serial(16):
+ A[i] = 1
+
+
+@pytest.mark.xfail(reason="Not implemented yet")
+class TestRemoveUnusedWriteIntoTemporary(BaseBeforeAfter):
+ """A write that only impacts a temporary allocation is a no-op."""
+
+ def before():
+ A = T.decl_buffer([16], "int32", scope="local")
+ for i in T.serial(16):
+ A[i] = 0
+
+ def expected():
+ T.evaluate(0)
+
+
+class TestKeepUsedWriteIntoTemporary(BaseBeforeAfter):
+ """A write into a temporary that is used later must be kept."""
+
+ def before(B: T.Buffer[16, "int32"]):
+ A = T.decl_buffer([16], "int32", scope="local")
+ for i in T.serial(16):
+ A[i] = 0
+
+ for i in T.serial(16):
+ B[i] = A[i]
+
+ expected = before
+
+
+@pytest.mark.xfail(reason="Not implemented yet")
+class TestRemoveWriteIntoTemporary(BaseBeforeAfter):
+ """A write that only impacts a temporary allocation is a no-op."""
+
+ def before(A: T.Buffer[16, "int32"], C: T.Buffer[1, "int32"]):
+ B = T.decl_buffer([16], "int32", scope="local")
+ for i in T.serial(16):
+ B[i] = A[i]
+
+ C[0] = 0
+ for i in T.serial(16):
+ C[0] = C[0] + B[i]
+
+ for i in T.serial(16):
+ B[i] = 0
+
+ def expected(A: T.Buffer[16, "int32"], C: T.Buffer[1, "int32"]):
+ B = T.decl_buffer([16], "int32", scope="local")
+ for i in T.serial(16):
+ B[i] = A[i]
+
+ C[0] = 0
+ for i in T.serial(16):
+ C[0] = C[0] + B[i]
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py
index fd98b715a4..1ddc0e50d9 100644
--- a/tests/python/unittest/test_tir_transform_simplify.py
+++ b/tests/python/unittest/test_tir_transform_simplify.py
@@ -1267,6 +1267,7 @@ class TestSimplifyUsingPartiallyKnownBufferConditional(BaseBeforeAfter):
"""An assumption about buffer contents may apply to only part of a buffer"""
propagate_knowns_to_prove_conditional = True
+ apply_constraints_to_boolean_branches = True
def before(A: T.Buffer[16, "int32"]):
for i in T.serial(16):