You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/10/25 16:01:20 UTC

[GitHub] [tvm] MasterJH5574 commented on a change in pull request #9341: [CUDA] Support memory reuse for dynamic shared memory

MasterJH5574 commented on a change in pull request #9341:
URL: https://github.com/apache/tvm/pull/9341#discussion_r735741835



##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -31,51 +31,256 @@
 #include <unordered_set>
 
 #include "../../runtime/thread_storage_scope.h"
+#include "../../support/arena.h"
 #include "ir_utils.h"
 
 namespace tvm {
 namespace tir {
 
+using runtime::StorageRank;
+using runtime::StorageScope;
+
 bool IsDynamicSharedMemory(Var buffer_var) {
-  auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
   return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
 }
 
+/*!
+ * \brief collect the mapping from the buffer var to its allocate
+ */
 class AllocateCollector : public StmtExprVisitor {
  public:
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      dyn_shmem_allocs_.insert(op);
+      dyn_shmem_allocs_[op->buffer_var.get()] = op;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
+  // The mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+};
+
+// Find a linear pattern of storage access
+// Used for liveness analysis.
+// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
+// before_scope -> scope_body -> after_scope
+//
+// The linear_seq_ stores before_scope and after_scope.
+// The access to the arrays are stored at the after_scope point.
+//
+// Define "scope" as the body of For/thread_launch/IfThenElse
+// This pass tries to detect last point that we need to keep memory
+// alive under the same scope as allocate.
+// The storage need to be kept alive between allocate and last access.
+// The free point is only inserted at the same scope of allocate.
+//
+class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+ public:
+  /*! \brief record the touch hist of statment. */
+  struct StmtEntry {
+    // The statment
+    const Object* stmt;
+    // The index in the linear_seq_ to point to end of the nested scope.
+    // This is only set to non-zero if stmt is a nested scope.
+    // if offset > 0, means this is the begin, the end entry is current_index + offset
+    // if offset < 0, means this is the end, the begin entry is current_index + offset
+    int64_t scope_pair_offset{0};
+    // The buffer variables this statment touched.

Review comment:
       ```suggestion
       // The buffer variables this statement touched.
   ```

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -36,46 +38,250 @@
 namespace tvm {
 namespace tir {
 
+using runtime::StorageRank;
+using runtime::StorageScope;
+
 bool IsDynamicSharedMemory(Var buffer_var) {
-  auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
   return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
 }
 
+/*!
+ * \brief collect the mapping from the buffer var to its allocate
+ */
 class AllocateCollector : public StmtExprVisitor {
  public:
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      dyn_shmem_allocs_.insert(op);
+      dyn_shmem_allocs_[op->buffer_var.get()] = op;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
+  // The mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+};
+
+// Find a linear pattern of storage access
+// Used for liveness analysis.
+// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
+// before_scope -> scope_body -> after_scope
+//
+// The linear_seq_ stores before_scope and after_scope.
+// The access to the arrays are stored at the after_scope point.
+//
+// Define "scope" as the body of For/thread_launch/IfThenElse
+// This pass tries to detect last point that we need to keep memory
+// alive under the same scope as allocate.
+// The storage need to be kept alive between allocate and last access.
+// The free point is only inserted at the same scope of allocate.
+//
+class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+ public:
+  /*! \brief record the touch hist of statment. */

Review comment:
       Do you mean "list" or "history" here?
   ```suggestion
     /*! \brief record the touch list of statement. */
   ```

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -31,51 +31,256 @@
 #include <unordered_set>
 
 #include "../../runtime/thread_storage_scope.h"
+#include "../../support/arena.h"
 #include "ir_utils.h"
 
 namespace tvm {
 namespace tir {
 
+using runtime::StorageRank;
+using runtime::StorageScope;
+
 bool IsDynamicSharedMemory(Var buffer_var) {
-  auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
   return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
 }
 
+/*!
+ * \brief collect the mapping from the buffer var to its allocate
+ */
 class AllocateCollector : public StmtExprVisitor {
  public:
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      dyn_shmem_allocs_.insert(op);
+      dyn_shmem_allocs_[op->buffer_var.get()] = op;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
+  // The mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+};
+
+// Find a linear pattern of storage access
+// Used for liveness analysis.
+// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
+// before_scope -> scope_body -> after_scope
+//
+// The linear_seq_ stores before_scope and after_scope.
+// The access to the arrays are stored at the after_scope point.
+//
+// Define "scope" as the body of For/thread_launch/IfThenElse
+// This pass tries to detect last point that we need to keep memory
+// alive under the same scope as allocate.
+// The storage need to be kept alive between allocate and last access.
+// The free point is only inserted at the same scope of allocate.
+//
+class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+ public:
+  /*! \brief record the touch hist of statment. */
+  struct StmtEntry {
+    // The statment

Review comment:
       ```suggestion
       // The statement
   ```

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -31,51 +31,256 @@
 #include <unordered_set>
 
 #include "../../runtime/thread_storage_scope.h"
+#include "../../support/arena.h"
 #include "ir_utils.h"
 
 namespace tvm {
 namespace tir {
 
+using runtime::StorageRank;
+using runtime::StorageScope;
+
 bool IsDynamicSharedMemory(Var buffer_var) {
-  auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
   return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
 }
 
+/*!
+ * \brief collect the mapping from the buffer var to its allocate
+ */
 class AllocateCollector : public StmtExprVisitor {
  public:
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      dyn_shmem_allocs_.insert(op);
+      dyn_shmem_allocs_[op->buffer_var.get()] = op;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
+  // The mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+};
+
+// Find a linear pattern of storage access
+// Used for liveness analysis.
+// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
+// before_scope -> scope_body -> after_scope
+//
+// The linear_seq_ stores before_scope and after_scope.
+// The access to the arrays are stored at the after_scope point.
+//
+// Define "scope" as the body of For/thread_launch/IfThenElse
+// This pass tries to detect last point that we need to keep memory
+// alive under the same scope as allocate.
+// The storage need to be kept alive between allocate and last access.
+// The free point is only inserted at the same scope of allocate.
+//
+class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+ public:

Review comment:
       I'm confused with the docs here for several points:
   * Seems that the definition of "scope" is after the sentence "Composite scopes(loop/thread_launch/IfThen) is represented by two points...". Should the definition be stated earlier?
   * What does "linear" mean? What is `linear_seq_`?
   * What is `before_scope`, `scope_body` and `after_scope`?
   * "allocate" itself is a verb. So I suggest using "AllocateNode" or "Allocate". Directly using "allocate" is confusing.

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -31,51 +31,256 @@
 #include <unordered_set>
 
 #include "../../runtime/thread_storage_scope.h"
+#include "../../support/arena.h"
 #include "ir_utils.h"
 
 namespace tvm {
 namespace tir {
 
+using runtime::StorageRank;
+using runtime::StorageScope;
+
 bool IsDynamicSharedMemory(Var buffer_var) {
-  auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
+  StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var));
   return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn";
 }
 
+/*!
+ * \brief collect the mapping from the buffer var to its allocate
+ */
 class AllocateCollector : public StmtExprVisitor {
  public:
   void VisitStmt_(const AllocateNode* op) final {
     if (IsDynamicSharedMemory(op->buffer_var)) {
-      dyn_shmem_allocs_.insert(op);
+      dyn_shmem_allocs_[op->buffer_var.get()] = op;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
+  // The mapping from the original buffer var to its allocate
+  std::unordered_map<const VarNode*, const AllocateNode*> dyn_shmem_allocs_;
+};
+
+// Find a linear pattern of storage access
+// Used for liveness analysis.
+// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
+// before_scope -> scope_body -> after_scope
+//
+// The linear_seq_ stores before_scope and after_scope.
+// The access to the arrays are stored at the after_scope point.
+//
+// Define "scope" as the body of For/thread_launch/IfThenElse
+// This pass tries to detect last point that we need to keep memory
+// alive under the same scope as allocate.
+// The storage need to be kept alive between allocate and last access.
+// The free point is only inserted at the same scope of allocate.
+//
+class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
+ public:
+  /*! \brief record the touch hist of statment. */
+  struct StmtEntry {
+    // The statment
+    const Object* stmt;
+    // The index in the linear_seq_ to point to end of the nested scope.
+    // This is only set to non-zero if stmt is a nested scope.
+    // if offset > 0, means this is the begin, the end entry is current_index + offset
+    // if offset < 0, means this is the end, the begin entry is current_index + offset
+    int64_t scope_pair_offset{0};
+    // The buffer variables this statment touched.
+    std::vector<const VarNode*> touched;
+  };
+  // The scope of each allocation
+  struct AllocEntry {
+    // the level in the scope stack
+    size_t level{0};
+    // allocation stmt
+    const AllocateNode* alloc{nullptr};
+  };
 
-  std::unordered_set<const AllocateNode*> dyn_shmem_allocs_;
+  void VisitStmt_(const AllocateNode* op) final {
+    size_t level = scope_.size();
+    const VarNode* buf = op->buffer_var.get();
+    alloc_info_[buf].alloc = op;
+    alloc_info_[buf].level = level;
+    StmtExprVisitor::VisitStmt_(op);
+  }
+  void VisitStmt_(const StoreNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    // Add write access.
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size());
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitStmt_(const EvaluateNode* op) final {
+    scope_.push_back(StmtEntry());
+    // visit subexpr
+    StmtExprVisitor::VisitStmt_(op);
+    StmtEntry e = scope_.back();
+    scope_.pop_back();
+    if (e.touched.size() != 0) {
+      e.stmt = op;
+      linear_seq_.push_back(e);
+    }
+  }
+  void VisitExpr_(const LoadNode* op) final {
+    // Add write access.
+    StmtExprVisitor::VisitExpr_(op);
+    const VarNode* buf = op->buffer_var.get();
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store.";
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+  }
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::address_of())) {
+      const LoadNode* l = op->args[0].as<LoadNode>();
+      this->VisitExpr(l->index);
+    } else {
+      StmtExprVisitor::VisitExpr_(op);
+    }
+  }
+  void VisitExpr_(const VarNode* buf) final {
+    // Directly reference to the variable count as a read.
+    auto it = alloc_info_.find(buf);
+    if (it != alloc_info_.end() && it->second.alloc) {
+      ICHECK_LT(it->second.level, scope_.size());
+      if (IsDynamicSharedMemory(GetRef<Var>(buf))) {
+        scope_[it->second.level].touched.push_back(buf);
+      }
+    }
+  }
+  template <typename T>
+  void VisitNewScope(const T* op) {
+    scope_.push_back(StmtEntry());
+    StmtEntry e;
+    e.stmt = op;
+    int64_t begin_index = static_cast<int64_t>(linear_seq_.size());
+    // before scope.
+    linear_seq_.push_back(e);
+    StmtExprVisitor::VisitStmt_(op);
+    // after scope.
+    e.touched = std::move(scope_.back().touched);
+    scope_.pop_back();
+    int64_t end_index = static_cast<int64_t>(linear_seq_.size());
+    ICHECK_GT(end_index, begin_index);
+    e.scope_pair_offset = begin_index - end_index;
+    linear_seq_.push_back(e);
+    // record the pointer to end index.
+    ICHECK_NE(end_index, 0U);
+    linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
+  }
+  void VisitStmt_(const AttrStmtNode* op) final {
+    // Only record the outer most thread extent.
+    if (op->attr_key == attr::thread_extent && !in_thread_env_) {
+      in_thread_env_ = true;
+      VisitNewScope(op);
+      in_thread_env_ = false;
+    } else if (op->attr_key == attr::extern_scope) {
+      VisitNewScope(op);
+    } else if (op->attr_key == attr::virtual_thread) {
+      VisitNewScope(op);
+    } else {
+      StmtExprVisitor::VisitStmt_(op);
+    }
+  }
+  void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const ForNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const WhileNode* op) final { VisitNewScope(op); }
+
+  void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); }
+
+  // linearized access sequence.
+  std::vector<StmtEntry> linear_seq_;
+  // The storage scope of each buffer
+  std::unordered_map<const VarNode*, AllocEntry> alloc_info_;
+
+ private:
+  // Whether already in thread env.
+  bool in_thread_env_{false};
+  // The scope stack.
+  std::vector<StmtEntry> scope_;
 };
 
+/*!
+ * \brief merge the buffers whose live range has no intersection and rewrite the body
+ */
 class DynamicSharedMemoryRewriter : public StmtExprMutator {
  public:
   explicit DynamicSharedMemoryRewriter(
-      const std::unordered_set<const AllocateNode*>& dyn_shmem_allocs)
+      const std::unordered_map<const VarNode*, const AllocateNode*>& dyn_shmem_allocs)
       : dyn_shmem_allocs_{dyn_shmem_allocs} {}
 
+  /*!
+   * \brief plan the memory reuse for all the buffer allocated in the statement
+   * \param stmt the statement
+   */
+  void PlanReuse(const Stmt& stmt) {
+    DynSharedMemLinearAccessPatternFinder finder;
+    finder(stmt);
+    this->LivenessAnalysis(finder.linear_seq_);
+    this->PlanMemory(finder.linear_seq_);
+  }
+
   Stmt VisitStmt_(const AttrStmtNode* op) final {

Review comment:
       Shouldn't these `VisitStmt_(...)` be private?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org