You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/10/21 14:36:31 UTC

[GitHub] [tvm] jinhongyii opened a new pull request #9341: [CUDA] Support memory reuse for dynamic shared memory

jinhongyii opened a new pull request #9341:
URL: https://github.com/apache/tvm/pull/9341


   This PR applies part of the algorithm of storage rewrite to detect memory reuse possibility for dynamic shared memory, and does rewrite to the body. This functionality is important in matmul because we may want to allocate the shared memory for data fetch and the shared memory for write back in the same location.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jinhongyii commented on pull request #9341: [CUDA] Support memory reuse for dynamic shared memory

Posted by GitBox <gi...@apache.org>.
jinhongyii commented on pull request #9341:
URL: https://github.com/apache/tvm/pull/9341#issuecomment-948683072


   cc: @junrushao1994 @vinx13 @masahi 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jinhongyii commented on pull request #9341: [CUDA] Support memory reuse for dynamic shared memory

Posted by GitBox <gi...@apache.org>.
jinhongyii commented on pull request #9341:
URL: https://github.com/apache/tvm/pull/9341#issuecomment-949269726


   @masahi 
   1. storage_rewrite can't handle storage scope with tag, and shared.dyn has tag
   2. storage_rewrite can't handle buffer with different dtypes
   3. storage_rewrite can't allocate a buffer in the place of another 2 buffers.  For example, you can take a look at https://github.com/jinhongyii/tvm/blob/dyn-shared-mem/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py#L80-L104
      C_sh's allocation reuses the memory of A_sh **and** B_sh. However, if we use the original algorithm of storage rewrite, it can only allocate C_sh in the place of A_sh **or** B_sh. The result allocation size would be block * block * 6


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jinhongyii commented on pull request #9341: [CUDA] Support memory reuse for dynamic shared memory

Posted by GitBox <gi...@apache.org>.
jinhongyii commented on pull request #9341:
URL: https://github.com/apache/tvm/pull/9341#issuecomment-953488082


   @vinx13 @Hzfengsy comments are all addressed. Please take another look.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #9341: [CUDA] Support memory reuse for dynamic shared memory

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #9341:
URL: https://github.com/apache/tvm/pull/9341#issuecomment-949314783


   Thanks @jinhongyii, that you are able to reduce the alloc size from `block * block * 3 * 4` to `block * block * 4` in the test is very cool!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on pull request #9341: [CUDA] Support memory reuse for dynamic shared memory

Posted by GitBox <gi...@apache.org>.
vinx13 commented on pull request #9341:
URL: https://github.com/apache/tvm/pull/9341#issuecomment-948937168


   I think this is to support reusing heterogenous data type allocs


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jinhongyii edited a comment on pull request #9341: [CUDA] Support memory reuse for dynamic shared memory

Posted by GitBox <gi...@apache.org>.
jinhongyii edited a comment on pull request #9341:
URL: https://github.com/apache/tvm/pull/9341#issuecomment-949269726


   @masahi 
   1. storage_rewrite can't handle buffer with different dtypes
   2. storage_rewrite can't allocate a buffer in the place of another 2 buffers.  For example, you can take a look at https://github.com/jinhongyii/tvm/blob/dyn-shared-mem/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py#L80-L104
      C_sh's allocation reuses the memory of A_sh **and** B_sh. However, if we use the original algorithm of storage rewrite, it can only allocate C_sh in the place of A_sh **or** B_sh. The result allocation size would be block * block * 6


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #9341: [CUDA] Support memory reuse for dynamic shared memory

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #9341:
URL: https://github.com/apache/tvm/pull/9341#issuecomment-950428665


   Let's double check and get it in if there is no further issue :-)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Hzfengsy commented on a change in pull request #9341: [CUDA] Support memory reuse for dynamic shared memory

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on a change in pull request #9341:
URL: https://github.com/apache/tvm/pull/9341#discussion_r735244224



##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -27,6 +27,8 @@
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/transform.h>
 
+#include <list>
+#include <map>

Review comment:
       Please delete unused packages `list` and `map`

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -36,46 +38,250 @@
 namespace tvm {
 namespace tir {
 
+using runtime::StorageRank;
+using runtime::StorageScope;
+
 bool IsDynamicSharedMemory(Var buffer_var) {
-  auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
   return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
 }
 
+/*!
+ * \brief collect the mapping from the buffer var to its allocate
+ */
 class AllocateCollector : public StmtExprVisitor {
  public:
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      dyn_shmem_allocs_.insert(op);
+      dyn_shmem_allocs_[op->buffer_var.get()] = op;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
+  // The mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+};
+
+// Find a linear pattern of storage access
+// Used for liveness analysis.
+// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
+// before_scope -> scope_body -> after_scope
+//
+// The linear_seq_ stores before_scope and after_scope.
+// The access to the arrays are stored at the after_scope point.
+//
+// Define "scope" as the body of For/thread_launch/IfThenElse
+// This pass tries to detect last point that we need to keep memory
+// alive under the same scope as allocate.
+// The storage need to be kept alive between allocate and last access.
+// The free point is only inserted at the same scope of allocate.
+//
+class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+ public:
+  /*! \brief record the touch hist of statment. */
+  struct StmtEntry {
+    // The statment
+    const Object* stmt;
+    // The index in the linear_seq_ to point to end of the nested scope.
+    // This is only set to non-zero if stmt is a nested scope.
+    // if offset > 0, means this is the begin, the end entry is current_index + offset
+    // if offset < 0, means this is the end, the begin entry is current_index + offset
+    int64_t scope_pair_offset{0};
+    // The buffer variables this statment touched.
+    std::vector<const VarNode*> touched;
+  };
+  // The scope of each allocation
+  struct AllocEntry {
+    // scope level
+    size_t level{0};
+    // allocation stmt
+    const AllocateNode* alloc{nullptr};
+  };
 
-  std::unordered_set<const AllocateNode*> dyn_shmem_allocs_;
+  void VisitStmt_(const AllocateNode* op) final {
+    size_t level = scope_.size();
+    const VarNode* buf = op->buffer_var.get();
+    alloc_info_[buf].alloc = op;
+    alloc_info_[buf].level = level;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+  void VisitStmt_(const StoreNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    // Add write access.
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size());
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitStmt_(const EvaluateNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitExpr_(const LoadNode* op) final {
+    // Add write access.
+    StmtExprVisitor::VisitExpr_(op);
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store.";
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+  }
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::address_of())) {
+      const LoadNode* l = op->args[0].as<LoadNode>();
+      this->VisitExpr(l->index);
+    } else {
+      StmtExprVisitor::VisitExpr_(op);
+    }
+  }
+  void VisitExpr_(const VarNode* buf) final {
+    // Directly reference to the variable count as a read.
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint;
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+  }
+  template <typename T>
+  void VisitNewScope(const T* op) {
+    scope_.push_back(StmtEntry());
+    StmtEntry e;
+    e.stmt = op;
+    int64_t begin_index = static_cast<int64_t>(linear_seq_.size());
+    // before scope.
+    linear_seq_.push_back(e);
+    StmtExprVisitor::VisitStmt_(op);
+    // after scope.
+    e.touched = std::move(scope_.back().touched);
+    scope_.pop_back();
+    int64_t end_index = static_cast<int64_t>(linear_seq_.size());
+    ICHECK_GT(end_index, begin_index);
+    e.scope_pair_offset = begin_index - end_index;
+    linear_seq_.push_back(e);
+    // record the pointer to end index.
+    ICHECK_NE(end_index, 0U);
+    linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
+  }
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Only record the outer most thread extent.
+    if (op->attr_key == attr::thread_extent && !in_thread_env_) {
+      in_thread_env_ = true;
+      VisitNewScope(op);
+      in_thread_env_ = false;
+    } else if (op->attr_key == attr::extern_scope) {
+      VisitNewScope(op);
+    } else if (op->attr_key == attr::virtual_thread) {
+      VisitNewScope(op);
+    } else {
+      StmtExprVisitor::VisitStmt_(op);
+    }
+  }
+  void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const ForNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const WhileNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); }
+
+  // linearized access sequence.
+  std::vector<StmtEntry> linear_seq_;
+  // The storage scope of each buffer
+  std::unordered_map<const VarNode*, AllocEntry> alloc_info_;
+
+ private:
+  // Whether already in thread env.
+  bool in_thread_env_{false};
+  // The scope stack.
+  std::vector<StmtEntry> scope_;
 };
 
