You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "Lunderberg (via GitHub)" <gi...@apache.org> on 2023/04/17 14:32:39 UTC

[GitHub] [tvm] Lunderberg commented on a diff in pull request #14021: [TIR] More flexible buffer compaction

Lunderberg commented on code in PR #14021:
URL: https://github.com/apache/tvm/pull/14021#discussion_r1168786999


##########
src/tir/transforms/compact_buffer_region.cc:
##########
@@ -103,17 +59,62 @@ NDIntSet NDIntSetEval(Region region, PrimExpr predicate,
   return support::NDIntSetEval(support::NDIntSetFromRegion(region), dom_map);
 }
 
+/*!
+ * \brief Collect buffer aliasing information.
+ */
+class Var2BufferCollector : public StmtExprVisitor {
+ public:
+  /*! \brief Map the buffer var to all aliased buffers. */
+  std::unordered_map<Var, std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>, ObjectPtrHash,
+                     ObjectPtrEqual>
+      var2buffer_;
+
+ private:
+  void VisitStmt_(const BufferStoreNode* op) final {
+    var2buffer_[op->buffer->data].insert(op->buffer);
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const BufferLoadNode* op) final {
+    var2buffer_[op->buffer->data].insert(op->buffer);
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    for (const Buffer& buffer : op->alloc_buffers) {
+      var2buffer_[buffer->data].insert(buffer);
+    }
+    for (const MatchBufferRegion& region : op->match_buffers) {
+      var2buffer_[region->buffer->data].insert(region->buffer);
+      var2buffer_[region->source->buffer->data].insert(region->source->buffer);
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitStmt_(const DeclBufferNode* op) final {
+    var2buffer_[op->buffer->data].insert(op->buffer);
+    StmtExprVisitor::VisitStmt_(op);
+  }
+};
+
 /*!
  * \brief Collect the access region of each buffer.
  * \note The param buffer regions will not be collected.
  */
 class BufferAccessRegionCollector : public StmtExprVisitor {
  public:
   static std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> Collect(
-      const PrimFunc& f) {
-    BufferAccessRegionCollector collector;
-    collector(f->body);
-    return std::move(collector.buffer_access_region_);
+      const PrimFunc& f, bool collect_inbound) {

Review Comment:
   The functionality of this class looks very similar to that of the [`DomainTouched`](https://github.com/apache/tvm/blob/main/src/arith/domain_touched.cc#L133) utility.  Should the two implementations be merged?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org