You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/01/18 06:46:51 UTC

[tvm] branch revert-9880-encode_conditional_accesses_in_read_write_annotations created (now 2e14b85)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a change to branch revert-9880-encode_conditional_accesses_in_read_write_annotations
in repository https://gitbox.apache.org/repos/asf/tvm.git.


      at 2e14b85  Revert "[TIR] Encode conditional accesses info into block read/write regions (#9880)"

This branch includes the following new commits:

     new 2e14b85  Revert "[TIR] Encode conditional accesses info into block read/write regions (#9880)"

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


[tvm] 01/01: Revert "[TIR] Encode conditional accesses info into block read/write regions (#9880)"

Posted by ju...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch revert-9880-encode_conditional_accesses_in_read_write_annotations
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 2e14b8565515b3940f1624de61fc23d3446c4cbe
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Mon Jan 17 22:46:07 2022 -0800

    Revert "[TIR] Encode conditional accesses info into block read/write regions (#9880)"
    
    This reverts commit 6f6fc68f5a028a92607b2907b9e4144543686639.
---
 src/tir/analysis/block_access_region_detector.cc   | 29 ++--------
 src/tir/transforms/compact_buffer_region.cc        | 10 ++--
 src/tir/transforms/ir_utils.cc                     | 62 +++++---------------
 src/tir/transforms/ir_utils.h                      | 18 +++---
 .../test_tir_analysis_get_block_access_region.py   | 66 ----------------------
 .../test_tir_transform_compact_buffer_region.py    |  1 -
 6 files changed, 30 insertions(+), 156 deletions(-)

diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc
index 07dcace..776538a 100644
--- a/src/tir/analysis/block_access_region_detector.cc
+++ b/src/tir/analysis/block_access_region_detector.cc
@@ -56,8 +56,6 @@ class BlockReadWriteDetector : public StmtExprVisitor {
  private:
   /*! \brief Iteration range for loop_vars */
   std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
-  /*! \brief Extra iteration range hint for free vars */
-  std::unordered_map<const VarNode*, arith::IntSet> hint_map_;
   /*! \brief The buffers that the current block reads */
   std::vector<Buffer> read_buffers_;
   /*! \brief The buffers that the current block writes */
@@ -98,9 +96,6 @@ class BlockReadWriteDetector : public StmtExprVisitor {
   /*! \brief Helper function to update a opaque access. */
   void UpdateOpaque(const Var& buffer_var);
 
-  /*! \brief Helper function to relax the buffer indices */
-  arith::IntSet RelaxAccessIndex(const PrimExpr& index);
-
   void VisitStmt_(const ForNode* op) override;
   void VisitStmt_(const IfThenElseNode* op) override;
   void VisitStmt_(const BlockRealizeNode* op) override;
@@ -145,22 +140,10 @@ void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) {
   ExprVisitor::VisitExpr_(op);
 }
 
-arith::IntSet BlockReadWriteDetector::RelaxAccessIndex(const PrimExpr& index) {
-  arith::IntSet relaxed = arith::EvalSet(index, dom_map_);
-  if (!hint_map_.empty()) {
-    // take non-relaxed var bound hints into considerations
-    // eg, if i * 4 + j with i >= 10 and j in [0, 4), only j in domain scope
-    // then the index region can be relaxed to [i*4, i*4+4) ^ [40, inf)
-    arith::IntSet hint_bound = arith::EvalSet(relaxed, hint_map_);
-    relaxed = arith::Intersect({relaxed, hint_bound});
-  }
-  return relaxed;
-}
-
 void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) {
   std::vector<arith::IntSet> relaxed_region;
   for (const PrimExpr& index : op->indices) {
-    relaxed_region.push_back(RelaxAccessIndex(index));
+    relaxed_region.push_back(arith::EvalSet(index, dom_map_));
   }
   Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region);
   ExprVisitor::VisitExpr_(op);
@@ -177,12 +160,12 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) {
   VisitExpr(op->condition);
   {
     // Visit then branch
-    With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, true);
+    With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, true);
     StmtExprVisitor::VisitStmt(op->then_case);
   }
   if (op->else_case.defined()) {
     // Visit else branch
-    With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, false);
+    With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, false);
     StmtExprVisitor::VisitStmt(op->else_case);
   }
 }