+/*!
+ * \brief merge the buffers whose live range has no intersection and rewrite the body
+ */
 class DynamicSharedMemoryRewriter : public StmtExprMutator {
  public:
   explicit DynamicSharedMemoryRewriter(
-      const std::unordered_set<const AllocateNode*>& dyn_shmem_allocs)
+      const std::unordered_map<const VarNode*, const AllocateNode*>& dyn_shmem_allocs)
       : dyn_shmem_allocs_{dyn_shmem_allocs} {}
 
+  /*!
+   * \brief plan the memory reuse for all the buffer allocated in the statement
+   * @param stmt the statement

Review comment:
       ```suggestion
      * \param stmt the statement.
   ```

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -36,46 +38,250 @@
 namespace tvm {
 namespace tir {
 
+using runtime::StorageRank;
+using runtime::StorageScope;
+
 bool IsDynamicSharedMemory(Var buffer_var) {
-  auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
   return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
 }
 
+/*!
+ * \brief collect the mapping from the buffer var to its allocate
+ */
 class AllocateCollector : public StmtExprVisitor {
  public:
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      dyn_shmem_allocs_.insert(op);
+      dyn_shmem_allocs_[op->buffer_var.get()] = op;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
+  // The mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+};
+
+// Find a linear pattern of storage access
+// Used for liveness analysis.
+// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
+// before_scope -> scope_body -> after_scope
+//
+// The linear_seq_ stores before_scope and after_scope.
+// The access to the arrays are stored at the after_scope point.
+//
+// Define "scope" as the body of For/thread_launch/IfThenElse
+// This pass tries to detect last point that we need to keep memory
+// alive under the same scope as allocate.
+// The storage need to be kept alive between allocate and last access.
+// The free point is only inserted at the same scope of allocate.
+//
+class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+ public:
+  /*! \brief record the touch hist of statment. */
+  struct StmtEntry {
+    // The statment
+    const Object* stmt;
+    // The index in the linear_seq_ to point to end of the nested scope.
+    // This is only set to non-zero if stmt is a nested scope.
+    // if offset > 0, means this is the begin, the end entry is current_index + offset
+    // if offset < 0, means this is the end, the begin entry is current_index + offset
+    int64_t scope_pair_offset{0};
+    // The buffer variables this statment touched.
+    std::vector<const VarNode*> touched;
+  };
+  // The scope of each allocation
+  struct AllocEntry {
+    // scope level
+    size_t level{0};
+    // allocation stmt
+    const AllocateNode* alloc{nullptr};
+  };
 
-  std::unordered_set<const AllocateNode*> dyn_shmem_allocs_;
+  void VisitStmt_(const AllocateNode* op) final {
+    size_t level = scope_.size();
+    const VarNode* buf = op->buffer_var.get();
+    alloc_info_[buf].alloc = op;
+    alloc_info_[buf].level = level;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+  void VisitStmt_(const StoreNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    // Add write access.
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size());
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitStmt_(const EvaluateNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitExpr_(const LoadNode* op) final {
+    // Add write access.
+    StmtExprVisitor::VisitExpr_(op);
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store.";
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+  }
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::address_of())) {
+      const LoadNode* l = op->args[0].as<LoadNode>();
+      this->VisitExpr(l->index);
+    } else {
+      StmtExprVisitor::VisitExpr_(op);
+    }
+  }
+  void VisitExpr_(const VarNode* buf) final {
+    // Directly reference to the variable count as a read.
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint;

Review comment:
       Either please add complete error msg or do not print any information

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -36,46 +38,250 @@
 namespace tvm {
 namespace tir {
 
+using runtime::StorageRank;
+using runtime::StorageScope;
+
 bool IsDynamicSharedMemory(Var buffer_var) {
-  auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
   return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
 }
 
+/*!
+ * \brief collect the mapping from the buffer var to its allocate
+ */
 class AllocateCollector : public StmtExprVisitor {
  public:
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      dyn_shmem_allocs_.insert(op);
+      dyn_shmem_allocs_[op->buffer_var.get()] = op;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
+  // The mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+};
+
+// Find a linear pattern of storage access
+// Used for liveness analysis.
+// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
+// before_scope -> scope_body -> after_scope
+//
+// The linear_seq_ stores before_scope and after_scope.
+// The access to the arrays are stored at the after_scope point.
+//
+// Define "scope" as the body of For/thread_launch/IfThenElse
+// This pass tries to detect last point that we need to keep memory
+// alive under the same scope as allocate.
+// The storage need to be kept alive between allocate and last access.
+// The free point is only inserted at the same scope of allocate.
+//
+class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+ public:
+  /*! \brief record the touch hist of statment. */
+  struct StmtEntry {
+    // The statment
+    const Object* stmt;
+    // The index in the linear_seq_ to point to end of the nested scope.
+    // This is only set to non-zero if stmt is a nested scope.
+    // if offset > 0, means this is the begin, the end entry is current_index + offset
+    // if offset < 0, means this is the end, the begin entry is current_index + offset
+    int64_t scope_pair_offset{0};
+    // The buffer variables this statment touched.
+    std::vector<const VarNode*> touched;
+  };
+  // The scope of each allocation
+  struct AllocEntry {
+    // scope level
+    size_t level{0};
+    // allocation stmt
+    const AllocateNode* alloc{nullptr};
+  };
 
-  std::unordered_set<const AllocateNode*> dyn_shmem_allocs_;
+  void VisitStmt_(const AllocateNode* op) final {
+    size_t level = scope_.size();
+    const VarNode* buf = op->buffer_var.get();
+    alloc_info_[buf].alloc = op;
+    alloc_info_[buf].level = level;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+  void VisitStmt_(const StoreNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    // Add write access.
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size());
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitStmt_(const EvaluateNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitExpr_(const LoadNode* op) final {
+    // Add write access.
+    StmtExprVisitor::VisitExpr_(op);
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store.";
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+  }
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::address_of())) {
+      const LoadNode* l = op->args[0].as<LoadNode>();
+      this->VisitExpr(l->index);
+    } else {
+      StmtExprVisitor::VisitExpr_(op);
+    }
+  }
+  void VisitExpr_(const VarNode* buf) final {
+    // Directly reference to the variable count as a read.
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint;
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+  }
+  template <typename T>
+  void VisitNewScope(const T* op) {
+    scope_.push_back(StmtEntry());
+    StmtEntry e;
+    e.stmt = op;
+    int64_t begin_index = static_cast<int64_t>(linear_seq_.size());
+    // before scope.
+    linear_seq_.push_back(e);
+    StmtExprVisitor::VisitStmt_(op);
+    // after scope.
+    e.touched = std::move(scope_.back().touched);
+    scope_.pop_back();
+    int64_t end_index = static_cast<int64_t>(linear_seq_.size());
+    ICHECK_GT(end_index, begin_index);
+    e.scope_pair_offset = begin_index - end_index;
+    linear_seq_.push_back(e);
+    // record the pointer to end index.
+    ICHECK_NE(end_index, 0U);
+    linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
+  }
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Only record the outer most thread extent.
+    if (op->attr_key == attr::thread_extent && !in_thread_env_) {
+      in_thread_env_ = true;
+      VisitNewScope(op);
+      in_thread_env_ = false;
+    } else if (op->attr_key == attr::extern_scope) {
+      VisitNewScope(op);
+    } else if (op->attr_key == attr::virtual_thread) {
+      VisitNewScope(op);
+    } else {
+      StmtExprVisitor::VisitStmt_(op);
+    }
+  }
+  void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const ForNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const WhileNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); }
+
+  // linearized access sequence.
+  std::vector<StmtEntry> linear_seq_;
+  // The storage scope of each buffer
+  std::unordered_map<const VarNode*, AllocEntry> alloc_info_;
+
+ private:
+  // Whether already in thread env.
+  bool in_thread_env_{false};
+  // The scope stack.
+  std::vector<StmtEntry> scope_;
 };
 
