You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/04/14 20:18:14 UTC

[tvm] branch main updated: [TIR] Ignore Allocate/AllocateConst in BufferAllocationLocator (#10998)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b94119692e [TIR] Ignore Allocate/AllocateConst in BufferAllocationLocator (#10998)
b94119692e is described below

commit b94119692eaa7307201fbad3e3434f8721c50ede
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Thu Apr 14 15:18:09 2022 -0500

    [TIR] Ignore Allocate/AllocateConst in BufferAllocationLocator (#10998)
    
    * [TIR] Ignore Allocate/AllocateConst in BufferAllocationLocator
    
    Prior to this commit, the BufferAllocationLocator mutator used in the
    PlanAndUpdateBufferAllocationLocation pass would erroneously insert an
    entry to `BlockNode::alloc_buffers` for buffers allocated using
    `Allocate` or `AllocateConst` nodes.  This error was introduced in
    https://github.com/apache/tvm/pull/9727, which deprecated `Load` and
    `Store` nodes, replacing them with `BufferLoad` and `BufferStore`
    nodes.  As a result, BufferAllocationLocator identified these as
    buffers whose allocations should be moved to inner loops, rather than
    as unmanaged allocations that should be ignored.
    
    This commit restores the earlier behavior by only operating on buffer
    allocations in `BlockNode::alloc_buffers`, and explicitly ignoring any
    buffers whose allocation is done with `Allocate` or `AllocateConst`.
    
    * Only inject opaque block if managed buffers exist.
    
    Previously, all buffers found were managed buffers, so this check
    wasn't needed.
---
 .../plan_update_buffer_allocation_location.cc      | 33 ++++++++++++++++------
 .../test_tir_transform_extract_constants.py        |  2 ++
 2 files changed, 27 insertions(+), 8 deletions(-)

diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc
index 6b495b3bf4..81dfceb40d 100644
--- a/src/tir/transforms/plan_update_buffer_allocation_location.cc
+++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc
@@ -61,16 +61,21 @@ class BufferAllocationLocator : public StmtExprMutator {
     for (const Buffer& buf : it->second) {
       buffer_data_to_buffer_.Set(buf->data, buf);
     }
-    Stmt stmt = StmtMutator::VisitStmt_(op);
-    op = stmt.as<ForNode>();
-    ICHECK(op != nullptr);
+    auto node = Downcast<For>(StmtMutator::VisitStmt_(op));
+
+    Array<Buffer> new_block_alloc_bufs;
     for (const Buffer& buf : it->second) {
-      buffer_data_to_buffer_.erase(buf->data);
+      if (!unmanaged_allocations_.count(buf->data.get())) {
+        buffer_data_to_buffer_.erase(buf->data);
+        new_block_alloc_bufs.push_back(buf);
+      }
     }
-    Stmt body = InjectOpaqueBlock(op->body, it->second);
-    ObjectPtr<ForNode> n = CopyOnWrite(op);
-    n->body = std::move(body);
-    return Stmt(n);
+
+    if (new_block_alloc_bufs.size()) {
+      node.CopyOnWrite()->body = InjectOpaqueBlock(node->body, new_block_alloc_bufs);
+    }
+
+    return std::move(node);
   }
 
   Stmt VisitStmt_(const BlockNode* op) final {
@@ -114,6 +119,16 @@ class BufferAllocationLocator : public StmtExprMutator {
     return Stmt(n);
   }
 
+  Stmt VisitStmt_(const AllocateNode* op) final {
+    unmanaged_allocations_.insert(op->buffer_var.get());
+    return StmtExprMutator::VisitStmt_(op);
+  }
+
+  Stmt VisitStmt_(const AllocateConstNode* op) final {
+    unmanaged_allocations_.insert(op->buffer_var.get());
+    return StmtExprMutator::VisitStmt_(op);
+  }
+
   Stmt VisitStmt_(const BufferRealizeNode* op) final {
     ICHECK(false) << "Internal Error: BufferRealizeNode is not allowed in TensorIR.";
     throw;
@@ -151,6 +166,8 @@ class BufferAllocationLocator : public StmtExprMutator {
   std::unordered_map<const StmtNode*, Array<Buffer>> alloc_buffers_;
   /*! \brief The buffer already allocated during recursive visiting. */
   Map<Var, Buffer> buffer_data_to_buffer_;
+  /*! \brief Buffers that are allocated outside of the BlockNode, and should not be moved. */
+  std::unordered_set<const VarNode*> unmanaged_allocations_;
 };
 
 PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) {
diff --git a/tests/python/unittest/test_tir_transform_extract_constants.py b/tests/python/unittest/test_tir_transform_extract_constants.py
index 9636a9bdde..cb49e7286f 100644
--- a/tests/python/unittest/test_tir_transform_extract_constants.py
+++ b/tests/python/unittest/test_tir_transform_extract_constants.py
@@ -59,6 +59,8 @@ def test_const_extraction():
     for n, f in mod.functions.items():
         tvm.tir.stmt_functor.post_order_visit(f.body, _visit)
 
+    tvm.lower(mod)
+
 
 if __name__ == "__main__":
     test_const_extraction()