@@ -192,12 +175,12 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) {
     VisitExpr(op->args[0]);
     {
       // Visit then branch
-      With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, true);
+      With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, true);
       StmtExprVisitor::VisitExpr(op->args[1]);
     }
     {
       // Visit else branch
-      With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, false);
+      With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, false);
       StmtExprVisitor::VisitExpr(op->args[2]);
     }
     return;
@@ -213,7 +196,7 @@ void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) {
 void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) {
   std::vector<arith::IntSet> relaxed_region;
   for (const PrimExpr& index : op->indices) {
-    relaxed_region.push_back(RelaxAccessIndex(index));
+    relaxed_region.push_back(arith::EvalSet(index, dom_map_));
   }
   Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region);
   StmtVisitor::VisitStmt_(op);
diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc
index 20ddd7f..07f9778 100644
--- a/src/tir/transforms/compact_buffer_region.cc
+++ b/src/tir/transforms/compact_buffer_region.cc
@@ -123,12 +123,12 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
     StmtExprVisitor::VisitExpr(op->condition);
     {
       // Visit then branch
-      With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, true);
+      With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, true);
       StmtExprVisitor::VisitStmt(op->then_case);
     }
     if (op->else_case.defined()) {
       // Visit else branch
-      With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, false);
+      With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, false);
       StmtExprVisitor::VisitStmt(op->else_case);
     }
   }
@@ -139,12 +139,12 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
       StmtExprVisitor::VisitExpr(op->args[0]);
       {
         // Visit then branch
-        With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, true);
+        With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, true);
         StmtExprVisitor::VisitExpr(op->args[1]);
       }
       {
         // Visit else branch
-        With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, &hint_map_, false);
+        With<ConditionalBoundsContext> ctx(op->args[0], &dom_map_, false);
         StmtExprVisitor::VisitExpr(op->args[2]);
       }
       return;
@@ -282,8 +282,6 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
 
   /*! \brief The map from loop vars to their iter range. */
   std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