+/*!
+ * \brief merge the buffers whose live range has no intersection and rewrite the body
+ */
 class DynamicSharedMemoryRewriter : public StmtExprMutator {
  public:
   explicit DynamicSharedMemoryRewriter(
-      const std::unordered_set<const AllocateNode*>& dyn_shmem_allocs)
+      const std::unordered_map<const VarNode*, const AllocateNode*>& dyn_shmem_allocs)
       : dyn_shmem_allocs_{dyn_shmem_allocs} {}
 
+  /*!
+   * \brief plan the memory reuse for all the buffer allocated in the statement
+   * @param stmt the statement
+   */
+  void PlanReuse(Stmt stmt) {

Review comment:
       ```suggestion
     void PlanReuse(const Stmt& stmt) {
   ```

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -36,46 +38,250 @@
 namespace tvm {
 namespace tir {
 
+using runtime::StorageRank;
+using runtime::StorageScope;
+
 bool IsDynamicSharedMemory(Var buffer_var) {
-  auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
   return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
 }
 
+/*!
+ * \brief collect the mapping from the buffer var to its allocate
+ */
 class AllocateCollector : public StmtExprVisitor {
  public:
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      dyn_shmem_allocs_.insert(op);
+      dyn_shmem_allocs_[op->buffer_var.get()] = op;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
+  // The mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+};
+
+// Find a linear pattern of storage access
+// Used for liveness analysis.
+// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
+// before_scope -> scope_body -> after_scope
+//
+// The linear_seq_ stores before_scope and after_scope.
+// The access to the arrays are stored at the after_scope point.
+//
+// Define "scope" as the body of For/thread_launch/IfThenElse
+// This pass tries to detect last point that we need to keep memory
+// alive under the same scope as allocate.
+// The storage need to be kept alive between allocate and last access.
+// The free point is only inserted at the same scope of allocate.
+//
+class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+ public:
+  /*! \brief record the touch hist of statment. */
+  struct StmtEntry {
+    // The statment
+    const Object* stmt;
+    // The index in the linear_seq_ to point to end of the nested scope.
+    // This is only set to non-zero if stmt is a nested scope.
+    // if offset > 0, means this is the begin, the end entry is current_index + offset
+    // if offset < 0, means this is the end, the begin entry is current_index + offset
+    int64_t scope_pair_offset{0};
+    // The buffer variables this statment touched.
+    std::vector<const VarNode*> touched;
+  };
+  // The scope of each allocation
+  struct AllocEntry {
+    // scope level

Review comment:
       Please document a bit more about `level`

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -90,42 +296,278 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
 
   PrimExpr VisitExpr_(const LoadNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      auto offset = GetBufferOffset(op->buffer_var, op->dtype);
-      auto index = StmtExprMutator::VisitExpr(op->index);
+      PrimExpr offset = GetBufferOffset(op->buffer_var, op->dtype);
+      PrimExpr index = StmtExprMutator::VisitExpr(op->index);
       return Load(op->dtype, merged_buf_var_, offset + index, op->predicate, op->span);
     }
     return StmtExprMutator::VisitExpr_(op);
   }
 
   Stmt VisitStmt_(const StoreNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      auto offset = GetBufferOffset(op->buffer_var, op->value->dtype);
-      auto index = StmtExprMutator::VisitExpr(op->index);
-      auto value = StmtExprMutator::VisitExpr(op->value);
+      PrimExpr offset = GetBufferOffset(op->buffer_var, op->value->dtype);
+      PrimExpr index = StmtExprMutator::VisitExpr(op->index);
+      PrimExpr value = StmtExprMutator::VisitExpr(op->value);
       return Store(merged_buf_var_, value, offset + index, op->predicate, op->span);
     }
     return StmtExprMutator::VisitStmt_(op);
   }
 
+  PrimExpr VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::tvm_access_ptr())) {
+      ICHECK_EQ(op->args.size(), 5U);
+      DataType dtype = op->args[0].dtype();
+      Var buffer = Downcast<Var>(op->args[1]);
+      if (!IsDynamicSharedMemory(buffer)) {
+        return StmtExprMutator::VisitExpr_(op);
+      }
+      PrimExpr extra_offset = GetBufferOffset(buffer, dtype);
+
+      PrimExpr offset = this->VisitExpr(op->args[2]);
+      PrimExpr extent = this->VisitExpr(op->args[3]);
+      return Call(op->dtype, op->op,
+                  {op->args[0], merged_buf_var_, extra_offset + offset, extent, op->args[4]});
+    } else {
+      return StmtExprMutator::VisitExpr_(op);
+    }
+  }
+
  private:
   PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) {
     auto it = buffer_byte_offsets_.find(buffer_var.get());
-    ICHECK(it != buffer_byte_offsets_.end());
+    ICHECK(it != buffer_byte_offsets_.end()) << buffer_var;
     return indexdiv(it->second, dtype.bytes());
   }
 
