You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/10/25 04:33:36 UTC

[GitHub] [tvm] Hzfengsy commented on a change in pull request #9341: [CUDA] Support memory reuse for dynamic shared memory

Hzfengsy commented on a change in pull request #9341:
URL: https://github.com/apache/tvm/pull/9341#discussion_r735244224



##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -27,6 +27,8 @@
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/transform.h>
 
+#include <list>
+#include <map>

Review comment:
       Please delete unused packages `list` and `map`

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -36,46 +38,250 @@
 namespace tvm {
 namespace tir {
 
+using runtime::StorageRank;
+using runtime::StorageScope;
+
 bool IsDynamicSharedMemory(Var buffer_var) {
-  auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
   return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
 }
 
+/*!
+ * \brief collect the mapping from the buffer var to its allocate
+ */
 class AllocateCollector : public StmtExprVisitor {
  public:
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      dyn_shmem_allocs_.insert(op);
+      dyn_shmem_allocs_[op->buffer_var.get()] = op;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
+  // The mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+};
+
+// Find a linear pattern of storage access
+// Used for liveness analysis.
+// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
+// before_scope -> scope_body -> after_scope
+//
+// The linear_seq_ stores before_scope and after_scope.
+// The access to the arrays are stored at the after_scope point.
+//
+// Define "scope" as the body of For/thread_launch/IfThenElse
+// This pass tries to detect last point that we need to keep memory
+// alive under the same scope as allocate.
+// The storage need to be kept alive between allocate and last access.
+// The free point is only inserted at the same scope of allocate.
+//
+class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+ public:
+  /*! \brief record the touch hist of statment. */
+  struct StmtEntry {
+    // The statment
+    const Object* stmt;
+    // The index in the linear_seq_ to point to end of the nested scope.
+    // This is only set to non-zero if stmt is a nested scope.
+    // if offset > 0, means this is the begin, the end entry is current_index + offset
+    // if offset < 0, means this is the end, the begin entry is current_index + offset
+    int64_t scope_pair_offset{0};
+    // The buffer variables this statment touched.
+    std::vector<const VarNode*> touched;
+  };
+  // The scope of each allocation
+  struct AllocEntry {
+    // scope level
+    size_t level{0};
+    // allocation stmt
+    const AllocateNode* alloc{nullptr};
+  };
 
-  std::unordered_set<const AllocateNode*> dyn_shmem_allocs_;
+  void VisitStmt_(const AllocateNode* op) final {
+    size_t level = scope_.size();
+    const VarNode* buf = op->buffer_var.get();
+    alloc_info_[buf].alloc = op;
+    alloc_info_[buf].level = level;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+  void VisitStmt_(const StoreNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    // Add write access.
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size());
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitStmt_(const EvaluateNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitExpr_(const LoadNode* op) final {
+    // Add write access.
+    StmtExprVisitor::VisitExpr_(op);
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store.";
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+  }
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::address_of())) {
+      const LoadNode* l = op->args[0].as<LoadNode>();
+      this->VisitExpr(l->index);
+    } else {
+      StmtExprVisitor::VisitExpr_(op);
+    }
+  }
+  void VisitExpr_(const VarNode* buf) final {
+    // Directly reference to the variable count as a read.
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint;
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+  }
+  template <typename T>
+  void VisitNewScope(const T* op) {
+    scope_.push_back(StmtEntry());
+    StmtEntry e;
+    e.stmt = op;
+    int64_t begin_index = static_cast<int64_t>(linear_seq_.size());
+    // before scope.
+    linear_seq_.push_back(e);
+    StmtExprVisitor::VisitStmt_(op);
+    // after scope.
+    e.touched = std::move(scope_.back().touched);
+    scope_.pop_back();
+    int64_t end_index = static_cast<int64_t>(linear_seq_.size());
+    ICHECK_GT(end_index, begin_index);
+    e.scope_pair_offset = begin_index - end_index;
+    linear_seq_.push_back(e);
+    // record the pointer to end index.
+    ICHECK_NE(end_index, 0U);
+    linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
+  }
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Only record the outer most thread extent.
+    if (op->attr_key == attr::thread_extent && !in_thread_env_) {
+      in_thread_env_ = true;
+      VisitNewScope(op);
+      in_thread_env_ = false;
+    } else if (op->attr_key == attr::extern_scope) {
+      VisitNewScope(op);
+    } else if (op->attr_key == attr::virtual_thread) {
+      VisitNewScope(op);
+    } else {
+      StmtExprVisitor::VisitStmt_(op);
+    }
+  }
+  void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const ForNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const WhileNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); }
+
+  // linearized access sequence.
+  std::vector<StmtEntry> linear_seq_;
+  // The storage scope of each buffer
+  std::unordered_map<const VarNode*, AllocEntry> alloc_info_;
+
+ private:
+  // Whether already in thread env.
+  bool in_thread_env_{false};
+  // The scope stack.
+  std::vector<StmtEntry> scope_;
 };
 