-  /*! \brief Extra map from free vars to their iter range hints. */
-  std::unordered_map<const VarNode*, arith::IntSet> hint_map_;
   /*! \brief The analyzer aware of loop domains. */
   arith::Analyzer dom_analyzer_;
   /*! \brief The map from Buffer to it's relaxed access set. */
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index bc2f7ad..2423b09 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -300,18 +300,11 @@ Map<Var, Range> ConditionalBoundsContext::GetVarBoundsFromCondition() {
   Array<Var> vars = Array<Var>(var_set.begin(), var_set.end());
   Map<Var, Range> ranges;
   for (const Var& v : vars) {
-    arith::IntSet dom;
-    auto relax_it = relax_map_->find(v.get());
-    if (relax_it != relax_map_->end()) {
-      dom = relax_it->second;
-    } else {
-      auto hint_it = hint_map_->find(v.get());
-      if (hint_it != hint_map_->end()) {
-        dom = hint_it->second;
-      }
-    }
-    if (dom.defined()) {
-      ranges.Set(v, Range::FromMinExtent(dom.min(), analyzer.Simplify(dom.max() - dom.min() + 1)));
+    auto it = dom_map_->find(v.get());
+    if (it != dom_map_->end()) {
+      const auto& int_set = it->second;
+      ranges.Set(v, Range::FromMinExtent(int_set.min(),
+                                         analyzer.Simplify(int_set.max() - int_set.min() + 1)));
     }
   }
   // solve constraints
@@ -321,53 +314,24 @@ Map<Var, Range> ConditionalBoundsContext::GetVarBoundsFromCondition() {
 }
 
 ConditionalBoundsContext::ConditionalBoundsContext(
-    const PrimExpr& condition, std::unordered_map<const VarNode*, arith::IntSet>* relax_map,
-    std::unordered_map<const VarNode*, arith::IntSet>* hint_map, bool is_true_branch)
-    : condition_(condition),
-      relax_map_(relax_map),
-      hint_map_(hint_map),
-      is_true_branch_(is_true_branch) {}
+    const PrimExpr& condition, std::unordered_map<const VarNode*, arith::IntSet>* dom_map,
+    bool is_true_branch)
+    : condition_(condition), dom_map_(dom_map), is_true_branch_(is_true_branch) {}
 
 void ConditionalBoundsContext::EnterWithScope() {
   for (const auto& p : GetVarBoundsFromCondition()) {
     const auto* var = p.first.get();
-    arith::IntSet new_dom = arith::IntSet::FromRange(p.second);
-    auto relax_it = relax_map_->find(var);
-    if (relax_it != relax_map_->end()) {
-      // this is a bound for relaxed var
-      origin_map_.emplace(var, relax_it->second);
-      relax_it->second = arith::Intersect({relax_it->second, new_dom});
-    } else {
-      // this is a bound for free var
-      auto hint_it = hint_map_->find(var);
-      if (hint_it != hint_map_->end()) {
-        origin_map_.emplace(var, hint_it->second);
-        hint_it->second = arith::Intersect({hint_it->second, new_dom});
-      } else {
-        origin_map_.emplace(var, arith::IntSet::Nothing());
-        hint_map_->insert(hint_it, {var, new_dom});
-      }
+    auto it = dom_map_->find(var);
+    if (it != dom_map_->end()) {
+      origin_map_.emplace(var, it->second);
+      it->second = arith::Intersect({it->second, arith::IntSet::FromRange(p.second)});
     }
   }
 }
 
 void ConditionalBoundsContext::ExitWithScope() {
   for (const auto& p : origin_map_) {
-    const auto* var = p.first;
-    auto relax_it = relax_map_->find(var);
-    if (relax_it != relax_map_->end()) {
-      // recover bound for relaxed var
-      relax_it->second = p.second;
-    } else {
-      // recover bound for free var
-      auto hint_it = hint_map_->find(var);
-      ICHECK(hint_it != hint_map_->end());
-      if (p.second.IsNothing()) {
-        hint_map_->erase(hint_it);
-      } else {
-        hint_it->second = p.second;
-      }
-    }
+    (*dom_map_)[p.first] = p.second;
   }
 }
 
diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h
index da52a82..7b1d34c 100644
--- a/src/tir/transforms/ir_utils.h
+++ b/src/tir/transforms/ir_utils.h
@@ -231,9 +231,9 @@ Bool IsFromLegacyTESchedule(PrimFunc f);
  *\brief Context helper to update domain map within conditional scope.
  *
  * Assume the condition is `0 <= i && i < 9` and global domain of i is [0, 20], thus `bounds[i]` is
- * [0, 8]. Then `With<ConditionalBoundsContext> ctx(condition, &relax_map, &hint_map, true)` step
- *into scope where dom_map[i] is [0, 8] and `With<ConditionalBoundsContext> ctx(condition,
- *&relax_map, &hint_map, false)` step into scope where dom_map[i] is [9, 20]
+ *[0, 8]. Then `With<ConditionalBoundsContext> ctx(&dom_map, bounds, true)` step into scope where
+ *dom_map[i] is [0, 8] and `With<ConditionalBoundsContext> ctx(&dom_map, bounds, false)` step into
+ *scope where dom_map[i] is [9, 20]
  */
 class ConditionalBoundsContext {
  private:
@@ -241,13 +241,11 @@ class ConditionalBoundsContext {
   /*!
    * \brief Construct a condition bounds context.
    * \param condition The condition holds on true branch.
-   * \param relax_map The domain map for relaxed vars to update.
-   * \param hint_map The domain map for free vars to update.
+   * \param dom_map The global domain map to be updated.
    * \param is_true_branch Whether step into the branch where condition bounds holds.
    */
   ConditionalBoundsContext(const PrimExpr& condition,
-                           std::unordered_map<const VarNode*, arith::IntSet>* relax_map,
-                           std::unordered_map<const VarNode*, arith::IntSet>* hint_map,
+                           std::unordered_map<const VarNode*, arith::IntSet>* dom_map,
                            bool is_true_branch);
   void EnterWithScope();
   void ExitWithScope();
@@ -257,10 +255,8 @@ class ConditionalBoundsContext {
 
   /*! \brief the condition holds on true branch. */
   const PrimExpr& condition_;
-  /*! \brief domain map for relaxed vars to update */
-  std::unordered_map<const VarNode*, arith::IntSet>* relax_map_;
-  /*! \brief domain map for free vars to update */
-  std::unordered_map<const VarNode*, arith::IntSet>* hint_map_;
+  /*! \brief global domain map to updated */
+  std::unordered_map<const VarNode*, arith::IntSet>* dom_map_;
   /*! \brief whether is on true branch */
   bool is_true_branch_;
   /*! \brief used to record and restore original var bounds */
diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py
index 5403754..e508fbb 100644
--- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py
+++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py
@@ -130,41 +130,6 @@ def access_in_branch_func() -> None:
                 B[i] = A[i - 1]
 
 
-@T.prim_func
-def access_of_padding_pattern() -> None:
-    X = T.alloc_buffer([28, 28])
-    X_pad = T.alloc_buffer([32, 32])
-    Y = T.alloc_buffer([28, 28])
-    for i, j in T.grid(32, 32):
-        with T.block("padding"):
-            vi, vj = T.axis.remap("SS", [i, j])
-            T.reads(
-                [
-                    X[
-                        T.max(vi - 2, 0) : T.min(vi - 2, 27) + 1,
-                        T.max(vj - 2, 0) : T.min(vj - 2, 27) + 1,
-                    ]
-                ]
-            )
-            T.writes([X_pad[vi, vj]])
-            X_pad[vi, vj] = T.if_then_else(
-                2 <= vi and vi < 30 and 2 <= vj and vj < 30, X[vi - 2, vj - 2], 0.0, dtype="float32"
-            )
-        with T.block("padding_reverse"):
-            vi, vj = T.axis.remap("SS", [i, j])
-            T.reads([X_pad[T.max(vi, 2) : T.min(vi, 29) + 1, T.max(vj, 2) : T.min(vj, 29) + 1]])
-            T.writes(
-                [
-                    Y[
-                        T.max(vi - 2, 0) : T.min(vi - 2, 27) + 1,
-                        T.max(vj - 2, 0) : T.min(vj - 2, 27) + 1,
-                    ]
-                ]
-            )
-            if 2 <= vi and vi < 30 and 2 <= vj and vj < 30:
-                Y[vi - 2, vj - 2] = X_pad[vi, vj]
-
-
 def test_block_access_region_detector():
     block = func.body.block.body.block
     alloc_buffers = func.body.block.alloc_buffers
@@ -255,36 +220,6 @@ def test_access_in_branch_func():
     tvm.ir.assert_structural_equal(ret0[1], ret1[1])
 
 
-def test_access_of_padding_pattern():
-    s = tvm.tir.schedule.Schedule(access_of_padding_pattern)
-    alloc_buffers = s.get_sref(s.get_block("root")).stmt.alloc_buffers
-    buffer_var_map = {buf.data: buf for buf in alloc_buffers}
-
-    def do_compare_buffer_region(region, expect):
-        assert region.buffer == expect.buffer
-        analyzer = tvm.arith.Analyzer()
-        for k, rng in enumerate(region.region):
-            tvm.ir.assert_structural_equal(
-                analyzer.simplify(rng.min), analyzer.simplify(expect.region[k].min)
-            )
-            tvm.ir.assert_structural_equal(
-                analyzer.simplify(rng.extent), analyzer.simplify(expect.region[k].extent)
-            )
-
-    def do_check_block(block_name):
-        block = s.get_sref(s.get_block(block_name)).stmt
-        expect_reads = block.reads
-        expect_writes = block.writes
-        ret = tir.analysis.get_block_access_region(block, buffer_var_map)
-        for i, read in enumerate(ret[0]):
-            do_compare_buffer_region(read, expect_reads[i])
-        for i, write in enumerate(ret[1]):
-            do_compare_buffer_region(write, expect_writes[i])
-
-    do_check_block("padding")
-    do_check_block("padding_reverse")
-
-
 if __name__ == "__main__":
     test_block_access_region_detector()
     test_opaque_block()
@@ -292,4 +227,3 @@ if __name__ == "__main__":
     test_match_buffer()
     test_access_in_if_then_else_func()
     test_access_in_branch_func()
-    test_access_of_padding_pattern()
diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
index 9b84485..57c87e5 100644
--- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py
+++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
@@ -24,7 +24,6 @@ def _check(original, transformed):
     mod = tvm.IRModule.from_expr(func)
     mod = tvm.tir.transform.CompactBufferAllocation()(mod)
     mod = tvm.tir.transform.Simplify()(mod)
-    transformed = tvm.tir.transform.Simplify()(tvm.IRModule.from_expr(transformed))["main"]
     tvm.ir.assert_structural_equal(mod["main"], transformed)