+  using StmtEntry = DynSharedMemLinearAccessPatternFinder::StmtEntry;
+  struct StorageEntry {
+    // The constant size of the buffer in bits, only used if it is constant
+    uint64_t const_nbits{0};
+    // Allocs that shares this entry.
+    // The inner vector means a "layer"
+    // For example, it we need to allocate C in the memory of A and B:
+    // |  A: 4096 bytes |  B: 4096 bytes |
+    // |            C: 8192 bytes        |
+    // Then the allocs = {{A, B}, {C}}
+    std::vector<std::vector<const VarNode*>> allocs;
+  };
+
+  // Event entry in liveness analysis
+  struct EventEntry {
+    // variables we generate
+    std::vector<const VarNode*> gen;
+    // variables we kill
+    std::vector<const VarNode*> kill;
+  };
+
+  /*!
+   * \brief Liveness analysis to find gen and kill point of each variable.
+   * \param seq the linear pattern of storage access
+   */
+  void LivenessAnalysis(const std::vector<StmtEntry>& seq) {
+    // find kill point, do a reverse linear scan.
+    std::unordered_set<const VarNode*> touched;
+    for (size_t i = seq.size(); i != 0; --i) {
+      const StmtEntry& s = seq[i - 1];
+      for (const VarNode* buffer : s.touched) {
+        if (!touched.count(buffer)) {
+          touched.insert(buffer);
+          event_map_[s.stmt].kill.push_back(buffer);
+        }
+      }
+    }
+    // find gen point, do forward scan
+    touched.clear();
+    for (size_t i = 0; i < seq.size(); ++i) {
+      int64_t offset = seq[i].scope_pair_offset;
+      if (offset < 0) continue;
+      const StmtEntry& s = seq[i + offset];
+      for (const VarNode* buffer : s.touched) {
+        if (!touched.count(buffer)) {
+          touched.insert(buffer);
+          event_map_[s.stmt].gen.push_back(buffer);
+        }
+      }
+    }
+  }
+
+  /*!
+   * \brief Memory plan algorithm
+   * \param seq the linear pattern of storage access
+   * \param alloc_info
+   */
+  void PlanMemory(const std::vector<StmtEntry>& seq) {
+    std::unordered_set<const VarNode*> inplace_flag;
+
+    for (size_t i = 0; i < seq.size(); ++i) {
+      auto it = event_map_.find(seq[i].stmt);
+      // scope_pair_offset >= 0 means it is either
+      // - leaf stmt(offset = 0)
+      // - beginning of scope(offset < 0)
+      // In both cases, we need to handle the gen event correctly
+      if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) {
+        for (const VarNode* var : it->second.gen) {
+          ICHECK(dyn_shmem_allocs_.count(var));
+          const AllocateNode* alloc = dyn_shmem_allocs_[var];
+          StorageEntry* dst_entry = FindAlloc(alloc);
+          alloc_map_[var] = dst_entry;
+        }
+      }
+      // scope_pair_offset <= 0 means it is either
+      // - leaf stmt(offset = 0)
+      // - end of scope(offset < 0)
+      // In both cases, we need to handle the kill event correctly
+      if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
+        for (const VarNode* var : it->second.kill) {
+          this->Free(var);
+        }
+      }
+    }
+  }
+  /*!
+   * \brief Allocate new storage entry.
+   * \param op the allocate node
+   * \param the size of the allocation in bits
+   * \return the new storage entry
+   */
+  StorageEntry* NewAlloc(const AllocateNode* op, size_t const_nbits) {

Review comment:
       Why we need return a bare pointer?

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -90,42 +296,278 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
 
   PrimExpr VisitExpr_(const LoadNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      auto offset = GetBufferOffset(op->buffer_var, op->dtype);
-      auto index = StmtExprMutator::VisitExpr(op->index);
+      PrimExpr offset = GetBufferOffset(op->buffer_var, op->dtype);
+      PrimExpr index = StmtExprMutator::VisitExpr(op->index);
       return Load(op->dtype, merged_buf_var_, offset + index, op->predicate, op->span);
     }
     return StmtExprMutator::VisitExpr_(op);
   }
 
   Stmt VisitStmt_(const StoreNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      auto offset = GetBufferOffset(op->buffer_var, op->value->dtype);
-      auto index = StmtExprMutator::VisitExpr(op->index);
-      auto value = StmtExprMutator::VisitExpr(op->value);
+      PrimExpr offset = GetBufferOffset(op->buffer_var, op->value->dtype);
+      PrimExpr index = StmtExprMutator::VisitExpr(op->index);
+      PrimExpr value = StmtExprMutator::VisitExpr(op->value);
       return Store(merged_buf_var_, value, offset + index, op->predicate, op->span);
     }
     return StmtExprMutator::VisitStmt_(op);
   }
 