+/*!
+ * \brief merge the buffers whose live range has no intersection and rewrite the body
+ */
 class DynamicSharedMemoryRewriter : public StmtExprMutator {
  public:
   explicit DynamicSharedMemoryRewriter(
-      const std::unordered_set<const AllocateNode*>& dyn_shmem_allocs)
+      const std::unordered_map<const VarNode*, const AllocateNode*>& dyn_shmem_allocs)
       : dyn_shmem_allocs_{dyn_shmem_allocs} {}
 
+  /*!
+   * \brief plan the memory reuse for all the buffer allocated in the statement
+   * @param stmt the statement

Review comment:
       ```suggestion
      * \param stmt the statement.
   ```

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -36,46 +38,250 @@
 namespace tvm {
 namespace tir {
 
+using runtime::StorageRank;
+using runtime::StorageScope;
+
 bool IsDynamicSharedMemory(Var buffer_var) {
-  auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
   return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
 }
 
+/*!
+ * \brief collect the mapping from the buffer var to its allocate
+ */
 class AllocateCollector : public StmtExprVisitor {
  public:
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      dyn_shmem_allocs_.insert(op);
+      dyn_shmem_allocs_[op->buffer_var.get()] = op;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
+  // The mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+};
+
+// Find a linear pattern of storage access
+// Used for liveness analysis.
+// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
+// before_scope -> scope_body -> after_scope
+//
+// The linear_seq_ stores before_scope and after_scope.
+// The access to the arrays are stored at the after_scope point.
+//
+// Define "scope" as the body of For/thread_launch/IfThenElse
+// This pass tries to detect last point that we need to keep memory
+// alive under the same scope as allocate.
+// The storage need to be kept alive between allocate and last access.
+// The free point is only inserted at the same scope of allocate.
+//
+class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+ public:
+  /*! \brief record the touch hist of statment. */
+  struct StmtEntry {
+    // The statment
+    const Object* stmt;
+    // The index in the linear_seq_ to point to end of the nested scope.
+    // This is only set to non-zero if stmt is a nested scope.
+    // if offset > 0, means this is the begin, the end entry is current_index + offset
+    // if offset < 0, means this is the end, the begin entry is current_index + offset
+    int64_t scope_pair_offset{0};
+    // The buffer variables this statment touched.
+    std::vector<const VarNode*> touched;
+  };
+  // The scope of each allocation
+  struct AllocEntry {
+    // scope level
+    size_t level{0};
+    // allocation stmt
+    const AllocateNode* alloc{nullptr};
+  };
 
-  std::unordered_set<const AllocateNode*> dyn_shmem_allocs_;
+  void VisitStmt_(const AllocateNode* op) final {
+    size_t level = scope_.size();
+    const VarNode* buf = op->buffer_var.get();
+    alloc_info_[buf].alloc = op;
+    alloc_info_[buf].level = level;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+  void VisitStmt_(const StoreNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    // Add write access.
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size());
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitStmt_(const EvaluateNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitExpr_(const LoadNode* op) final {
+    // Add write access.
+    StmtExprVisitor::VisitExpr_(op);
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store.";
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+  }
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::address_of())) {
+      const LoadNode* l = op->args[0].as<LoadNode>();
+      this->VisitExpr(l->index);
+    } else {
+      StmtExprVisitor::VisitExpr_(op);
+    }
+  }
+  void VisitExpr_(const VarNode* buf) final {
+    // Directly reference to the variable count as a read.
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint;

Review comment:
       Either please add complete error msg or do not print any information

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -36,46 +38,250 @@
 namespace tvm {
 namespace tir {
 
+using runtime::StorageRank;
+using runtime::StorageScope;
+
 bool IsDynamicSharedMemory(Var buffer_var) {
-  auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
   return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
 }
 
+/*!
+ * \brief collect the mapping from the buffer var to its allocate
+ */
 class AllocateCollector : public StmtExprVisitor {
  public:
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      dyn_shmem_allocs_.insert(op);
+      dyn_shmem_allocs_[op->buffer_var.get()] = op;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
+  // The mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+};
+
+// Find a linear pattern of storage access
+// Used for liveness analysis.
+// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
+// before_scope -> scope_body -> after_scope
+//
+// The linear_seq_ stores before_scope and after_scope.
+// The access to the arrays are stored at the after_scope point.
+//
+// Define "scope" as the body of For/thread_launch/IfThenElse
+// This pass tries to detect last point that we need to keep memory
+// alive under the same scope as allocate.
+// The storage need to be kept alive between allocate and last access.
+// The free point is only inserted at the same scope of allocate.
+//
+class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+ public:
+  /*! \brief record the touch hist of statment. */
+  struct StmtEntry {
+    // The statment
+    const Object* stmt;
+    // The index in the linear_seq_ to point to end of the nested scope.
+    // This is only set to non-zero if stmt is a nested scope.
+    // if offset > 0, means this is the begin, the end entry is current_index + offset
+    // if offset < 0, means this is the end, the begin entry is current_index + offset
+    int64_t scope_pair_offset{0};
+    // The buffer variables this statment touched.
+    std::vector<const VarNode*> touched;
+  };
+  // The scope of each allocation
+  struct AllocEntry {
+    // scope level
+    size_t level{0};
+    // allocation stmt
+    const AllocateNode* alloc{nullptr};
+  };
 
-  std::unordered_set<const AllocateNode*> dyn_shmem_allocs_;
+  void VisitStmt_(const AllocateNode* op) final {
+    size_t level = scope_.size();
+    const VarNode* buf = op->buffer_var.get();
+    alloc_info_[buf].alloc = op;
+    alloc_info_[buf].level = level;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+  void VisitStmt_(const StoreNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    // Add write access.
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size());
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitStmt_(const EvaluateNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitExpr_(const LoadNode* op) final {
+    // Add write access.
+    StmtExprVisitor::VisitExpr_(op);
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store.";
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+  }
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::address_of())) {
+      const LoadNode* l = op->args[0].as<LoadNode>();
+      this->VisitExpr(l->index);
+    } else {
+      StmtExprVisitor::VisitExpr_(op);
+    }
+  }
+  void VisitExpr_(const VarNode* buf) final {
+    // Directly reference to the variable count as a read.
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint;
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+  }
+  template <typename T>
+  void VisitNewScope(const T* op) {
+    scope_.push_back(StmtEntry());
+    StmtEntry e;
+    e.stmt = op;
+    int64_t begin_index = static_cast<int64_t>(linear_seq_.size());
+    // before scope.
+    linear_seq_.push_back(e);
+    StmtExprVisitor::VisitStmt_(op);
+    // after scope.
+    e.touched = std::move(scope_.back().touched);
+    scope_.pop_back();
+    int64_t end_index = static_cast<int64_t>(linear_seq_.size());
+    ICHECK_GT(end_index, begin_index);
+    e.scope_pair_offset = begin_index - end_index;
+    linear_seq_.push_back(e);
+    // record the pointer to end index.
+    ICHECK_NE(end_index, 0U);
+    linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
+  }
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Only record the outer most thread extent.
+    if (op->attr_key == attr::thread_extent && !in_thread_env_) {
+      in_thread_env_ = true;
+      VisitNewScope(op);
+      in_thread_env_ = false;
+    } else if (op->attr_key == attr::extern_scope) {
+      VisitNewScope(op);
+    } else if (op->attr_key == attr::virtual_thread) {
+      VisitNewScope(op);
+    } else {
+      StmtExprVisitor::VisitStmt_(op);
+    }
+  }
+  void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const ForNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const WhileNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); }
+
+  // linearized access sequence.
+  std::vector<StmtEntry> linear_seq_;
+  // The storage scope of each buffer
+  std::unordered_map<const VarNode*, AllocEntry> alloc_info_;
+
+ private:
+  // Whether already in thread env.
+  bool in_thread_env_{false};
+  // The scope stack.
+  std::vector<StmtEntry> scope_;
 };
 
