You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by cs...@apache.org on 2022/08/12 20:16:32 UTC

[tvm] branch main updated: [LowerVTCMAlloc] Move LowerVtcmAlloc to after StorageRewrite (#12364)

This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3eb673478b [LowerVTCMAlloc] Move LowerVtcmAlloc to after StorageRewrite (#12364)
3eb673478b is described below

commit 3eb673478bc444daf24ee8d6308a42a71c81b74f
Author: Tristan Konolige <tk...@octoml.ai>
AuthorDate: Fri Aug 12 13:16:23 2022 -0700

    [LowerVTCMAlloc] Move LowerVtcmAlloc to after StorageRewrite (#12364)
    
    * [LowerVTCMAlloc] Move LowerVtcmAlloc to after StorageRewrite
    
    Vtcm allocations were being moved inside loops even if they were
    originally allocated outside of the loops. Normally
    PlanAndUpdateBufferAllocationLocation moves allocations as close to use
    as possible and then StorageRewrite moves them back out as far as
    possible. However, with Vtcm allocation,
    PlanAndUpdateBufferAllocationLocation would move the Vtcm allocation
    close to the compute, then LowerVtcm would convert the allocation to a
    LetStmt. StorageRewrite would not move this LetStmt as it only handles
    allocations. Moving LowerVtcmAlloc to after StorageRewrite ensures that
    the vtcm allocations are in their final spot before converting them to a
    LetStmt.
    
    * fix issues with tagging and storage rewrite
---
 src/driver/driver_api.cc              | 3 ++-
 src/tir/transforms/storage_rewrite.cc | 6 +++---
 2 files changed, 5 insertions(+), 4 deletions(-)

diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index cbf809a267..9bd2e8a812 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -204,7 +204,6 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
   pass_list.push_back(tir::transform::InjectSoftwarePipeline());
   pass_list.push_back(tir::transform::LowerOpaqueBlock());
   pass_list.push_back(tir::transform::FlattenBuffer());
-  pass_list.push_back(tir::transform::LowerVtcmAlloc());
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
@@ -223,6 +222,8 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
   if (!disable_storage_rewrite) {
     pass_list.push_back(tir::transform::StorageRewrite());
   }
+  // LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
+  pass_list.push_back(tir::transform::LowerVtcmAlloc());
   pass_list.push_back(tir::transform::UnrollLoop());
 
   // Add user-defined phase-2 passes
diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc
index 5a326d9fac..d15bed56fd 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -583,8 +583,10 @@ class StoragePlanRewriter : public StmtExprMutator {
   };
 
   // Checks whether the storage_scope is especially tagged for a specific memory.
+  // Special memory is all combined into a single allocation.
   bool IsSpecialTaggedMemory(const StorageScope& scope) {
-    return scope.tag.length() != 0 && scope.tag != ".dyn" && scope.tag != ".workspace";
+    return scope.tag.length() != 0 && scope.tag != ".dyn" && scope.tag != ".workspace" &&
+           scope.tag != ".vtcm";
   }
 
   // Alllocate entry of node.
@@ -655,8 +657,6 @@ class StoragePlanRewriter : public StmtExprMutator {
 
         if (e->allocs.size() == 1) {
           // simply use the original allocation.
-          PrimExpr sz = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
-                              make_const(DataType::Int(32), 1), e->allocs[0]->extents);
           e->new_alloc = Allocate(e->alloc_var, alloc_type, e->allocs[0]->extents,
                                   e->allocs[0]->condition, Evaluate(0));
           if (IsSpecialTaggedMemory(e->scope)) {