+  PrimExpr VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::tvm_access_ptr())) {
+      ICHECK_EQ(op->args.size(), 5U);
+      DataType dtype = op->args[0].dtype();
+      Var buffer = Downcast<Var>(op->args[1]);
+      if (!IsDynamicSharedMemory(buffer)) {
+        return StmtExprMutator::VisitExpr_(op);
+      }
+      PrimExpr extra_offset = GetBufferOffset(buffer, dtype);
+
+      PrimExpr offset = this->VisitExpr(op->args[2]);
+      PrimExpr extent = this->VisitExpr(op->args[3]);
+      return Call(op->dtype, op->op,
+                  {op->args[0], merged_buf_var_, extra_offset + offset, extent, op->args[4]});
+    } else {
+      return StmtExprMutator::VisitExpr_(op);
+    }
+  }
+
  private:
   PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) {
     auto it = buffer_byte_offsets_.find(buffer_var.get());
-    ICHECK(it != buffer_byte_offsets_.end());
+    ICHECK(it != buffer_byte_offsets_.end()) << buffer_var;

Review comment:
       ditto

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -36,46 +38,250 @@
 namespace tvm {
 namespace tir {
 
+using runtime::StorageRank;
+using runtime::StorageScope;
+
 bool IsDynamicSharedMemory(Var buffer_var) {
-  auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
   return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
 }
 
+/*!
+ * \brief collect the mapping from the buffer var to its allocate
+ */
 class AllocateCollector : public StmtExprVisitor {
  public:
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      dyn_shmem_allocs_.insert(op);
+      dyn_shmem_allocs_[op->buffer_var.get()] = op;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
+  // The mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+};
+
+// Find a linear pattern of storage access
+// Used for liveness analysis.
+// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
+// before_scope -> scope_body -> after_scope
+//
+// The linear_seq_ stores before_scope and after_scope.
+// The access to the arrays are stored at the after_scope point.
+//
+// Define "scope" as the body of For/thread_launch/IfThenElse
+// This pass tries to detect last point that we need to keep memory
+// alive under the same scope as allocate.
+// The storage need to be kept alive between allocate and last access.
+// The free point is only inserted at the same scope of allocate.
+//
+class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+ public:
+  /*! \brief record the touch hist of statment. */
+  struct StmtEntry {
+    // The statment
+    const Object* stmt;
+    // The index in the linear_seq_ to point to end of the nested scope.
+    // This is only set to non-zero if stmt is a nested scope.
+    // if offset > 0, means this is the begin, the end entry is current_index + offset
+    // if offset < 0, means this is the end, the begin entry is current_index + offset
+    int64_t scope_pair_offset{0};
+    // The buffer variables this statment touched.
+    std::vector<const VarNode*> touched;
+  };
+  // The scope of each allocation
+  struct AllocEntry {
+    // scope level
+    size_t level{0};
+    // allocation stmt
+    const AllocateNode* alloc{nullptr};
+  };
 
-  std::unordered_set<const AllocateNode*> dyn_shmem_allocs_;
+  void VisitStmt_(const AllocateNode* op) final {
+    size_t level = scope_.size();
+    const VarNode* buf = op->buffer_var.get();
+    alloc_info_[buf].alloc = op;
+    alloc_info_[buf].level = level;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+  void VisitStmt_(const StoreNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    // Add write access.
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size());
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitStmt_(const EvaluateNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitExpr_(const LoadNode* op) final {
+    // Add write access.
+    StmtExprVisitor::VisitExpr_(op);
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store.";
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+  }
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::address_of())) {
+      const LoadNode* l = op->args[0].as<LoadNode>();
+      this->VisitExpr(l->index);
+    } else {
+      StmtExprVisitor::VisitExpr_(op);
+    }
+  }
+  void VisitExpr_(const VarNode* buf) final {
+    // Directly reference to the variable count as a read.
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint;
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+  }
+  template <typename T>
+  void VisitNewScope(const T* op) {
+    scope_.push_back(StmtEntry());
+    StmtEntry e;
+    e.stmt = op;
+    int64_t begin_index = static_cast<int64_t>(linear_seq_.size());
+    // before scope.
+    linear_seq_.push_back(e);
+    StmtExprVisitor::VisitStmt_(op);
+    // after scope.
+    e.touched = std::move(scope_.back().touched);
+    scope_.pop_back();
+    int64_t end_index = static_cast<int64_t>(linear_seq_.size());
+    ICHECK_GT(end_index, begin_index);
+    e.scope_pair_offset = begin_index - end_index;
+    linear_seq_.push_back(e);
+    // record the pointer to end index.
+    ICHECK_NE(end_index, 0U);
+    linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
+  }
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Only record the outer most thread extent.
+    if (op->attr_key == attr::thread_extent && !in_thread_env_) {
+      in_thread_env_ = true;
+      VisitNewScope(op);
+      in_thread_env_ = false;
+    } else if (op->attr_key == attr::extern_scope) {
+      VisitNewScope(op);
+    } else if (op->attr_key == attr::virtual_thread) {
+      VisitNewScope(op);
+    } else {
+      StmtExprVisitor::VisitStmt_(op);
+    }
+  }
+  void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const ForNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const WhileNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); }
+
+  // linearized access sequence.
+  std::vector<StmtEntry> linear_seq_;
+  // The storage scope of each buffer
+  std::unordered_map<const VarNode*, AllocEntry> alloc_info_;

Review comment:
       Please keep it private if it is never be accessed outside.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9341: [CUDA] Support memory reuse for dynamic shared memory

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9341:
URL: https://github.com/apache/tvm/pull/9341#discussion_r739594556



##########
File path: tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py
##########
@@ -82,24 +82,25 @@ def test_matmul_ir(A, B, C):
         # Create a dynamic shared memory for the accumulation.
         # This is for testing merging dynamic shared memory alloctions with different data type.
         # In practice, there is no need to allocate a shared memory for C.
+        C_local = ib.allocate(C.dtype, (1), scope="local", name="C_local")

