You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/01/18 06:46:52 UTC

[tvm] 01/01: Revert "[TIR] Encode conditional accesses info into block read/write regions (#9880)"

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch revert-9880-encode_conditional_accesses_in_read_write_annotations
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 2e14b8565515b3940f1624de61fc23d3446c4cbe
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Mon Jan 17 22:46:07 2022 -0800

    Revert "[TIR] Encode conditional accesses info into block read/write regions (#9880)"
    
    This reverts commit 6f6fc68f5a028a92607b2907b9e4144543686639.
---
 src/tir/analysis/block_access_region_detector.cc   | 29 ++--------
 src/tir/transforms/compact_buffer_region.cc        | 10 ++--
 src/tir/transforms/ir_utils.cc                     | 62 +++++---------------
 src/tir/transforms/ir_utils.h                      | 18 +++---
 .../test_tir_analysis_get_block_access_region.py   | 66 ----------------------
 .../test_tir_transform_compact_buffer_region.py    |  1 -
 6 files changed, 30 insertions(+), 156 deletions(-)

diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc
index 07dcace..776538a 100644
--- a/src/tir/analysis/block_access_region_detector.cc
+++ b/src/tir/analysis/block_access_region_detector.cc
@@ -56,8 +56,6 @@ class BlockReadWriteDetector : public StmtExprVisitor {
  private:
   /*! \brief Iteration range for loop_vars */
   std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
-  /*! \brief Extra iteration range hint for free vars */
-  std::unordered_map<const VarNode*, arith::IntSet> hint_map_;
   /*! \brief The buffers that the current block reads */
   std::vector<Buffer> read_buffers_;
   /*! \brief The buffers that the current block writes */
@@ -98,9 +96,6 @@ class BlockReadWriteDetector : public StmtExprVisitor {
   /*! \brief Helper function to update a opaque access. */
   void UpdateOpaque(const Var& buffer_var);
 
-  /*! \brief Helper function to relax the buffer indices */
-  arith::IntSet RelaxAccessIndex(const PrimExpr& index);
-
   void VisitStmt_(const ForNode* op) override;
   void VisitStmt_(const IfThenElseNode* op) override;
   void VisitStmt_(const BlockRealizeNode* op) override;
@@ -145,22 +140,10 @@ void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) {
   ExprVisitor::VisitExpr_(op);
 }
 
-arith::IntSet BlockReadWriteDetector::RelaxAccessIndex(const PrimExpr& index) {
-  arith::IntSet relaxed = arith::EvalSet(index, dom_map_);
-  if (!hint_map_.empty()) {
-    // take non-relaxed var bound hints into considerations
-    // eg, if i * 4 + j with i >= 10 and j in [0, 4), only j in domain scope
-    // then the index region can be relaxed to [i*4, i*4+4) ^ [40, inf)
-    arith::IntSet hint_bound = arith::EvalSet(relaxed, hint_map_);
-    relaxed = arith::Intersect({relaxed, hint_bound});
-  }
-  return relaxed;
-}
-
 void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) {
   std::vector<arith::IntSet> relaxed_region;
   for (const PrimExpr& index : op->indices) {
-    relaxed_region.push_back(RelaxAccessIndex(index));
+    relaxed_region.push_back(arith::EvalSet(index, dom_map_));
   }
   Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region);
   ExprVisitor::VisitExpr_(op);
@@ -177,12 +160,12 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) {
   VisitExpr(op->condition);
   {
     // Visit then branch
-    With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, true);
+    With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, true);
     StmtExprVisitor::VisitStmt(op->then_case);
   }
   if (op->else_case.defined()) {
     // Visit else branch
-    With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, false);
+    With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, false);
     StmtExprVisitor::VisitStmt(op->else_case);
   }
 }
@@ -192,12 +175,12 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) {
     VisitExpr(op->args[0]);
     {
       // Visit then branch
-      With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, true);
+      With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, true);
       StmtExprVisitor::VisitExpr(op->args[1]);
     }
     {
       // Visit else branch
-      With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, false);
+      With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, false);
       StmtExprVisitor::VisitExpr(op->args[2]);
     }
     return;
