You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/12/13 11:10:19 UTC
[tvm] branch main updated: [TIR] Fix PlanAndUpdateBufferAllocationLocation not visiting constant buffer (#13605)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1d9863470e [TIR] Fix PlanAndUpdateBufferAllocationLocation not visiting constant buffer (#13605)
1d9863470e is described below
commit 1d9863470e0e97413d05b98f2852dc7de60611a0
Author: masahi <ma...@gmail.com>
AuthorDate: Tue Dec 13 20:10:13 2022 +0900
[TIR] Fix PlanAndUpdateBufferAllocationLocation not visiting constant buffer (#13605)
* Fix PlanAndUpdateBufferAllocationLocation not visiting constant buffer
* add comment
---
.../plan_update_buffer_allocation_location.cc | 19 +++++++++++++++----
1 file changed, 15 insertions(+), 4 deletions(-)
diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc
index 4c63d3393f..11d8330ec8 100644
--- a/src/tir/transforms/plan_update_buffer_allocation_location.cc
+++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc
@@ -61,24 +61,35 @@ class BufferAllocateOrderCollector : public StmtExprVisitor {
}
private:
+ bool find(const Buffer& buf) {
+ return std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), buf) !=
+ buffer_alloc_recorder_.end();
+ }
+
void VisitStmt_(const BlockNode* op) final {
for (const Buffer& buffer : op->alloc_buffers) {
buffer_alloc_recorder_.push_back(buffer);
}
+ // Also visit match_buffers to collect constant buffers associated with AllocateConst nodes.
+ // These buffers only appear in read and match_buffer regions.
+ for (const auto& region : op->match_buffers) {
+ if (!find(region->source->buffer)) {
+ buffer_alloc_recorder_.push_back(region->source->buffer);
+ }
+ }
+
StmtExprVisitor::VisitStmt_(op);
}
void VisitExpr_(const BufferLoadNode* op) final {
- if (std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), op->buffer) ==
- buffer_alloc_recorder_.end()) {
+ if (!find(op->buffer)) {
buffer_alloc_recorder_.push_back(op->buffer);
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode* op) final {
- if (std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), op->buffer) ==
- buffer_alloc_recorder_.end()) {
+ if (!find(op->buffer)) {
buffer_alloc_recorder_.push_back(op->buffer);
}
StmtExprVisitor::VisitStmt_(op);