Review comment:
       ```suggestion
           C_local = ib.allocate(C.dtype, (1,), scope="local", name="C_local")
   ```
   

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -31,51 +31,256 @@
 #include <unordered_set>
 
 #include "../../runtime/thread_storage_scope.h"
+#include "../../support/arena.h"
 #include "ir_utils.h"
 
 namespace tvm {
 namespace tir {
 
+using runtime::StorageRank;
+using runtime::StorageScope;
+
 bool IsDynamicSharedMemory(Var buffer_var) {
-  auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
   return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
 }
 
+/*!
+ * \brief collect the mapping from the buffer var to its allocate
+ */
 class AllocateCollector : public StmtExprVisitor {
  public:
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      dyn_shmem_allocs_.insert(op);
+      dyn_shmem_allocs_[op->buffer_var.get()] = op;
+    }
+    StmtExprVisitor::VisitStmt_(op);
+  }
+  // The mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+};
+
+// Find a linear pattern of storage access
+// Used for liveness analysis.
+// "linear" means fitting a complex access pattern into an array of StmtEntry
+//
+// Define "scope" as the body of For/thread_launch/IfThenElse
+// Composite scopes(loop/thread_launch/IfThen) is represented by three StmtEntry:
+// before_scope -> scope_body -> after_scope
+//
+// This pass tries to detect last point that we need to keep memory
+// alive under the same scope as Allocate.
+// The storage need to be kept alive between Allocate and last access.
+// The free point is only inserted at the same scope of Allocate.
+//
+class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+ public:
+  /*! \brief record the touch list of statment. */

Review comment:
       ```suggestion
     /*! \brief record the touch list of statement. */
   ```
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Hzfengsy merged pull request #9341: [CUDA] Support memory reuse for dynamic shared memory

Posted by GitBox <gi...@apache.org>.
Hzfengsy merged pull request #9341:
URL: https://github.com/apache/tvm/pull/9341


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] MasterJH5574 commented on a change in pull request #9341: [CUDA] Support memory reuse for dynamic shared memory

Posted by GitBox <gi...@apache.org>.
MasterJH5574 commented on a change in pull request #9341:
URL: https://github.com/apache/tvm/pull/9341#discussion_r735741835



##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -31,51 +31,256 @@
 #include <unordered_set>
 
 #include "../../runtime/thread_storage_scope.h"