@@ -213,7 +196,7 @@ void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) {
 void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) {
   std::vector<arith::IntSet> relaxed_region;
   for (const PrimExpr& index : op->indices) {
-    relaxed_region.push_back(RelaxAccessIndex(index));
+    relaxed_region.push_back(arith::EvalSet(index, dom_map_));
   }
   Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region);
   StmtVisitor::VisitStmt_(op);
diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc
index 20ddd7f..07f9778 100644
--- a/src/tir/transforms/compact_buffer_region.cc
+++ b/src/tir/transforms/compact_buffer_region.cc
@@ -123,12 +123,12 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
     StmtExprVisitor::VisitExpr(op->condition);
     {
       // Visit then branch
-      With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, true);
+      With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, true);
       StmtExprVisitor::VisitStmt(op->then_case);
     }
     if (op->else_case.defined()) {
       // Visit else branch
-      With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, false);
+      With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, false);
       StmtExprVisitor::VisitStmt(op->else_case);
     }
   }
@@ -139,12 +139,12 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
       StmtExprVisitor::VisitExpr(op->args[0]);
       {
         // Visit then branch
-        With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, true);
+        With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, true);
         StmtExprVisitor::VisitExpr(op->args[1]);
       }
       {
         // Visit else branch
-        With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, false);
+        With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, false);
         StmtExprVisitor::VisitExpr(op->args[2]);
       }
       return;
@@ -282,8 +282,6 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
 
   /*! \brief The map from loop vars to their iter range. */
   std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
