You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/12/08 08:47:31 UTC

[GitHub] [tvm] Hzfengsy commented on a diff in pull request #13560: [TIR][Transform] Keep the allocate buffers order after update buffer allocation location

Hzfengsy commented on code in PR #13560:
URL: https://github.com/apache/tvm/pull/13560#discussion_r1043076105


##########
src/tir/transforms/plan_update_buffer_allocation_location.cc:
##########
@@ -48,10 +48,51 @@ class CollectUnmanagedAllocations : public StmtExprVisitor {
   std::unordered_set<const VarNode*> unmanaged_allocations;
 };
 
+/*! \brief Collect the allocate buffer order. */
+class BufferAllocateOrderCollector : public StmtExprVisitor {
+ public:
+  static Array<Buffer> Collect(const PrimFunc& func) {
+    BufferAllocateOrderCollector collector;
+    for (const auto& kv : func->buffer_map) {
+      collector.buffer_alloc_recorder_.push_back(kv.second);
+    }
+    collector(func->body);
+    return std::move(collector.buffer_alloc_recorder_);
+  }
+
+ private:
+  void VisitStmt_(const BlockNode* op) final {
+    for (const Buffer& buffer : op->alloc_buffers) {
+      buffer_alloc_recorder_.push_back(buffer);
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  void VisitExpr_(const BufferLoadNode* op) final {
+    if (std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), op->buffer) ==
+        buffer_alloc_recorder_.end()) {
+      buffer_alloc_recorder_.push_back(op->buffer);
+    }
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitStmt_(const BufferStoreNode* op) final {
+    if (std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), op->buffer) ==
+        buffer_alloc_recorder_.end()) {
+      buffer_alloc_recorder_.push_back(op->buffer);
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+
+  /*! \brief The buffer allocated order recorder. */
+  Array<Buffer> buffer_alloc_recorder_;
+};
+
 class BufferAllocationLocator : public StmtExprMutator {
  public:
   explicit BufferAllocationLocator(const PrimFunc& func) {
     Map<Buffer, Optional<Stmt>> buffer_lca = DetectBufferAccessLCA(func);
+    Array<Buffer> buffer_alloc_recorder = BufferAllocateOrderCollector::Collect(func);

Review Comment:
   Please add comments saying that the `buffer_alloc_recorder` is used for keeping the allocation order since the Map is unordered.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org