+#include "../../support/arena.h"
 #include "ir_utils.h"
 
 namespace tvm {
 namespace tir {
 
+using runtime::StorageRank;
+using runtime::StorageScope;
+
 bool IsDynamicSharedMemory(Var buffer_var) {
-  auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
   return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
 }
 
+/*!
+ * \brief collect the mapping from the buffer var to its allocate
+ */
 class AllocateCollector : public StmtExprVisitor {
  public:
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      dyn_shmem_allocs_.insert(op);
+      dyn_shmem_allocs_[op->buffer_var.get()] = op;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
+  // The mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+};
+
+// Find a linear pattern of storage access
+// Used for liveness analysis.
+// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
+// before_scope -> scope_body -> after_scope
+//
+// The linear_seq_ stores before_scope and after_scope.
+// The access to the arrays are stored at the after_scope point.
+//
+// Define "scope" as the body of For/thread_launch/IfThenElse
+// This pass tries to detect last point that we need to keep memory
+// alive under the same scope as allocate.
+// The storage need to be kept alive between allocate and last access.
+// The free point is only inserted at the same scope of allocate.
+//
+class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+ public:
+  /*! \brief record the touch hist of statment. */
+  struct StmtEntry {
+    // The statment
+    const Object* stmt;
+    // The index in the linear_seq_ to point to end of the nested scope.
+    // This is only set to non-zero if stmt is a nested scope.
+    // if offset > 0, means this is the begin, the end entry is current_index + offset
+    // if offset < 0, means this is the end, the begin entry is current_index + offset
+    int64_t scope_pair_offset{0};
+    // The buffer variables this statment touched.

Review comment:
       ```suggestion
       // The buffer variables this statement touched.
   ```

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -36,46 +38,250 @@
 namespace tvm {
 namespace tir {
 
+using runtime::StorageRank;
+using runtime::StorageScope;
+
 bool IsDynamicSharedMemory(Var buffer_var) {
-  auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
   return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
 }
 
+/*!
+ * \brief collect the mapping from the buffer var to its allocate
+ */
 class AllocateCollector : public StmtExprVisitor {
  public:
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      dyn_shmem_allocs_.insert(op);
+      dyn_shmem_allocs_[op->buffer_var.get()] = op;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
+  // The mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+};
+
+// Find a linear pattern of storage access
+// Used for liveness analysis.
+// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
+// before_scope -> scope_body -> after_scope
+//
+// The linear_seq_ stores before_scope and after_scope.
+// The access to the arrays are stored at the after_scope point.
+//
+// Define "scope" as the body of For/thread_launch/IfThenElse
+// This pass tries to detect last point that we need to keep memory
+// alive under the same scope as allocate.
+// The storage need to be kept alive between allocate and last access.
+// The free point is only inserted at the same scope of allocate.
+//
+class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+ public:
+  /*! \brief record the touch hist of statment. */

Review comment:
       Do you mean "list" or "history" here?
   ```suggestion
     /*! \brief record the touch list of statement. */
   ```

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -31,51 +31,256 @@
 #include <unordered_set>
 
 #include "../../runtime/thread_storage_scope.h"
+#include "../../support/arena.h"
 #include "ir_utils.h"
 
 namespace tvm {
 namespace tir {
 
+using runtime::StorageRank;
+using runtime::StorageScope;
+
 bool IsDynamicSharedMemory(Var buffer_var) {
-  auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
   return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
 }
 
+/*!
+ * \brief collect the mapping from the buffer var to its allocate
+ */
 class AllocateCollector : public StmtExprVisitor {
  public:
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      dyn_shmem_allocs_.insert(op);
+      dyn_shmem_allocs_[op->buffer_var.get()] = op;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
+  // The mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+};
+
+// Find a linear pattern of storage access
+// Used for liveness analysis.
+// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
+// before_scope -> scope_body -> after_scope
+//
+// The linear_seq_ stores before_scope and after_scope.
+// The access to the arrays are stored at the after_scope point.
+//
+// Define "scope" as the body of For/thread_launch/IfThenElse
+// This pass tries to detect last point that we need to keep memory
+// alive under the same scope as allocate.
+// The storage need to be kept alive between allocate and last access.
+// The free point is only inserted at the same scope of allocate.
+//
+class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+ public:
+  /*! \brief record the touch hist of statment. */
+  struct StmtEntry {
+    // The statment

Review comment:
       ```suggestion
       // The statement
   ```

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -31,51 +31,256 @@
 #include <unordered_set>
 
 #include "../../runtime/thread_storage_scope.h"
+#include "../../support/arena.h"
 #include "ir_utils.h"
 
 namespace tvm {
 namespace tir {
 
+using runtime::StorageRank;
+using runtime::StorageScope;
+
 bool IsDynamicSharedMemory(Var buffer_var) {
-  auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
   return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
 }
 
+/*!
+ * \brief collect the mapping from the buffer var to its allocate
+ */
 class AllocateCollector : public StmtExprVisitor {
  public:
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      dyn_shmem_allocs_.insert(op);
+      dyn_shmem_allocs_[op->buffer_var.get()] = op;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
+  // The mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+};
+
+// Find a linear pattern of storage access
+// Used for liveness analysis.
+// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
+// before_scope -> scope_body -> after_scope
+//
+// The linear_seq_ stores before_scope and after_scope.
+// The access to the arrays are stored at the after_scope point.
+//
+// Define "scope" as the body of For/thread_launch/IfThenElse
+// This pass tries to detect last point that we need to keep memory
+// alive under the same scope as allocate.
+// The storage need to be kept alive between allocate and last access.
+// The free point is only inserted at the same scope of allocate.
+//
+class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+ public:

Review comment:
       I'm confused with the docs here for several points:
   * Seems that the definition of "scope" is after the sentence "Composite scopes(loop/thread_launch/IfThen) is represented by two points...". Should the definition be stated earlier?
   * What does "linear" mean? What is `linear_seq_`?
   * What is `before_scope`, `scope_body` and `after_scope`?
   * "allocate" itself is a verb. So I suggest using "AllocateNode" or "Allocate". Directly using "allocate" is confusing.

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -31,51 +31,256 @@
 #include <unordered_set>
 
 #include "../../runtime/thread_storage_scope.h"
+#include "../../support/arena.h"
 #include "ir_utils.h"
 
 namespace tvm {
 namespace tir {
 
+using runtime::StorageRank;
+using runtime::StorageScope;
+
 bool IsDynamicSharedMemory(Var buffer_var) {
-  auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
   return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
 }
 
+/*!
+ * \brief collect the mapping from the buffer var to its allocate
+ */
 class AllocateCollector : public StmtExprVisitor {
  public:
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      dyn_shmem_allocs_.insert(op);
+      dyn_shmem_allocs_[op->buffer_var.get()] = op;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
+  // The mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+};
+
+// Find a linear pattern of storage access
+// Used for liveness analysis.
+// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
+// before_scope -> scope_body -> after_scope
+//
+// The linear_seq_ stores before_scope and after_scope.
+// The access to the arrays are stored at the after_scope point.
+//
+// Define "scope" as the body of For/thread_launch/IfThenElse
+// This pass tries to detect last point that we need to keep memory
+// alive under the same scope as allocate.
+// The storage need to be kept alive between allocate and last access.
+// The free point is only inserted at the same scope of allocate.
+//
+class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+ public:
+  /*! \brief record the touch hist of statment. */
+  struct StmtEntry {
+    // The statment
+    const Object* stmt;
+    // The index in the linear_seq_ to point to end of the nested scope.
+    // This is only set to non-zero if stmt is a nested scope.
+    // if offset > 0, means this is the begin, the end entry is current_index + offset
+    // if offset < 0, means this is the end, the begin entry is current_index + offset
+    int64_t scope_pair_offset{0};
+    // The buffer variables this statment touched.
+    std::vector<const VarNode*> touched;
+  };
+  // The scope of each allocation
+  struct AllocEntry {
+    // the level in the scope stack
+    size_t level{0};
+    // allocation stmt
+    const AllocateNode* alloc{nullptr};
+  };
 
-  std::unordered_set<const AllocateNode*> dyn_shmem_allocs_;
+  void VisitStmt_(const AllocateNode* op) final {
+    size_t level = scope_.size();
+    const VarNode* buf = op->buffer_var.get();
+    alloc_info_[buf].alloc = op;
+    alloc_info_[buf].level = level;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+  void VisitStmt_(const StoreNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    // Add write access.
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size());
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitStmt_(const EvaluateNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitExpr_(const LoadNode* op) final {
+    // Add write access.
+    StmtExprVisitor::VisitExpr_(op);
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store.";
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+  }
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::address_of())) {
+      const LoadNode* l = op->args[0].as<LoadNode>();
+      this->VisitExpr(l->index);
+    } else {
+      StmtExprVisitor::VisitExpr_(op);
+    }
+  }
+  void VisitExpr_(const VarNode* buf) final {
+    // Directly reference to the variable count as a read.
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size());
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+  }
+  template <typename T>
+  void VisitNewScope(const T* op) {
+    scope_.push_back(StmtEntry());
+    StmtEntry e;
+    e.stmt = op;
+    int64_t begin_index = static_cast<int64_t>(linear_seq_.size());
+    // before scope.
+    linear_seq_.push_back(e);
+    StmtExprVisitor::VisitStmt_(op);
+    // after scope.
+    e.touched = std::move(scope_.back().touched);
+    scope_.pop_back();
+    int64_t end_index = static_cast<int64_t>(linear_seq_.size());
+    ICHECK_GT(end_index, begin_index);
+    e.scope_pair_offset = begin_index - end_index;
+    linear_seq_.push_back(e);
+    // record the pointer to end index.
+    ICHECK_NE(end_index, 0U);
+    linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
+  }
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Only record the outer most thread extent.
+    if (op->attr_key == attr::thread_extent && !in_thread_env_) {
+      in_thread_env_ = true;
+      VisitNewScope(op);
+      in_thread_env_ = false;
+    } else if (op->attr_key == attr::extern_scope) {
+      VisitNewScope(op);
+    } else if (op->attr_key == attr::virtual_thread) {
+      VisitNewScope(op);
+    } else {
+      StmtExprVisitor::VisitStmt_(op);
+    }
+  }
+  void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const ForNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const WhileNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); }
+
+  // linearized access sequence.
+  std::vector<StmtEntry> linear_seq_;
+  // The storage scope of each buffer
+  std::unordered_map<const VarNode*, AllocEntry> alloc_info_;
+
+ private:
+  // Whether already in thread env.
+  bool in_thread_env_{false};
+  // The scope stack.
+  std::vector<StmtEntry> scope_;
 };
 
+/*!
+ * \brief merge the buffers whose live range has no intersection and rewrite the body
+ */
 class DynamicSharedMemoryRewriter : public StmtExprMutator {
  public:
   explicit DynamicSharedMemoryRewriter(
-      const std::unordered_set<const AllocateNode*>& dyn_shmem_allocs)
+      const std::unordered_map<const VarNode*, const AllocateNode*>& dyn_shmem_allocs)
       : dyn_shmem_allocs_{dyn_shmem_allocs} {}
 
+  /*!
+   * \brief plan the memory reuse for all the buffer allocated in the statement
+   * \param stmt the statement
+   */
+  void PlanReuse(const Stmt& stmt) {
+    DynSharedMemLinearAccessPatternFinder finder;
+    finder(stmt);
+    this->LivenessAnalysis(finder.linear_seq_);
+    this->PlanMemory(finder.linear_seq_);
+  }
+
   Stmt VisitStmt_(const AttrStmtNode* op) final {

Review comment:
       Shouldn't these `VisitStmt_(...)` be private?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on pull request #9341: [CUDA] Support memory reuse for dynamic shared memory

Posted by GitBox <gi...@apache.org>.
masahi commented on pull request #9341:
URL: https://github.com/apache/tvm/pull/9341#issuecomment-948934302


   Hmm, for reuse on the constant-size dynamic shmem, I thought the existing `storage_write` pass would do the job, https://github.com/apache/tvm/pull/8571#issuecomment-889541656. I wonder why it is not sufficient?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jinhongyii commented on a change in pull request #9341: [CUDA] Support memory reuse for dynamic shared memory

Posted by GitBox <gi...@apache.org>.
jinhongyii commented on a change in pull request #9341:
URL: https://github.com/apache/tvm/pull/9341#discussion_r735568966



##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -36,46 +38,250 @@
 namespace tvm {
 namespace tir {
 
+using runtime::StorageRank;
+using runtime::StorageScope;
+
 bool IsDynamicSharedMemory(Var buffer_var) {
-  auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
   return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
 }
 
+/*!
+ * \brief collect the mapping from the buffer var to its allocate
+ */
 class AllocateCollector : public StmtExprVisitor {
  public:
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      dyn_shmem_allocs_.insert(op);
+      dyn_shmem_allocs_[op->buffer_var.get()] = op;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
+  // The mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+};
+
+// Find a linear pattern of storage access
+// Used for liveness analysis.
+// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
+// before_scope -> scope_body -> after_scope
+//
+// The linear_seq_ stores before_scope and after_scope.
+// The access to the arrays are stored at the after_scope point.
+//
+// Define "scope" as the body of For/thread_launch/IfThenElse
+// This pass tries to detect last point that we need to keep memory
+// alive under the same scope as allocate.
+// The storage need to be kept alive between allocate and last access.
+// The free point is only inserted at the same scope of allocate.
+//
+class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+ public:
+  /*! \brief record the touch hist of statment. */
+  struct StmtEntry {
+    // The statment
+    const Object* stmt;
+    // The index in the linear_seq_ to point to end of the nested scope.
+    // This is only set to non-zero if stmt is a nested scope.
+    // if offset > 0, means this is the begin, the end entry is current_index + offset
+    // if offset < 0, means this is the end, the begin entry is current_index + offset
+    int64_t scope_pair_offset{0};
+    // The buffer variables this statment touched.
+    std::vector<const VarNode*> touched;
+  };
+  // The scope of each allocation
+  struct AllocEntry {
+    // scope level
+    size_t level{0};
+    // allocation stmt
+    const AllocateNode* alloc{nullptr};
+  };
 
-  std::unordered_set<const AllocateNode*> dyn_shmem_allocs_;
+  void VisitStmt_(const AllocateNode* op) final {
+    size_t level = scope_.size();
+    const VarNode* buf = op->buffer_var.get();
+    alloc_info_[buf].alloc = op;
+    alloc_info_[buf].level = level;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+  void VisitStmt_(const StoreNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    // Add write access.
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size());
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitStmt_(const EvaluateNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitExpr_(const LoadNode* op) final {
+    // Add write access.
+    StmtExprVisitor::VisitExpr_(op);
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store.";
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+  }
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::address_of())) {
+      const LoadNode* l = op->args[0].as<LoadNode>();
+      this->VisitExpr(l->index);
+    } else {
+      StmtExprVisitor::VisitExpr_(op);
+    }
+  }
+  void VisitExpr_(const VarNode* buf) final {
+    // Directly reference to the variable count as a read.
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint;
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+  }
+  template <typename T>
+  void VisitNewScope(const T* op) {
+    scope_.push_back(StmtEntry());
+    StmtEntry e;
+    e.stmt = op;
+    int64_t begin_index = static_cast<int64_t>(linear_seq_.size());
+    // before scope.
+    linear_seq_.push_back(e);
+    StmtExprVisitor::VisitStmt_(op);
+    // after scope.
+    e.touched = std::move(scope_.back().touched);
+    scope_.pop_back();
+    int64_t end_index = static_cast<int64_t>(linear_seq_.size());
+    ICHECK_GT(end_index, begin_index);
+    e.scope_pair_offset = begin_index - end_index;
+    linear_seq_.push_back(e);
+    // record the pointer to end index.
+    ICHECK_NE(end_index, 0U);
+    linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
+  }
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Only record the outer most thread extent.
+    if (op->attr_key == attr::thread_extent && !in_thread_env_) {
+      in_thread_env_ = true;
+      VisitNewScope(op);
+      in_thread_env_ = false;
+    } else if (op->attr_key == attr::extern_scope) {
+      VisitNewScope(op);
+    } else if (op->attr_key == attr::virtual_thread) {
+      VisitNewScope(op);
+    } else {
+      StmtExprVisitor::VisitStmt_(op);
+    }
+  }
+  void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const ForNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const WhileNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); }
+
+  // linearized access sequence.
+  std::vector<StmtEntry> linear_seq_;
+  // The storage scope of each buffer
+  std::unordered_map<const VarNode*, AllocEntry> alloc_info_;

Review comment:
       it is accessed in the rewriter




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org