+/*!
+ * \brief merge the buffers whose live range has no intersection and rewrite the body
+ */
 class DynamicSharedMemoryRewriter : public StmtExprMutator {
  public:
   explicit DynamicSharedMemoryRewriter(
-      const std::unordered_set<const AllocateNode*>& dyn_shmem_allocs)
+      const std::unordered_map<const VarNode*, const AllocateNode*>& dyn_shmem_allocs)
       : dyn_shmem_allocs_{dyn_shmem_allocs} {}
 
+  /*!
+   * \brief plan the memory reuse for all the buffer allocated in the statement
+   * @param stmt the statement
+   */
+  void PlanReuse(Stmt stmt) {

Review comment:
       ```suggestion
     void PlanReuse(const Stmt& stmt) {
   ```

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -36,46 +38,250 @@
 namespace tvm {
 namespace tir {
 
+using runtime::StorageRank;
+using runtime::StorageScope;
+
 bool IsDynamicSharedMemory(Var buffer_var) {
-  auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
   return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
 }
 
+/*!
+ * \brief collect the mapping from the buffer var to its allocate
+ */
 class AllocateCollector : public StmtExprVisitor {
  public:
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      dyn_shmem_allocs_.insert(op);
+      dyn_shmem_allocs_[op->buffer_var.get()] = op;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
+  // The mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+};
+
+// Find a linear pattern of storage access
+// Used for liveness analysis.
+// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
+// before_scope -> scope_body -> after_scope
+//
+// The linear_seq_ stores before_scope and after_scope.
+// The access to the arrays are stored at the after_scope point.
+//
+// Define "scope" as the body of For/thread_launch/IfThenElse
+// This pass tries to detect last point that we need to keep memory
+// alive under the same scope as allocate.
+// The storage need to be kept alive between allocate and last access.
+// The free point is only inserted at the same scope of allocate.
+//
+class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+ public:
+  /*! \brief record the touch hist of statment. */
+  struct StmtEntry {
+    // The statment
+    const Object* stmt;
+    // The index in the linear_seq_ to point to end of the nested scope.
+    // This is only set to non-zero if stmt is a nested scope.
+    // if offset > 0, means this is the begin, the end entry is current_index + offset
+    // if offset < 0, means this is the end, the begin entry is current_index + offset
+    int64_t scope_pair_offset{0};
+    // The buffer variables this statment touched.
+    std::vector<const VarNode*> touched;
+  };
+  // The scope of each allocation
+  struct AllocEntry {
+    // scope level

Review comment:
       Please document a bit more about `level`

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -90,42 +296,278 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
 
   PrimExpr VisitExpr_(const LoadNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      auto offset = GetBufferOffset(op->buffer_var, op->dtype);
-      auto index = StmtExprMutator::VisitExpr(op->index);
+      PrimExpr offset = GetBufferOffset(op->buffer_var, op->dtype);
+      PrimExpr index = StmtExprMutator::VisitExpr(op->index);
       return Load(op->dtype, merged_buf_var_, offset + index, op->predicate, op->span);
     }
     return StmtExprMutator::VisitExpr_(op);
   }
 
   Stmt VisitStmt_(const StoreNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      auto offset = GetBufferOffset(op->buffer_var, op->value->dtype);
-      auto index = StmtExprMutator::VisitExpr(op->index);
-      auto value = StmtExprMutator::VisitExpr(op->value);
+      PrimExpr offset = GetBufferOffset(op->buffer_var, op->value->dtype);
+      PrimExpr index = StmtExprMutator::VisitExpr(op->index);
+      PrimExpr value = StmtExprMutator::VisitExpr(op->value);
       return Store(merged_buf_var_, value, offset + index, op->predicate, op->span);
     }
     return StmtExprMutator::VisitStmt_(op);
   }
 
+  PrimExpr VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::tvm_access_ptr())) {
+      ICHECK_EQ(op->args.size(), 5U);
+      DataType dtype = op->args[0].dtype();
+      Var buffer = Downcast<Var>(op->args[1]);
+      if (!IsDynamicSharedMemory(buffer)) {
+        return StmtExprMutator::VisitExpr_(op);
+      }
+      PrimExpr extra_offset = GetBufferOffset(buffer, dtype);
+
+      PrimExpr offset = this->VisitExpr(op->args[2]);
+      PrimExpr extent = this->VisitExpr(op->args[3]);
+      return Call(op->dtype, op->op,
+                  {op->args[0], merged_buf_var_, extra_offset + offset, extent, op->args[4]});
+    } else {
+      return StmtExprMutator::VisitExpr_(op);
+    }
+  }
+
  private:
   PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) {
     auto it = buffer_byte_offsets_.find(buffer_var.get());
-    ICHECK(it != buffer_byte_offsets_.end());
+    ICHECK(it != buffer_byte_offsets_.end()) << buffer_var;
     return indexdiv(it->second, dtype.bytes());
   }
 