-  /*! \brief Extra map from free vars to their iter range hints. */
-  std::unordered_map<const VarNode*, arith::IntSet> hint_map_;
   /*! \brief The analyzer aware of loop domains. */
   arith::Analyzer dom_analyzer_;
   /*! \brief The map from Buffer to it's relaxed access set. */
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index bc2f7ad..2423b09 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -300,18 +300,11 @@ Map<Var, Range> ConditionalBoundsContext::GetVarBoundsFromCondition() {
   Array<Var> vars = Array<Var>(var_set.begin(), var_set.end());
   Map<Var, Range> ranges;
   for (const Var& v : vars) {
-    arith::IntSet dom;
-    auto relax_it = relax_map_->find(v.get());
-    if (relax_it != relax_map_->end()) {
-      dom = relax_it->second;
-    } else {
-      auto hint_it = hint_map_->find(v.get());
-      if (hint_it != hint_map_->end()) {
-        dom = hint_it->second;
-      }
-    }
-    if (dom.defined()) {
-      ranges.Set(v, Range::FromMinExtent(dom.min(), analyzer.Simplify(dom.max() - dom.min() + 1)));
+    auto it = dom_map_->find(v.get());
+    if (it != dom_map_->end()) {
+      const auto& int_set = it->second;
+      ranges.Set(v, Range::FromMinExtent(int_set.min(),
+                                         analyzer.Simplify(int_set.max() - int_set.min() + 1)));
     }
   }
   // solve constraints
@@ -321,53 +314,24 @@ Map<Var, Range> ConditionalBoundsContext::GetVarBoundsFromCondition() {
 }
 
 ConditionalBoundsContext::ConditionalBoundsContext(
-    const PrimExpr& condition, std::unordered_map<const VarNode*, arith::IntSet>* relax_map,
-    std::unordered_map<const VarNode*, arith::IntSet>* hint_map, bool is_true_branch)
-    : condition_(condition),
-      relax_map_(relax_map),
-      hint_map_(hint_map),
-      is_true_branch_(is_true_branch) {}
+    const PrimExpr& condition, std::unordered_map<const VarNode*, arith::IntSet>* dom_map,
+    bool is_true_branch)
+    : condition_(condition), dom_map_(dom_map), is_true_branch_(is_true_branch) {}
 
 void ConditionalBoundsContext::EnterWithScope() {
   for (const auto& p : GetVarBoundsFromCondition()) {
     const auto* var = p.first.get();
-    arith::IntSet new_dom = arith::IntSet::FromRange(p.second);
-    auto relax_it = relax_map_->find(var);
-    if (relax_it != relax_map_->end()) {
-      // this is a bound for relaxed var
-      origin_map_.emplace(var, relax_it->second);
-      relax_it->second = arith::Intersect({relax_it->second, new_dom});
-    } else {
-      // this is a bound for free var
-      auto hint_it = hint_map_->find(var);
-      if (hint_it != hint_map_->end()) {
-        origin_map_.emplace(var, hint_it->second);
-        hint_it->second = arith::Intersect({hint_it->second, new_dom});
-      } else {
-        origin_map_.emplace(var, arith::IntSet::Nothing());
-        hint_map_->insert(hint_it, {var, new_dom});
-      }
+    auto it = dom_map_->find(var);
+    if (it != dom_map_->end()) {
+      origin_map_.emplace(var, it->second);
+      it->second = arith::Intersect({it->second, arith::IntSet::FromRange(p.second)});
     }
   }
 }
 
 void ConditionalBoundsContext::ExitWithScope() {
   for (const auto& p : origin_map_) {
-    const auto* var = p.first;
-    auto relax_it = relax_map_->find(var);
-    if (relax_it != relax_map_->end()) {
-      // recover bound for relaxed var
-      relax_it->second = p.second;
-    } else {
-      // recover bound for free var
-      auto hint_it = hint_map_->find(var);
-      ICHECK(hint_it != hint_map_->end());
-      if (p.second.IsNothing()) {
-        hint_map_->erase(hint_it);
-      } else {
-        hint_it->second = p.second;
-      }
-    }
+    (*dom_map_)[p.first] = p.second;
   }
 }
 
diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h
index da52a82..7b1d34c 100644
--- a/src/tir/transforms/ir_utils.h
+++ b/src/tir/transforms/ir_utils.h
@@ -231,9 +231,9 @@ Bool IsFromLegacyTESchedule(PrimFunc f);
  *\brief Context helper to update domain map within conditional scope.
  *
  * Assume the condition is `0 <= i && i < 9` and global domain of i is [0, 20], thus `bounds[i]` is
- * [0, 8]. Then `With<ConditionalBoundsContext> ctx(condition, &relax_map, &hint_map, true)` step
- *into scope where dom_map[i] is [0, 8] and `With<ConditionalBoundsContext> ctx(condition,
- *&relax_map, &hint_map, false)` step into scope where dom_map[i] is [9, 20]
+ *[0, 8]. Then `With<ConditionalBoundsContext> ctx(&dom_map, bounds, true)` step into scope where
+ *dom_map[i] is [0, 8] and `With<ConditionalBoundsContext> ctx(&dom_map, bounds, false)` step into
+ *scope where dom_map[i] is [9, 20]
  */
 class ConditionalBoundsContext {
  private:
@@ -241,13 +241,11 @@ class ConditionalBoundsContext {
   /*!
    * \brief Construct a condition bounds context.
    * \param condition The condition holds on true branch.
-   * \param relax_map The domain map for relaxed vars to update.
-   * \param hint_map The domain map for free vars to update.
+   * \param dom_map The global domain map to be updated.
    * \param is_true_branch Whether step into the branch where condition bounds holds.
    */
   ConditionalBoundsContext(const PrimExpr& condition,
-                           std::unordered_map<const VarNode*, arith::IntSet>* relax_map,
-                           std::unordered_map<const VarNode*, arith::IntSet>* hint_map,
+                           std::unordered_map<const VarNode*, arith::IntSet>* dom_map,
                            bool is_true_branch);
   void EnterWithScope();
   void ExitWithScope();
@@ -257,10 +255,8 @@ class ConditionalBoundsContext {
 
   /*! \brief the condition holds on true branch. */
   const PrimExpr& condition_;
-  /*! \brief domain map for relaxed vars to update */
-  std::unordered_map<const VarNode*, arith::IntSet>* relax_map_;
-  /*! \brief domain map for free vars to update */
-  std::unordered_map<const VarNode*, arith::IntSet>* hint_map_;
+  /*! \brief global domain map to updated */
+  std::unordered_map<const VarNode*, arith::IntSet>* dom_map_;
   /*! \brief whether is on true branch */
   bool is_true_branch_;
   /*! \brief used to record and restore original var bounds */
diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py
index 5403754..e508fbb 100644
--- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py
+++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py
@@ -130,41 +130,6 @@ def access_in_branch_func() -> None:
                 B[i] = A[i - 1]
 
 
-@T.prim_func
-def access_of_padding_pattern() -> None:
-    X = T.alloc_buffer([28, 28])
-    X_pad = T.alloc_buffer([32, 32])
-    Y = T.alloc_buffer([28, 28])
-    for i, j in T.grid(32, 32):
-        with T.block("padding"):
-            vi, vj = T.axis.remap("SS", [i, j])
-            T.reads(
-                [
-                    X[
-                        T.max(vi - 2, 0) : T.min(vi - 2, 27) + 1,
-                        T.max(vj - 2, 0) : T.min(vj - 2, 27) + 1,
-                    ]
-                ]
-            )
-            T.writes([X_pad[vi, vj]])
-            X_pad[vi, vj] = T.if_then_else(
-                2 <= vi and vi < 30 and 2 <= vj and vj < 30, X[vi - 2, vj - 2], 0.0, dtype="float32"
-            )
-        with T.block("padding_reverse"):
-            vi, vj = T.axis.remap("SS", [i, j])
-            T.reads([X_pad[T.max(vi, 2) : T.min(vi, 29) + 1, T.max(vj, 2) : T.min(vj, 29) + 1]])
-            T.writes(
-                [
-                    Y[
-                        T.max(vi - 2, 0) : T.min(vi - 2, 27) + 1,
-                        T.max(vj - 2, 0) : T.min(vj - 2, 27) + 1,
-                    ]
-                ]
-            )
-            if 2 <= vi and vi < 30 and 2 <= vj and vj < 30:
-                Y[vi - 2, vj - 2] = X_pad[vi, vj]
-
-
 def test_block_access_region_detector():
     block = func.body.block.body.block
     alloc_buffers = func.body.block.alloc_buffers
@@ -255,36 +220,6 @@ def test_access_in_branch_func():
     tvm.ir.assert_structural_equal(ret0[1], ret1[1])
 
 
-def test_access_of_padding_pattern():
-    s = tvm.tir.schedule.Schedule(access_of_padding_pattern)
-    alloc_buffers = s.get_sref(s.get_block("root")).stmt.alloc_buffers
-    buffer_var_map = {buf.data: buf for buf in alloc_buffers}
-
-    def do_compare_buffer_region(region, expect):
-        assert region.buffer == expect.buffer
-        analyzer = tvm.arith.Analyzer()
-        for k, rng in enumerate(region.region):
-            tvm.ir.assert_structural_equal(
-                analyzer.simplify(rng.min), analyzer.simplify(expect.region[k].min)
-            )
-            tvm.ir.assert_structural_equal(
-                analyzer.simplify(rng.extent), analyzer.simplify(expect.region[k].extent)
-            )
-
-    def do_check_block(block_name):
-        block = s.get_sref(s.get_block(block_name)).stmt
-        expect_reads = block.reads
-        expect_writes = block.writes
-        ret = tir.analysis.get_block_access_region(block, buffer_var_map)
-        for i, read in enumerate(ret[0]):
-            do_compare_buffer_region(read, expect_reads[i])
-        for i, write in enumerate(ret[1]):
-            do_compare_buffer_region(write, expect_writes[i])
-
-    do_check_block("padding")
-    do_check_block("padding_reverse")
-
-
 if __name__ == "__main__":
     test_block_access_region_detector()
     test_opaque_block()
@@ -292,4 +227,3 @@ if __name__ == "__main__":
     test_match_buffer()
     test_access_in_if_then_else_func()
     test_access_in_branch_func()
-    test_access_of_padding_pattern()
diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
index 9b84485..57c87e5 100644
--- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py
+++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
@@ -24,7 +24,6 @@ def _check(original, transformed):
     mod = tvm.IRModule.from_expr(func)
     mod = tvm.tir.transform.CompactBufferAllocation()(mod)
     mod = tvm.tir.transform.Simplify()(mod)
-    transformed = tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(transformed))["main"]
     tvm.ir.assert_structural_equal(mod["main"], transformed)