You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/11/25 22:05:44 UTC

[tvm] branch main updated: [TIR][Transform] Optional data-flow analysis in RemoveNoOp (#13217)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 101e3a4ade [TIR][Transform] Optional data-flow analysis in RemoveNoOp (#13217)
101e3a4ade is described below

commit 101e3a4ade226a2b9cdef6437a285af18aef9cf8
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Fri Nov 25 16:05:37 2022 -0600

    [TIR][Transform] Optional data-flow analysis in RemoveNoOp (#13217)
    
    * [TIR][Transform] Optional data-flow analysis in RemoveNoOp
    
    Previously, `RemoveNoOp` would remove statements that could be locally
    analyzed as having no effect (e.g. `For` with empty loop extents).
    This commit adds opt-in use of data-flow analysis to identify
    two types of statements that are no-ops based on their context:
    
    * Buffer stores that are overwritten without ever being read.
    
      ```python
      buf[i] = 5 # Overwritten by next statement
      buf[i] = 10
      ```
    
    * Storing a value that is already known to be present.
    
      ```python
      buf[0:16] = T.ramp(0, 16, 1)
      buf[5] = 5 # Previous load already stored this value
      ```
    
    * Avoid dangling pointers in var_range_map_
---
 src/arith/rewrite_simplify.cc                      |   7 +
 src/tir/analysis/control_flow_graph.cc             | 117 +++--
 src/tir/analysis/control_flow_graph.h              |  12 +-
 src/tir/transforms/remove_no_op.cc                 | 230 +++++++--
 .../unittest/test_tir_transform_remove_no_op.py    | 521 +++++++++++++++++++++
 .../python/unittest/test_tir_transform_simplify.py |   1 +
 6 files changed, 796 insertions(+), 92 deletions(-)

diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index e6d876cf5a..90c448f4ea 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -1644,6 +1644,11 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) {
     TVM_TRY_RECURSIVE_REWRITE(x + c1 < c2, x < c2 - c1);
     TVM_TRY_RECURSIVE_REWRITE(x - c1 < c2, x < c2 + c1);
     TVM_TRY_REWRITE(x - c1 < 0, x < c1);
+
+    TVM_TRY_RECURSIVE_REWRITE(x - 1 < y, x <= y);
+    TVM_TRY_RECURSIVE_REWRITE(x < y + 1, x <= y);
+    TVM_TRY_RECURSIVE_REWRITE(x + (-1) < y, x <= y);
+    TVM_TRY_RECURSIVE_REWRITE(x < y - (-1), x <= y);
     // clang-format on
   }
   return std::move(ret);
@@ -1886,6 +1891,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) {
   TVM_TRY_REWRITE(x <= y || y < x, ctrue);
   TVM_TRY_REWRITE(y < x || x <= y, ctrue);
 
+  TVM_TRY_REWRITE(x < y || y < x, x != y);
+
   TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue, c2.Eval()->value < c1.Eval()->value);
   TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue, c2.Eval()->value < c1.Eval()->value);
 
diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tir/analysis/control_flow_graph.cc
index 42c5c8bb82..2e537450d2 100644
--- a/src/tir/analysis/control_flow_graph.cc
+++ b/src/tir/analysis/control_flow_graph.cc
@@ -31,6 +31,7 @@
 #include <tvm/tir/op.h>
 #include <tvm/tir/stmt_functor.h>
 
+#include <algorithm>
 #include <numeric>
 #include <optional>
 #include <queue>
@@ -819,10 +820,30 @@ BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph
   return buffer_touch;
 }
 
-ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, size_t max_revisits) {
+ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, size_t max_revisits)
+    : max_revisits_(max_revisits) {
   ControlFlowGraphBuilder::Build(this, stmt);
-  ForwardPropagateKnownValues(max_revisits);
-  BackwardPropagateUnusedValues(max_revisits);
+  ForwardPropagateKnownValues();
+  BackwardPropagateUnusedValues();
+}
+
+void ControlFlowGraph::RemoveStore(const tir::BufferStore& store) {
+  size_t context_index = [&]() {
+    auto it = control_flow_lookup_.find(store.get());
+    ICHECK(it != control_flow_lookup_.end())
+        << "BufferStore did not occur in the Stmt provided to BufferTouchPattern's constructor";
+    return it->second;
+  }();
+
+  auto& touch_points = control_flow_[context_index].touch_points;
+
+  touch_points.erase(std::remove_if(touch_points.begin(), touch_points.end(),
+                                    [](const BufferTouch& touch) {
+                                      return touch.touch_type == BufferTouch::AccessType::Write;
+                                    }),
+                     touch_points.end());
+  ForwardPropagateKnownValues(context_index);
+  BackwardPropagateUnusedValues(context_index);
 }
 
 std::ostream& operator<<(std::ostream& os, const ControlFlowGraph::ControlFlowEdge& edge) {
@@ -1327,33 +1348,38 @@ Array<Var> ControlFlowGraph::GetIndexVariables(const Buffer& buf, const Array<Pr
   return vars;
 }
 
-void ControlFlowGraph::ForwardPropagateKnownValues(size_t max_revisits) {
+void ControlFlowGraph::ForwardPropagateKnownValues(std::optional<size_t> flow_from) {
   // Values to visit when searching.  Using a std::set to
   // preferentially visit nodes near the start of the control flow.
   std::set<size_t> to_visit;
 
-  // Map from a block's index
-  std::unordered_map<size_t, size_t> visit_count_lookup;
-
-  // Initiatize the locations to search from, propagating values
-  // forward from all locations that have a known value.
-  for (size_t i = 0; i < control_flow_.size(); i++) {
-    bool has_known_value = false;
-    for (const auto& touch : control_flow_[i].touch_points) {
-      if (!HasBufferLoad(touch.value)) {
-        has_known_value = true;
-        break;
+  if (flow_from.has_value()) {
+    to_visit.insert(flow_from.value());
+  } else {
+    // Initiatize the locations to search from, propagating values
+    // forward from all locations that have a known value.
+    for (size_t i = 0; i < control_flow_.size(); i++) {
+      bool has_known_value = false;
+      for (const auto& touch : control_flow_[i].touch_points) {
+        if (!HasBufferLoad(touch.value)) {
+          has_known_value = true;
+          break;
+        }
       }
-    }
 
-    if (has_known_value) {
-      to_visit.insert(i);
+      if (has_known_value) {
+        to_visit.insert(i);
+      }
     }
   }
 
+  // Map from a block's index
+  std::unordered_map<size_t, size_t> visit_count_lookup;
+
   Analyzer analyzer;
   analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
       arith::RewriteSimplifier::kTransitivelyProveInequalities |
+      arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
       arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches));
 
   analyzer.Bind(iterator_ranges_);
@@ -1369,7 +1395,7 @@ void ControlFlowGraph::ForwardPropagateKnownValues(size_t max_revisits) {
 
     // Step 1: Collect known values provided from each predecessor
     block.known_at_block_start = [&]() -> BufferState {
-      if (num_previous_visits >= max_revisits) {
+      if (num_previous_visits >= max_revisits_) {
         return BufferState();
       }
 
@@ -1437,7 +1463,7 @@ void ControlFlowGraph::ForwardPropagateKnownValues(size_t max_revisits) {
 
     // Step 2: Collect knowns provided as a result of executing this block
     auto post_state = [&]() {
-      if (num_previous_visits >= max_revisits) {
+      if (num_previous_visits >= max_revisits_) {
         return BufferState();
       }
       auto post_state = block.known_at_block_start;
@@ -1459,29 +1485,35 @@ void ControlFlowGraph::ForwardPropagateKnownValues(size_t max_revisits) {
   }
 }
 
-void ControlFlowGraph::BackwardPropagateUnusedValues(size_t max_revisits) {
+void ControlFlowGraph::BackwardPropagateUnusedValues(std::optional<size_t> flow_from) {
   // Values to visit when searching.  Using a std::set to
   // preferentially visit nodes near the end of the control flow.
   std::set<size_t> to_visit;
 
-  // Map from a block's index
-  std::unordered_map<size_t, size_t> visit_count_lookup;
-
-  // Initiatize the locations to search from, propagating values
-  // backward from anywhere that performs a write.
-  for (size_t i = 0; i < control_flow_.size(); i++) {
-    const auto& touch_points = control_flow_[i].touch_points;
-    bool performs_write = std::any_of(
-        touch_points.begin(), touch_points.end(),
-        [](const auto& touch) { return touch.touch_type == BufferTouch::AccessType::Write; });
-    if (performs_write) {
-      to_visit.insert(i);
+  if (flow_from.has_value()) {
+    to_visit.insert(flow_from.value());
+  } else {
+    // Initiatize the locations to search from, propagating values
+    // backward from anywhere that performs a write.
+    for (size_t i = 0; i < control_flow_.size(); i++) {
+      const auto& touch_points = control_flow_[i].touch_points;
+      bool performs_write = std::any_of(
+          touch_points.begin(), touch_points.end(),
+          [](const auto& touch) { return touch.touch_type == BufferTouch::AccessType::Write; });
+      if (performs_write) {
+        to_visit.insert(i);
+      }
     }
   }
 
+  // Map from a block's index
+  std::unordered_map<size_t, size_t> visit_count_lookup;
+
   Analyzer analyzer;
-  analyzer.rewrite_simplify.SetEnabledExtensions(
-      arith::RewriteSimplifier::kTransitivelyProveInequalities);
+  analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
+      arith::RewriteSimplifier::kTransitivelyProveInequalities |
+      arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
+      arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches));
 
   analyzer.Bind(iterator_ranges_);
   analyzer.Bind(free_predicate_parameters_);
@@ -1496,7 +1528,7 @@ void ControlFlowGraph::BackwardPropagateUnusedValues(size_t max_revisits) {
 
     // Step 1: Collect known unused indices provided by each successor
     block.unused_at_block_end = [&]() -> BufferState {
-      if (num_previous_visits >= max_revisits) {
+      if (num_previous_visits >= max_revisits_) {
         return BufferState();
       }
       ICHECK_LE(block.successors.size(), 2)
@@ -1561,7 +1593,7 @@ void ControlFlowGraph::BackwardPropagateUnusedValues(size_t max_revisits) {
 
     // Step 2: Collect knowns provided as a result of executing this block
     auto unused_at_block_start = [&]() {
-      if (num_previous_visits >= max_revisits) {
+      if (num_previous_visits >= max_revisits_) {
         return BufferState();
       }
       auto prior_state = block.unused_at_block_end;
@@ -1603,8 +1635,10 @@ bool ControlFlowGraph::IsOverwrittenWithoutEffect(const tir::BufferStore& store,
   local_analyzer.Bind(free_predicate_parameters_);
   local_analyzer.Bind(iterator_ranges_);
   local_analyzer.Bind(free_params);
-  local_analyzer.rewrite_simplify.SetEnabledExtensions(
-      RewriteSimplifier::kTransitivelyProveInequalities);
+  local_analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
+      arith::RewriteSimplifier::kTransitivelyProveInequalities |
+      arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
+      arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches));
 
   PrimExpr predicate = store_touch.predicate && store_touch.AtLoopIteration();
 
@@ -1630,13 +1664,16 @@ PrimExpr ControlFlowGraph::SimplifyInContext(PrimExpr expr, const tir::Stmt& con
     return it->second;
   }();
 
+  const auto& control_flow_block = control_flow_[context_index];
+
   PrimExpr constraint = Bool(true);
   for (const auto& known : non_buffer_assumptions_) {
     constraint = constraint && known;
   }
   With<ConstraintContext> constraint_context(analyzer, constraint);
+  With<ConstraintContext> control_flow_scope(analyzer, control_flow_block.scope_predicate);
 
-  expr = control_flow_[context_index].known_at_block_start.SubstituteKnownBufferValues(
+  expr = control_flow_block.known_at_block_start.SubstituteKnownBufferValues(
       std::move(expr), axis_var_lookup_, analyzer);
 
   expr = analyzer->Simplify(std::move(expr));
diff --git a/src/tir/analysis/control_flow_graph.h b/src/tir/analysis/control_flow_graph.h
index aa9023ba29..590392cf65 100644
--- a/src/tir/analysis/control_flow_graph.h
+++ b/src/tir/analysis/control_flow_graph.h
@@ -29,6 +29,7 @@
 #include <tvm/tir/stmt.h>
 #include <tvm/tir/var.h>
 
+#include <optional>
 #include <unordered_map>
 #include <utility>
 #include <vector>
@@ -474,13 +475,17 @@ class ControlFlowGraph {
 
   /*! \brief Propagate known values from known BufferStore/assume
    *  subsequent control flow blocks
+   *
+   * \param flow_from If specified, re-flow only from that block.
    */
-  void ForwardPropagateKnownValues(size_t max_revisits);
+  void ForwardPropagateKnownValues(std::optional<size_t> flow_from = std::nullopt);
 
   /*! \brief Propagate overwritten/unused indices to preceding control
    *  flow blocks
+   *
+   * \param flow_from If specified, re-flow only from that block.
    */
-  void BackwardPropagateUnusedValues(size_t max_revisits);
+  void BackwardPropagateUnusedValues(std::optional<size_t> flow_from = std::nullopt);
 
   struct ControlFlowEdge {
     /* \brief The source block of the control flow edge
@@ -646,6 +651,9 @@ class ControlFlowGraph {
   std::vector<PrimExpr> non_buffer_assumptions_;
 
   friend class ControlFlowGraphBuilder;
+
+  /*! \brief The maximum number of revisits while flowing constraints */
+  size_t max_revisits_;
 };
 
 }  // namespace tir
diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc
index 41250408a7..3374f975f5 100644
--- a/src/tir/transforms/remove_no_op.cc
+++ b/src/tir/transforms/remove_no_op.cc
@@ -29,21 +29,71 @@
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/transform.h>
 
+#include <optional>
 #include <unordered_map>
 
 #include "../../arith/const_fold.h"
+#include "../../arith/ir_mutator_with_analyzer.h"
+#include "../analysis/control_flow_graph.h"
 #include "ir_utils.h"
 
 namespace tvm {
 namespace tir {
 
+struct RemoveNoOpConfigNode : public tvm::AttrsNode<RemoveNoOpConfigNode> {
+  bool use_dataflow_analysis;
+
+  TVM_DECLARE_ATTRS(RemoveNoOpConfigNode, "tir.transform.RemoveNoOpConfig") {
+    TVM_ATTR_FIELD(use_dataflow_analysis)
+        .describe(
+            "If true, known buffer values are propagated and used "
+            "to statically prove statements as no-ops.")
+        .set_default(false);
+  }
+};
+
+class RemoveNoOpConfig : public Attrs {
+ public:
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RemoveNoOpConfig, Attrs, RemoveNoOpConfigNode);
+};
+
+TVM_REGISTER_NODE_TYPE(RemoveNoOpConfigNode);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.RemoveNoOp", RemoveNoOpConfig);
+
 // Mark the statement of each stage.
-class NoOpRemover : public StmtMutator {
+class NoOpRemover : public arith::IRMutatorWithAnalyzer {
  public:
+  static Stmt Apply(Stmt stmt, arith::Analyzer* analyzer,
+                    std::optional<ControlFlowGraph> touch_pattern, const StmtNode* context) {
+    NoOpRemover visitor(analyzer, touch_pattern, context);
+    return visitor(std::move(stmt));
+  }
+
+ private:
+  using Parent = IRMutatorWithAnalyzer;
+  using Parent::VisitStmt;
+  using Parent::VisitStmt_;
+
+  NoOpRemover(arith::Analyzer* analyzer, std::optional<ControlFlowGraph> touch_pattern,
+              const StmtNode* context)
+      : Parent(analyzer), touch_pattern_(touch_pattern), context_(context) {}
+
   Stmt VisitStmt_(const LetStmtNode* op) final {
-    Stmt stmt = StmtMutator::VisitStmt_(op);
+    Stmt stmt = Parent::VisitStmt_(op);
     op = stmt.as<LetStmtNode>();
-    return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
+    if (is_no_op(op->body)) {
+      return MakeEvaluate(op->value);
+    }
+
+    bool body_uses_bound_variable =
+        !UsesVar(op->body, [&](const VarNode* var) { return var == op->var.get(); });
+    if (body_uses_bound_variable && HasSideEffect(op->value)) {
+      return SeqStmt({MakeEvaluate(op->value), op->body});
+    } else if (body_uses_bound_variable) {
+      return op->body;
+    } else {
+      return stmt;
+    }
   }
   Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == "pragma_debug_skip_region") {
@@ -58,24 +108,26 @@ class NoOpRemover : public StmtMutator {
         // We assume that such wait is a nop.
         auto inner = op->body.as<AttrStmtNode>();
         ICHECK(inner);
-        return StmtMutator::VisitStmt(inner->body);
+        return Parent::VisitStmt(inner->body);
       }
     }
 
-    Stmt stmt = StmtMutator::VisitStmt_(op);
+    Stmt stmt = Parent::VisitStmt_(op);
     op = stmt.as<AttrStmtNode>();
     return is_no_op(op->body) ? MakeEvaluate(op->value) : stmt;
   }
   Stmt VisitStmt_(const IfThenElseNode* op) final {
-    Stmt stmt = StmtMutator::VisitStmt_(op);
+    Stmt stmt = Parent::VisitStmt_(op);
     op = stmt.as<IfThenElseNode>();
     if (op->else_case) {
-      if (is_no_op(op->else_case.value())) {
-        if (is_no_op(op->then_case)) {
-          return MakeEvaluate(op->condition);
-        } else {
-          return IfThenElse(op->condition, op->then_case);
-        }
+      bool no_op_else = is_no_op(op->else_case.value());
+      bool no_op_then = is_no_op(op->then_case);
+      if (no_op_else && no_op_then) {
+        return MakeEvaluate(op->condition);
+      } else if (no_op_else) {
+        return IfThenElse(op->condition, op->then_case);
+      } else if (no_op_then) {
+        return IfThenElse(!op->condition, op->else_case.value());
       } else {
         return stmt;
       }
@@ -88,13 +140,13 @@ class NoOpRemover : public StmtMutator {
     }
   }
   Stmt VisitStmt_(const ForNode* op) final {
-    var_range_map_[op->loop_var.get()] = arith::IntSet::FromMinExtent(op->min, op->extent);
     auto extent_range = arith::EvalSet(op->extent, var_range_map_);
     if (!arith::is_neg_inf(extent_range.max()) && !arith::is_pos_inf(extent_range.max()) &&
-        analyzer_.CanProve(extent_range.max() <= 0)) {
+        analyzer_->CanProve(extent_range.max() <= 0)) {
       return Evaluate(0);
     }
-    Stmt stmt = StmtMutator::VisitStmt_(op);
+    var_range_map_[op->loop_var.get()] = arith::IntSet::FromMinExtent(op->min, op->extent);
+    Stmt stmt = Parent::VisitStmt_(op);
     var_range_map_.erase(op->loop_var.get());
     op = stmt.as<ForNode>();
     if (is_zero(op->extent)) {
@@ -114,42 +166,104 @@ class NoOpRemover : public StmtMutator {
     return is_no_op(op->body) ? op->body : stmt;
   }
   Stmt VisitStmt_(const EvaluateNode* op) final {
-    if (SideEffect(op->value) > CallEffectKind::kReadState) return GetRef<Stmt>(op);
-    return Evaluate(0);
+    if (HasSideEffect(op->value)) {
+      return GetRef<Stmt>(op);
+    } else {
+      return Evaluate(0);
+    }
   }
 
   Stmt VisitStmt_(const SeqStmtNode* op) final {
-    Stmt ret = StmtMutator::VisitSeqStmt_(op, true);
-    op = ret.as<SeqStmtNode>();
-    ICHECK(op != nullptr);
-    bool need_compact = false;
-    for (size_t i = 0; i < op->size(); ++i) {
-      if (is_no_op(op->seq[i])) need_compact = true;
-    }
+    auto ret = Downcast<SeqStmt>(StmtMutator::VisitSeqStmt_(op, true));
+
+    bool need_compact = std::any_of(ret->seq.begin(), ret->seq.end(),
+                                    [](const auto& stmt) { return is_no_op(stmt); });
+
     if (need_compact) {
-      auto n = CopyOnWrite(op);
-      size_t top = 0;
-      for (size_t i = 0; i < n->seq.size(); ++i) {
-        if (!is_no_op(n->seq[i])) {
-          n->seq.Set(top++, n->seq[i]);
+      Array<Stmt> filtered;
+      for (Stmt stmt : ret->seq) {
+        if (!is_no_op(stmt)) {
+          filtered.push_back(std::move(stmt));
         }
       }
-      if (top == 1) {
-        return n->seq[0];
-      } else {
-        n->seq.resize(top);
-        return Stmt(n);
-      }
+      ret = SeqStmt(filtered);
+    }
+
+    if (ret->size() == 0) {
+      return Evaluate(0);
+    } else if (ret->size() == 1) {
+      return ret->seq[0];
     } else {
-      if (op->size() == 1) {
-        return op->seq[0];
-      } else {
-        return ret;
+      return std::move(ret);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    BufferStore store = GetRef<BufferStore>(op);
+
+    // Helper function that returns a statement containing only the
+    // side effects of evaluating this BufferStore, but not the store
+    // itself.
+    auto only_side_effects = [&]() {
+      Array<Stmt> statements;
+      statements.push_back(MakeEvaluate(store->value));
+      for (const auto& index : store->indices) {
+        statements.push_back(MakeEvaluate(index));
+      }
+      return this->VisitStmt(SeqStmt(statements));
+    };
+
+    if (touch_pattern_.has_value()) {
+      // A write that is later overwritten is a no-op.
+      Stmt context = context_ ? GetRef<Stmt>(context_) : store;
+      if (touch_pattern_->IsOverwrittenWithoutEffect(store, context)) {
+        touch_pattern_->RemoveStore(store);
+        return only_side_effects();
+      }
+
+      // A write whose destination is known to already contain the
+      // values to be written is a no-op.
+      PrimExpr stores_existing_value = store->value == BufferLoad(store->buffer, store->indices);
+
+      PrimExpr simplified =
+          touch_pattern_->SimplifyInContext(stores_existing_value, context, analyzer_);
+      if (auto* as_int = as_const_int(simplified); as_int && *as_int) {
+        return only_side_effects();
       }
     }
+
+    // If the stored value is a load from the same location, the
+    // statement is a no-op, regardless of contextual information.
+    if (const BufferLoadNode* load = store->value.as<BufferLoadNode>()) {
+      if (load->buffer->data.same_as(store->buffer->data) &&
+          analyzer_->CanProveEqual(load->buffer->elem_offset, store->buffer->elem_offset) &&
+          ArrayValueEqual(load->buffer->shape, store->buffer->shape) &&
+          ArrayValueEqual(load->buffer->strides, store->buffer->strides) &&
+          ArrayValueEqual(load->indices, store->indices)) {
+        return only_side_effects();
+      }
+    }
+
+    return std::move(store);
   }
 
  private:
+  bool ArrayValueEqual(const Array<PrimExpr>& a, const Array<PrimExpr>& b) {
+    if (a.size() != b.size()) {
+      return false;
+    }
+    for (size_t i = 0; i < a.size(); i++) {
+      if (!analyzer_->CanProveEqual(a[i], b[i])) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  bool HasSideEffect(const PrimExpr& value) {
+    return SideEffect(value) > CallEffectKind::kReadState;
+  }
+
   Stmt MakeEvaluate(PrimExpr value) {
     if (SideEffect(value) > CallEffectKind::kReadState) {
       return Evaluate(value);
@@ -158,31 +272,47 @@ class NoOpRemover : public StmtMutator {
     }
   }
   Stmt MakeEvaluate(const Array<PrimExpr>& values) {
-    Stmt stmt;
+    Array<Stmt> stmts;
     for (PrimExpr e : values) {
       if (SideEffect(e) > CallEffectKind::kReadState) {
-        if (stmt.defined()) {
-          stmt = SeqStmt({stmt, Evaluate(e)});
-        } else {
-          stmt = Evaluate(e);
-        }
+        stmts.push_back(Evaluate(e));
       }
     }
-    return stmt.defined() ? stmt : Evaluate(0);
+
+    if (stmts.size() == 0) {
+      return Evaluate(0);
+    } else if (stmts.size() == 1) {
+      return stmts[0];
+    } else {
+      return SeqStmt(stmts);
+    }
   }
 
   std::unordered_map<const VarNode*, arith::IntSet> var_range_map_;
-  arith::Analyzer analyzer_;
+  std::optional<ControlFlowGraph> touch_pattern_;
+  const StmtNode* context_;
 };
 
-Stmt RemoveNoOp(Stmt stmt) { return NoOpRemover()(std::move(stmt)); }
-
 namespace transform {
 
 Pass RemoveNoOp() {
   auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
+
+    RemoveNoOpConfig config = ctx->GetConfig<RemoveNoOpConfig>("tir.RemoveNoOp")
+                                  .value_or(AttrsWithDefaultValues<RemoveNoOpConfig>());
+    if (config->use_dataflow_analysis) {
+      touch_pattern.emplace(f->body);
+    }
+
+    arith::Analyzer analyzer;
+    analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension(
+        arith::RewriteSimplifier::kTransitivelyProveInequalities |
+        arith::RewriteSimplifier::kConvertBooleanToAndOfOrs |
+        arith::RewriteSimplifier::kApplyConstraintsToBooleanBranches));
+
     auto* n = f.CopyOnWrite();
-    n->body = NoOpRemover()(std::move(n->body));
+    n->body = NoOpRemover::Apply(std::move(n->body), &analyzer, std::move(touch_pattern), nullptr);
     return f;
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {});
diff --git a/tests/python/unittest/test_tir_transform_remove_no_op.py b/tests/python/unittest/test_tir_transform_remove_no_op.py
index 820e32eb7e..ce37329b7e 100644
--- a/tests/python/unittest/test_tir_transform_remove_no_op.py
+++ b/tests/python/unittest/test_tir_transform_remove_no_op.py
@@ -19,6 +19,8 @@ from tvm import te
 from tvm.script import tir as T
 import tvm.testing
 
+import pytest
+
 
 def nop():
     return tvm.tir.Evaluate(0)
@@ -82,5 +84,524 @@ def test_remove_no_op_with_invalid_extent():
     assert isinstance(ret, tvm.tir.Evaluate)
 
 
+class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
+    use_dataflow_analysis = False
+
+    def transform(self):
+        def inner(mod):
+            config = {
+                "tir.RemoveNoOp": {
+                    "use_dataflow_analysis": self.use_dataflow_analysis,
+                }
+            }
+            with tvm.transform.PassContext(config=config):
+                mod = tvm.tir.transform.RemoveNoOp()(mod)
+            return mod
+
+        return inner
+
+
+class TestRemoveEmptyForLoop(BaseBeforeAfter):
+    """A for-loop whose body is a no-op is itself a no-op."""
+
+    def before():
+        for i in T.serial(16):
+            T.evaluate(0)
+
+    def expected():
+        T.evaluate(0)
+
+
+class TestRemoveZeroExtentLoop(BaseBeforeAfter):
+    """A for-loop with no extent is a no-op."""
+
+    def before(A: T.Buffer[16, "int32"]):
+        for i in T.serial(0):
+            A[i] = 42
+
+    def expected(A: T.Buffer[16, "int32"]):
+        T.evaluate(0)
+
+
+class TestRemoveUnusedLet(BaseBeforeAfter):
+    """A let statement that is never used is a no-op."""
+
+    def before(A: T.Buffer[16, "int32"]):
+        x = 5
+        for i in T.serial(16):
+            A[i] = 0
+
+    def expected(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            A[i] = 0
+
+
+class TestRemoveLetUsedOnlyInNoOp(BaseBeforeAfter):
+    """A let statement that is never used is a no-op.
+
+    Similar to TestRemoveUnusedLet, but the usage of the let binding
+    may have been removed by an earlier removal of another no-op.
+    """
+
+    def before(A: T.Buffer[16, "int32"]):
+        x = 5
+        for i in T.serial(0):
+            A[i] = x
+
+    def expected(A: T.Buffer[16, "int32"]):
+        T.evaluate(0)
+
+
+class TestKeepSideEffectsOfLet(BaseBeforeAfter):
+    """The side effects of a no-op let must be kept."""
+
+    def before():
+        x = T.call_extern("extern_func", dtype="int32")
+        T.evaluate(0)
+
+    def expected():
+        T.evaluate(T.call_extern("extern_func", dtype="int32"))
+
+
+class TestRemoveEmptyThenCase(BaseBeforeAfter):
+    """A no-op then_case can be removed."""
+
+    def before(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            if i < 8:
+                T.evaluate(0)
+            else:
+                A[i] = 42
+
+    def expected(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            if not (i < 8):
+                A[i] = 42
+
+
+class TestRemoveEmptyElseCase(BaseBeforeAfter):
+    """A no-op else_case can be removed."""
+
+    def before(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            if i < 8:
+                A[i] = 42
+            else:
+                T.evaluate(0)
+
+    def expected(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            if i < 8:
+                A[i] = 42
+
+
+class TestRemoveUnusedWrite(BaseBeforeAfter):
+    """For two sequential writes, the first is a no-op"""
+
+    use_dataflow_analysis = True
+
+    def before(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            A[i] = 100
+            A[i] = 42
+
+    def expected(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            A[i] = 42
+
+
+class TestSuppressRemovalOfUnusedWrite(BaseBeforeAfter):
+    """Dataflow analysis requires the config to opt-in
+
+    Like TestRemoveUnusedWrite, but dataflow analysis isn't enabled.
+    """
+
+    use_dataflow_analysis = False
+
+    def before(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            A[i] = 100
+            A[i] = 42
+
+    expected = before
+
+
+class TestKeepSideEffectsOfUnusedWrite(BaseBeforeAfter):
+    """For two sequential writes, the first value may have side effects"""
+
+    use_dataflow_analysis = True
+
+    def before(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            A[i] = T.call_extern("extern_func", dtype="int32")
+            A[i] = 42
+
+    def expected(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            T.evaluate(T.call_extern("extern_func", dtype="int32"))
+            A[i] = 42
+
+
+class TestKeepFirstWriteWhenUsed(BaseBeforeAfter):
+    """For two sequential writes, keep the first if it is used"""
+
+    def before(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            A[i] = 100
+            A[i] = A[i] + 1
+
+    expected = before
+
+
+class TestRemoveOverwrittenLoop(BaseBeforeAfter):
+    """Remove repeated writes to the same region
+
+    If two loops write to the same region, the first is a no-op.
+    """
+
+    use_dataflow_analysis = True
+
+    def before(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            A[i] = 100
+
+        for i in T.serial(16):
+            A[i] = 42
+
+    def expected(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            A[i] = 42
+
+
+class TestRemoveOverwrittenSubloop(BaseBeforeAfter):
+    """Remove repeated writes to the same region
+
+    If the first loop writes to a subset of the region, the first loop
+    is a no-op.  Similar to TestRemoveOverwrittenLoop, but the first
+    loop's extents are a subset of the second loop.
+    """
+
+    use_dataflow_analysis = True
+
+    def before(A: T.Buffer[16, "int32"]):
+        for i in T.serial(4, 12):
+            A[i] = 100
+
+        for i in T.serial(16):
+            A[i] = 42
+
+    def expected(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            A[i] = 42
+
+
+class TestKeepPartiallyOverwrittenLoop(BaseBeforeAfter):
+    """Keep partially overwritten regions
+
+    If the second loop doesn't entirely overwrite the first, the first
+    may not be removed be kept.
+    """
+
+    def before(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            A[i] = 100
+
+        for i in T.serial(16):
+            if i < 12:
+                A[i] = 42
+
+    expected = before
+
+
+class TestRemoveOverwrittenPredicatedLoopWithIdenticalCondition(BaseBeforeAfter):
+    """Remove repeated writes to the same predicated region.
+
+    Similar to TestKeepPartiallyOverwrittenLoop, except the first loop
+    has the same predicate as the second, and can therefore be
+    removed.
+    """
+
+    use_dataflow_analysis = True
+
+    def before(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            if i < 12:
+                A[i] = 100
+
+        for i in T.serial(16):
+            if i < 12:
+                A[i] = 42
+
+    def expected(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            if i < 12:
+                A[i] = 42
+
+
+class TestRemoveOverwrittenPredicatedLoopWithProvableCondition(BaseBeforeAfter):
+    """Remove repeated writes to the same predicated region.
+
+    Similar to
+    TestRemoveOverwrittenPredicatedLoopWithIdenticalCondition, except
+    the first loop's predicate is not a precise match for the second
+    loop's predicate.  So long as the regions written in the first
+    loop are a subset of those written in the second loop, they can be
+    removed.
+    """
+
+    use_dataflow_analysis = True
+
+    def before(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            if i < 10:
+                A[i] = 100
+
+        for i in T.serial(16):
+            if i // 4 < 3:
+                A[i] = 42
+
+    def expected(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            if i // 4 < 3:
+                A[i] = 42
+
+
+class TestRemoveSeparatedOverwrites(BaseBeforeAfter):
+    """Remove repeated writes to the same predicated region.
+
+    Similar to TestRemoveOverwrittenLoopRegion, but with an
+    independent loop between the first and second write of the buffer.
+    """
+
+    use_dataflow_analysis = True
+
+    def before(A: T.Buffer[16, "int32"], B: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            A[i] = 100
+
+        for i in T.serial(16):
+            B[i] = 0
+
+        for i in T.serial(16):
+            A[i] = 42
+
+    def expected(A: T.Buffer[16, "int32"], B: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            B[i] = 0
+
+        for i in T.serial(16):
+            A[i] = 42
+
+
+@pytest.mark.xfail(reason="Not implemented yet")
+class TestRemoveSeparatedOverwriteOfPredicatedLoop(BaseBeforeAfter):
+    """Remove repeated writes to the same predicated region.
+
+    Similar to TestRemoveSeparatedOverwrites, but the independent loop
+    between the first and second writes writes to a different subset
+    of the same buffer.
+    """
+
+    use_dataflow_analysis = True
+
+    def before(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            if i < 12:
+                A[i] = 100
+
+        for i in T.serial(16):
+            if i > 12:
+                A[i] = 15
+
+        for i in T.serial(16):
+            if i < 12:
+                A[i] = 42
+
+    def expected(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            if i > 12:
+                A[i] = 15
+
+        for i in T.serial(16):
+            if i < 12:
+                A[i] = 42
+
+
+class TestRemoveReadWrite(BaseBeforeAfter):
+    """Writing a value to the same location as was just read is a no-op."""
+
+    def before(A: T.Buffer[1, "int32"]):
+        A[0] = A[0]
+
+    def expected(A: T.Buffer[1, "int32"]):
+        T.evaluate(0)
+
+
+class TestKeepReadWriteToDifferentIndices(BaseBeforeAfter):
+    """Writing a value to a different index should not be removed"""
+
+    def before(A: T.Buffer[16, "int32"]):
+        for i in T.serial(15):
+            A[i] = A[i + 1]
+
+    expected = before
+
+
+class TestRemoveReadWriteSameIndexDifferentExpression(BaseBeforeAfter):
+    """Writing a value to the same location as the read is a no-op.
+
+    If the value of the index can be proven to be the same, then the
+    no-op can be removed, even if they have different forms of the
+    expression.
+    """
+
+    def before(A: T.Buffer[16, "int32"]):
+        for io, ii in T.grid(4, 4):
+            i = 4 * io + ii
+            A[4 * io + ii] = A[i]
+
+    def expected(A: T.Buffer[16, "int32"]):
+        T.evaluate(0)
+
+
+class TestRemoveReadWriteSameIndexUsingConstraint(BaseBeforeAfter):
+    """Writing a value to the same location as the read is a no-op.
+
+    If the value of the index can be proven to be the same, then the
+    no-op can be removed.  This may require using the a constraint
+    that is known from a conditional containing the read/write.
+    """
+
+    def before(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            if i != 0:
+                A[i] = A[i - 1]
+            else:
+                A[i] = A[0]
+
+    def expected(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            if i != 0:
+                A[i] = A[i - 1]
+
+
+class TestRemoveWritingOfKnownValue(BaseBeforeAfter):
+    """Writing a value that already exists at that index is a no-op"""
+
+    use_dataflow_analysis = True
+
+    def before(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            A[i] = i
+
+        A[4] = 4
+
+    def expected(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            A[i] = i
+
+
+class TestKeepOneOfDuplicateLoops(BaseBeforeAfter):
+    """Must not reason based on a touch point after removing it.
+
+    If the first loop is removed because it is overwritten by the
+    second loop, and the second loop is removed because it writes the
+    same value as the first loop, the overall transformation is no
+    longer valid.  In this case, only one of the two should be
+    removed.
+    """
+
+    use_dataflow_analysis = True
+
+    def before(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            A[i] = i
+
+        for i in T.serial(16):
+            A[i] = i
+
+    def expected(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            A[i] = i
+
+
+class TestRemoveEmptyTemporary(BaseBeforeAfter):
+    """An allocation with a no-op body is a no-op."""
+
+    def before():
+        A = T.allocate([16], "int32", "local")
+        T.evaluate(0)
+
+    def expected():
+        T.evaluate(0)
+
+
+@pytest.mark.xfail(reason="Not implemented yet")
+class TestRemoveUnusedTemporary(BaseBeforeAfter):
+    """An unused allocation is a no-op."""
+
+    def before(A: T.Buffer[16, "int32"]):
+        B = T.allocate([16], "int32", "local")
+        for i in T.serial(16):
+            A[i] = 1
+
+    def expected(A: T.Buffer[16, "int32"]):
+        for i in T.serial(16):
+            A[i] = 1
+
+
+@pytest.mark.xfail(reason="Not implemented yet")
+class TestRemoveUnusedWriteIntoTemporary(BaseBeforeAfter):
+    """A write that only impacts a temporary allocation is a no-op."""
+
+    def before():
+        A = T.decl_buffer([16], "int32", scope="local")
+        for i in T.serial(16):
+            A[i] = 0
+
+    def expected():
+        T.evaluate(0)
+
+
+class TestKeepUsedWriteIntoTemporary(BaseBeforeAfter):
+    """A write into a temporary that is used later must be kept."""
+
+    def before(B: T.Buffer[16, "int32"]):
+        A = T.decl_buffer([16], "int32", scope="local")
+        for i in T.serial(16):
+            A[i] = 0
+
+        for i in T.serial(16):
+            B[i] = A[i]
+
+    expected = before
+
+
+@pytest.mark.xfail(reason="Not implemented yet")
+class TestRemoveWriteIntoTemporary(BaseBeforeAfter):
+    """A write that only impacts a temporary allocation is a no-op."""
+
+    def before(A: T.Buffer[16, "int32"], C: T.Buffer[1, "int32"]):
+        B = T.decl_buffer([16], "int32", scope="local")
+        for i in T.serial(16):
+            B[i] = A[i]
+
+        C[0] = 0
+        for i in T.serial(16):
+            C[0] = C[0] + B[i]
+
+        for i in T.serial(16):
+            B[i] = 0
+
+    def expected(A: T.Buffer[16, "int32"], C: T.Buffer[1, "int32"]):
+        B = T.decl_buffer([16], "int32", scope="local")
+        for i in T.serial(16):
+            B[i] = A[i]
+
+        C[0] = 0
+        for i in T.serial(16):
+            C[0] = C[0] + B[i]
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py
index fd98b715a4..1ddc0e50d9 100644
--- a/tests/python/unittest/test_tir_transform_simplify.py
+++ b/tests/python/unittest/test_tir_transform_simplify.py
@@ -1267,6 +1267,7 @@ class TestSimplifyUsingPartiallyKnownBufferConditional(BaseBeforeAfter):
     """An assumption about buffer contents may apply to only part of a buffer"""
 
     propagate_knowns_to_prove_conditional = True
+    apply_constraints_to_boolean_branches = True
 
     def before(A: T.Buffer[16, "int32"]):
         for i in T.serial(16):