+  using StmtEntry = DynSharedMemLinearAccessPatternFinder::StmtEntry;
+  struct StorageEntry {
+    // The constant size of the buffer in bits, only used if it is constant
+    uint64_t const_nbits{0};
+    // Allocs that shares this entry.
+    // The inner vector means a "layer"
+    // For example, it we need to allocate C in the memory of A and B:
+    // |  A: 4096 bytes |  B: 4096 bytes |
+    // |            C: 8192 bytes        |
+    // Then the allocs = {{A, B}, {C}}
+    std::vector<std::vector<const VarNode*>> allocs;
+  };
+
+  // Event entry in liveness analysis
+  struct EventEntry {
+    // variables we generate
+    std::vector<const VarNode*> gen;
+    // variables we kill
+    std::vector<const VarNode*> kill;
+  };
+
+  /*!
+   * \brief Liveness analysis to find gen and kill point of each variable.
+   * \param seq the linear pattern of storage access
+   */
+  void LivenessAnalysis(const std::vector<StmtEntry>& seq) {
+    // find kill point, do a reverse linear scan.
+    std::unordered_set<const VarNode*> touched;
+    for (size_t i = seq.size(); i != 0; --i) {
+      const StmtEntry& s = seq[i - 1];
+      for (const VarNode* buffer : s.touched) {
+        if (!touched.count(buffer)) {
+          touched.insert(buffer);
+          event_map_[s.stmt].kill.push_back(buffer);
+        }
+      }
+    }
+    // find gen point, do forward scan
+    touched.clear();
+    for (size_t i = 0; i < seq.size(); ++i) {
+      int64_t offset = seq[i].scope_pair_offset;
+      if (offset < 0) continue;
+      const StmtEntry& s = seq[i + offset];
+      for (const VarNode* buffer : s.touched) {
+        if (!touched.count(buffer)) {
+          touched.insert(buffer);
+          event_map_[s.stmt].gen.push_back(buffer);
+        }
+      }
+    }
+  }
+
+  /*!
+   * \brief Memory plan algorithm
+   * \param seq the linear pattern of storage access
+   * \param alloc_info
+   */
+  void PlanMemory(const std::vector<StmtEntry>& seq) {
+    std::unordered_set<const VarNode*> inplace_flag;
+
+    for (size_t i = 0; i < seq.size(); ++i) {
+      auto it = event_map_.find(seq[i].stmt);
+      // scope_pair_offset >= 0 means it is either
+      // - leaf stmt(offset = 0)
+      // - beginning of scope(offset < 0)
+      // In both cases, we need to handle the gen event correctly
+      if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) {
+        for (const VarNode* var : it->second.gen) {
+          ICHECK(dyn_shmem_allocs_.count(var));
+          const AllocateNode* alloc = dyn_shmem_allocs_[var];
+          StorageEntry* dst_entry = FindAlloc(alloc);
+          alloc_map_[var] = dst_entry;
+        }
+      }
+      // scope_pair_offset <= 0 means it is either
+      // - leaf stmt(offset = 0)
+      // - end of scope(offset < 0)
+      // In both cases, we need to handle the kill event correctly
+      if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
+        for (const VarNode* var : it->second.kill) {
+          this->Free(var);
+        }
+      }
+    }
+  }
+  /*!
+   * \brief Allocate new storage entry.
+   * \param op the allocate node
+   * \param the size of the allocation in bits
+   * \return the new storage entry
+   */
+  StorageEntry* NewAlloc(const AllocateNode* op, size_t const_nbits) {

Review comment:
       Why we need return a bare pointer?

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -90,42 +296,278 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
 
   PrimExpr VisitExpr_(const LoadNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      auto offset = GetBufferOffset(op->buffer_var, op->dtype);
-      auto index = StmtExprMutator::VisitExpr(op->index);
+      PrimExpr offset = GetBufferOffset(op->buffer_var, op->dtype);
+      PrimExpr index = StmtExprMutator::VisitExpr(op->index);
       return Load(op->dtype, merged_buf_var_, offset + index, op->predicate, op->span);
     }
     return StmtExprMutator::VisitExpr_(op);
   }
 
   Stmt VisitStmt_(const StoreNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      auto offset = GetBufferOffset(op->buffer_var, op->value->dtype);
-      auto index = StmtExprMutator::VisitExpr(op->index);
-      auto value = StmtExprMutator::VisitExpr(op->value);
+      PrimExpr offset = GetBufferOffset(op->buffer_var, op->value->dtype);
+      PrimExpr index = StmtExprMutator::VisitExpr(op->index);
+      PrimExpr value = StmtExprMutator::VisitExpr(op->value);
       return Store(merged_buf_var_, value, offset + index, op->predicate, op->span);
     }
     return StmtExprMutator::VisitStmt_(op);
   }
 
