You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/12/13 11:10:19 UTC

[tvm] branch main updated: [TIR] Fix PlanAndUpdateBufferAllocationLocation not visiting constant buffer (#13605)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1d9863470e [TIR] Fix PlanAndUpdateBufferAllocationLocation not visiting constant buffer (#13605)
1d9863470e is described below

commit 1d9863470e0e97413d05b98f2852dc7de60611a0
Author: masahi <ma...@gmail.com>
AuthorDate: Tue Dec 13 20:10:13 2022 +0900

    [TIR] Fix PlanAndUpdateBufferAllocationLocation not visiting constant buffer (#13605)
    
    * Fix PlanAndUpdateBufferAllocationLocation not visiting constant buffer
    
    * add comment
---
 .../plan_update_buffer_allocation_location.cc         | 19 +++++++++++++++----
 1 file changed, 15 insertions(+), 4 deletions(-)

diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc
index 4c63d3393f..11d8330ec8 100644
--- a/src/tir/transforms/plan_update_buffer_allocation_location.cc
+++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc
@@ -61,24 +61,35 @@ class BufferAllocateOrderCollector : public StmtExprVisitor {
   }
 
  private:
+  bool find(const Buffer& buf) {
+    return std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), buf) !=
+           buffer_alloc_recorder_.end();
+  }
+
   void VisitStmt_(const BlockNode* op) final {
     for (const Buffer& buffer : op->alloc_buffers) {
       buffer_alloc_recorder_.push_back(buffer);
     }
+    // Also visit match_buffers to collect constant buffers associated with AllocateConst nodes.
+    // These buffers only appear in read and match_buffer regions.
+    for (const auto& region : op->match_buffers) {
+      if (!find(region->source->buffer)) {
+        buffer_alloc_recorder_.push_back(region->source->buffer);
+      }
+    }
+
     StmtExprVisitor::VisitStmt_(op);
   }
 
   void VisitExpr_(const BufferLoadNode* op) final {
-    if (std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), op->buffer) ==
-        buffer_alloc_recorder_.end()) {
+    if (!find(op->buffer)) {
       buffer_alloc_recorder_.push_back(op->buffer);
     }
     StmtExprVisitor::VisitExpr_(op);
   }
 
   void VisitStmt_(const BufferStoreNode* op) final {
-    if (std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), op->buffer) ==
-        buffer_alloc_recorder_.end()) {
+    if (!find(op->buffer)) {
       buffer_alloc_recorder_.push_back(op->buffer);
     }
     StmtExprVisitor::VisitStmt_(op);