+  PrimExpr VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::tvm_access_ptr())) {
+      ICHECK_EQ(op->args.size(), 5U);
+      DataType dtype = op->args[0].dtype();
+      Var buffer = Downcast<Var>(op->args[1]);
+      if (!IsDynamicSharedMemory(buffer)) {
+        return StmtExprMutator::VisitExpr_(op);
+      }
+      PrimExpr extra_offset = GetBufferOffset(buffer, dtype);
+
+      PrimExpr offset = this->VisitExpr(op->args[2]);
+      PrimExpr extent = this->VisitExpr(op->args[3]);
+      return Call(op->dtype, op->op,
+                  {op->args[0], merged_buf_var_, extra_offset + offset, extent, op->args[4]});
+    } else {
+      return StmtExprMutator::VisitExpr_(op);
+    }
+  }
+
  private:
   PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) {
     auto it = buffer_byte_offsets_.find(buffer_var.get());
-    ICHECK(it != buffer_byte_offsets_.end());
+    ICHECK(it != buffer_byte_offsets_.end()) << buffer_var;

Review comment:
       ditto

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -36,46 +38,250 @@
 namespace tvm {
 namespace tir {
 
+using runtime::StorageRank;
+using runtime::StorageScope;
+
 bool IsDynamicSharedMemory(Var buffer_var) {
-  auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
   return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
 }
 
+/*!
+ * \brief collect the mapping from the buffer var to its allocate
+ */
 class AllocateCollector : public StmtExprVisitor {
  public:
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      dyn_shmem_allocs_.insert(op);
+      dyn_shmem_allocs_[op->buffer_var.get()] = op;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
+  // The mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+};
+
+// Find a linear pattern of storage access
+// Used for liveness analysis.
+// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
+// before_scope -> scope_body -> after_scope
+//
+// The linear_seq_ stores before_scope and after_scope.
+// The access to the arrays are stored at the after_scope point.
+//
+// Define "scope" as the body of For/thread_launch/IfThenElse
+// This pass tries to detect last point that we need to keep memory
+// alive under the same scope as allocate.
+// The storage need to be kept alive between allocate and last access.
+// The free point is only inserted at the same scope of allocate.
+//
+class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+ public:
+  /*! \brief record the touch hist of statment. */
+  struct StmtEntry {
+    // The statment
+    const Object* stmt;
+    // The index in the linear_seq_ to point to end of the nested scope.
+    // This is only set to non-zero if stmt is a nested scope.
+    // if offset > 0, means this is the begin, the end entry is current_index + offset
+    // if offset < 0, means this is the end, the begin entry is current_index + offset
+    int64_t scope_pair_offset{0};
+    // The buffer variables this statment touched.
+    std::vector<const VarNode*> touched;
+  };
+  // The scope of each allocation
+  struct AllocEntry {
+    // scope level
+    size_t level{0};
+    // allocation stmt
+    const AllocateNode* alloc{nullptr};
+  };
 
-  std::unordered_set<const AllocateNode*> dyn_shmem_allocs_;
+  void VisitStmt_(const AllocateNode* op) final {
+    size_t level = scope_.size();
+    const VarNode* buf = op->buffer_var.get();
+    alloc_info_[buf].alloc = op;
+    alloc_info_[buf].level = level;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+  void VisitStmt_(const StoreNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    // Add write access.
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size());
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitStmt_(const EvaluateNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitExpr_(const LoadNode* op) final {
+    // Add write access.
+    StmtExprVisitor::VisitExpr_(op);
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store.";
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+  }
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::address_of())) {
+      const LoadNode* l = op->args[0].as<LoadNode>();
+      this->VisitExpr(l->index);
+    } else {
+      StmtExprVisitor::VisitExpr_(op);
+    }
+  }
+  void VisitExpr_(const VarNode* buf) final {
+    // Directly reference to the variable count as a read.
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint;
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+  }
+  template <typename T>
+  void VisitNewScope(const T* op) {
+    scope_.push_back(StmtEntry());
+    StmtEntry e;
+    e.stmt = op;
+    int64_t begin_index = static_cast<int64_t>(linear_seq_.size());
+    // before scope.
+    linear_seq_.push_back(e);
+    StmtExprVisitor::VisitStmt_(op);
+    // after scope.
+    e.touched = std::move(scope_.back().touched);
+    scope_.pop_back();
+    int64_t end_index = static_cast<int64_t>(linear_seq_.size());
+    ICHECK_GT(end_index, begin_index);
+    e.scope_pair_offset = begin_index - end_index;
+    linear_seq_.push_back(e);
+    // record the pointer to end index.
+    ICHECK_NE(end_index, 0U);
+    linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
+  }
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Only record the outer most thread extent.
+    if (op->attr_key == attr::thread_extent && !in_thread_env_) {
+      in_thread_env_ = true;
+      VisitNewScope(op);
+      in_thread_env_ = false;
+    } else if (op->attr_key == attr::extern_scope) {
+      VisitNewScope(op);
+    } else if (op->attr_key == attr::virtual_thread) {
+      VisitNewScope(op);
+    } else {
+      StmtExprVisitor::VisitStmt_(op);
+    }
+  }
+  void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const ForNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const WhileNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); }
+
+  // linearized access sequence.
+  std::vector<StmtEntry> linear_seq_;
+  // The storage scope of each buffer
+  std::unordered_map<const VarNode*, AllocEntry> alloc_info_;

Review comment:
       Please keep it private if it is never be accessed outside.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org