You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/07/08 16:34:09 UTC

[GitHub] [tvm] NicolaLancellotti commented on a diff in pull request #12029: [microNPU] Add MergeConstants pass

NicolaLancellotti commented on code in PR #12029:
URL: https://github.com/apache/tvm/pull/12029#discussion_r916976523


##########
src/tir/contrib/ethosu/passes.cc:
##########
@@ -223,6 +224,572 @@ tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements)
 TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering")
     .set_body_typed(CopyComputeReordering);
 
+/*!
+ * \brief This pass looks for the constants used by each compute operator
+ * and merges them into a single buffer.
+ * Constants written to a buffer with local scope are not merged.
+ */
+class MergeConstantsMutator : public StmtExprMutator {
+ public:
+  MergeConstantsMutator() {}
+
+  PrimFunc operator()(PrimFunc main_func, const Map<IntImm, runtime::NDArray>& const_dict) {
+    // Analyze
+    Stmt new_body{this->VisitStmt(main_func->body)};
+
+    // Rewrite
+    analyze = false;
+    new_body = rewrite_prim_func_body(new_body);
+    std::set<ObjectRef> params_to_delete{};
+    auto new_buffer_map{make_new_buffer_map(main_func->buffer_map, &params_to_delete)};
+    auto new_params{make_new_params(main_func->params, params_to_delete)};
+
+    // Make the new const dict
+    auto args_to_merge{get_args_to_merge(main_func->buffer_map, main_func->params)};
+    auto buffers_to_merge{
+        get_args_to_merge_without_args_not_in_const_dict(args_to_merge, const_dict)};
+    auto new_const_dict{make_new_const_dict(buffers_to_merge, const_dict)};
+
+    // Make the new prim func
+    auto prim_func_node{main_func.CopyOnWrite()};
+    prim_func_node->body = std::move(new_body);
+    prim_func_node->buffer_map = std::move(new_buffer_map);
+    prim_func_node->params = std::move(new_params);
+    prim_func_node->preflattened_buffer_map = {};
+    PrimFunc f{GetRef<PrimFunc>(prim_func_node)};
+
+    // Add the new const dict as an attribute
+    f = WithAttr(std::move(f), "ethos-u.const-dict", new_const_dict);
+
+    return f;
+  }
+
+ private:
+  /*! Indicates whether the pass is analyzing or rewriting */
+  bool analyze = true;
+
+  /*! A stack to store allocates as they are visited. */
+  std::vector<Allocate> allocates{};
+
+  /*! A list that contains in the i-th position the write buffer of the i-th statement
+   * if that statement is a copy to a buffer with global scope  */
+  std::vector<Optional<Buffer>> copy_write_buffers{};
+
+  /*! Maps a copy's write buffer to an index representing the
+   * new buffer and an offset in that buffer */
+  std::map<Buffer, std::pair<int /* new buffer index */, int /* offset */>>
+      old_to_new_write_buffer{};
+
+  /*! Maps an index representing a new buffer to the length of that buffer */
+  std::map<int /* new buffer index */, int /* length */> new_buffers_length{};
+
+  /*! Maps an index representing a new buffer to the new buffer */
+  std::map<int /* new buffer index */, Buffer> new_buffers{};
+
+  /*! Maps an index representing a new buffer to the cycle_counts needed to copy that buffer */
+  std::map<int /* new buffer index */, int64_t> cycle_counts{};
+
+  /*! Maps a copy's read buffer to the new copy's read buffer */
+  std::map<Buffer, Buffer> old_to_new_read_buffers{};
+
+  /*! Maps an index representing a new buffer to the list of buffers to be merged in the new buffer
+   */
+  std::map<int /* new buffer index */, std::vector<Buffer>> buffers_to_merge{};
+
+  /*! A set of buffers to delete */
+  std::set<Buffer> buffers_to_delete{};
+
+  // Visit
+
+  Stmt VisitStmt_(const AllocateNode* op) override {
+    if (analyze) {
+      allocates.push_back(GetRef<Allocate>(op));
+      return VisitStmt(op->body);
+    } else {
+      auto allocate{CopyOnWrite(op)};
+      allocate->body = this->VisitStmt(op->body);
+      return Stmt(allocate);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) override {
+    if (op->size() <= 1) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    return analyze ? analyze_seq_stmt(op) : rewrite_seq_stmt(op);
+  }
+
+  Stmt analyze_seq_stmt(const SeqStmtNode* op) {
+    auto seq_stmt{GetRef<SeqStmt>(op)};
+
+    for (size_t i = 0; i < seq_stmt.size(); ++i) {
+      Stmt stmt{seq_stmt[i]};
+
+      switch (get_stmt_type(stmt)) {
+        case StmtType::global_copy: {
+          Buffer write_buffer{get_copy_write_buffer(stmt)};
+          copy_write_buffers.push_back(write_buffer);
+          old_to_new_write_buffer[write_buffer] = std::make_pair(-1, -1);
+          break;
+        }
+        case StmtType::local_copy: {
+          copy_write_buffers.push_back(Optional<Buffer>{});
+          break;
+        }
+        case StmtType::compute: {
+          copy_write_buffers.push_back(Optional<Buffer>{});
+          auto buffers{get_copied_buffers_used_by_stmt(stmt)};
+          if (buffers.empty()) {
+            continue;
+          }
+          new_buffers_length[i] = 0;
+          for (auto buffer : buffers) {
+            for (size_t j{i - 1}; j >= 0; --j) {
+              if (copy_write_buffers[j] == buffer) {
+                old_to_new_write_buffer[buffer] = std::make_pair(i, new_buffers_length[i]);
+                new_buffers_length[i] += get_copy_length(seq_stmt[j]);
+                cycle_counts[i] += get_stmt_cycle_counts(seq_stmt[j]);
+                break;
+              }
+            }
+          }
+          break;
+        }
+      }
+    }
+    return seq_stmt;
+  }
+
+  Stmt rewrite_prim_func_body(Stmt body) {
+    std::map<const VarNode*, Allocate> var_to_allocate{};
+
+    // Rewrite old allocates
+    std::set<ObjectRef> buffer_vars{get_vars_for_written_copy_buffers()};
+    for (auto it{allocates.rbegin()}; it != allocates.rend(); ++it) {
+      Allocate alloc{*it};
+      var_to_allocate[alloc->buffer_var.get()] = alloc;
+      if (buffer_vars.count(alloc->buffer_var) == 0) {
+        body = Allocate(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->condition, body,
+                        alloc->annotations, alloc->span);
+      }
+    }
+
+    // Rewrite new allocates
+    for (auto it{copy_write_buffers.rbegin()}; it != copy_write_buffers.rend(); ++it) {
+      if (auto buffer_opt = *it) {

Review Comment:
   Done.



##########
src/tir/contrib/ethosu/passes.cc:
##########
@@ -223,6 +224,572 @@ tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements)
 TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering")
     .set_body_typed(CopyComputeReordering);
 
+/*!
+ * \brief This pass looks for the constants used by each compute operator
+ * and merges them into a single buffer.
+ * Constants written to a buffer with local scope are not merged.
+ */
+class MergeConstantsMutator : public StmtExprMutator {
+ public:
+  MergeConstantsMutator() {}
+
+  PrimFunc operator()(PrimFunc main_func, const Map<IntImm, runtime::NDArray>& const_dict) {
+    // Analyze
+    Stmt new_body{this->VisitStmt(main_func->body)};

Review Comment:
   Done.



##########
src/tir/contrib/ethosu/passes.cc:
##########
@@ -223,6 +224,572 @@ tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements)
 TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering")
     .set_body_typed(CopyComputeReordering);
 
+/*!
+ * \brief This pass looks for the constants used by each compute operator
+ * and merges them into a single buffer.
+ * Constants written to a buffer with local scope are not merged.
+ */
+class MergeConstantsMutator : public StmtExprMutator {
+ public:
+  MergeConstantsMutator() {}
+
+  PrimFunc operator()(PrimFunc main_func, const Map<IntImm, runtime::NDArray>& const_dict) {
+    // Analyze
+    Stmt new_body{this->VisitStmt(main_func->body)};
+
+    // Rewrite
+    analyze = false;
+    new_body = rewrite_prim_func_body(new_body);
+    std::set<ObjectRef> params_to_delete{};
+    auto new_buffer_map{make_new_buffer_map(main_func->buffer_map, &params_to_delete)};

Review Comment:
   Done.



##########
src/tir/contrib/ethosu/passes.cc:
##########
@@ -223,6 +224,572 @@ tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements)
 TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering")
     .set_body_typed(CopyComputeReordering);
 
+/*!
+ * \brief This pass looks for the constants used by each compute operator
+ * and merges them into a single buffer.
+ * Constants written to a buffer with local scope are not merged.
+ */
+class MergeConstantsMutator : public StmtExprMutator {
+ public:
+  MergeConstantsMutator() {}
+
+  PrimFunc operator()(PrimFunc main_func, const Map<IntImm, runtime::NDArray>& const_dict) {
+    // Analyze
+    Stmt new_body{this->VisitStmt(main_func->body)};
+
+    // Rewrite
+    analyze = false;
+    new_body = rewrite_prim_func_body(new_body);
+    std::set<ObjectRef> params_to_delete{};
+    auto new_buffer_map{make_new_buffer_map(main_func->buffer_map, &params_to_delete)};
+    auto new_params{make_new_params(main_func->params, params_to_delete)};
+
+    // Make the new const dict
+    auto args_to_merge{get_args_to_merge(main_func->buffer_map, main_func->params)};
+    auto buffers_to_merge{
+        get_args_to_merge_without_args_not_in_const_dict(args_to_merge, const_dict)};
+    auto new_const_dict{make_new_const_dict(buffers_to_merge, const_dict)};
+
+    // Make the new prim func
+    auto prim_func_node{main_func.CopyOnWrite()};
+    prim_func_node->body = std::move(new_body);
+    prim_func_node->buffer_map = std::move(new_buffer_map);
+    prim_func_node->params = std::move(new_params);
+    prim_func_node->preflattened_buffer_map = {};
+    PrimFunc f{GetRef<PrimFunc>(prim_func_node)};
+
+    // Add the new const dict as an attribute
+    f = WithAttr(std::move(f), "ethos-u.const-dict", new_const_dict);
+
+    return f;
+  }
+
+ private:
+  /*! Indicates whether the pass is analyzing or rewriting */
+  bool analyze = true;
+
+  /*! A stack to store allocates as they are visited. */
+  std::vector<Allocate> allocates{};
+
+  /*! A list that contains in the i-th position the write buffer of the i-th statement
+   * if that statement is a copy to a buffer with global scope  */
+  std::vector<Optional<Buffer>> copy_write_buffers{};
+
+  /*! Maps a copy's write buffer to an index representing the
+   * new buffer and an offset in that buffer */
+  std::map<Buffer, std::pair<int /* new buffer index */, int /* offset */>>
+      old_to_new_write_buffer{};
+
+  /*! Maps an index representing a new buffer to the length of that buffer */
+  std::map<int /* new buffer index */, int /* length */> new_buffers_length{};
+
+  /*! Maps an index representing a new buffer to the new buffer */
+  std::map<int /* new buffer index */, Buffer> new_buffers{};
+
+  /*! Maps an index representing a new buffer to the cycle_counts needed to copy that buffer */
+  std::map<int /* new buffer index */, int64_t> cycle_counts{};
+
+  /*! Maps a copy's read buffer to the new copy's read buffer */
+  std::map<Buffer, Buffer> old_to_new_read_buffers{};
+
+  /*! Maps an index representing a new buffer to the list of buffers to be merged in the new buffer
+   */
+  std::map<int /* new buffer index */, std::vector<Buffer>> buffers_to_merge{};
+
+  /*! A set of buffers to delete */
+  std::set<Buffer> buffers_to_delete{};
+
+  // Visit
+
+  Stmt VisitStmt_(const AllocateNode* op) override {
+    if (analyze) {
+      allocates.push_back(GetRef<Allocate>(op));
+      return VisitStmt(op->body);
+    } else {
+      auto allocate{CopyOnWrite(op)};
+      allocate->body = this->VisitStmt(op->body);
+      return Stmt(allocate);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) override {
+    if (op->size() <= 1) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    return analyze ? analyze_seq_stmt(op) : rewrite_seq_stmt(op);

Review Comment:
   Done.



##########
python/tvm/relay/backend/contrib/ethosu/tir/passes.py:
##########
@@ -938,3 +938,38 @@ def CopyComputeReordering(max_copy_movements: Optional[int] = None) -> tvm.IRMod
         The new module with copy and compute nodes reordered.
     """
     return _ffi_api.CopyComputeReordering(max_copy_movements)
+
+
+def MergeConstants(const_dict):
+    """
+    This pass looks for the constants used by each compute operator
+    and merges them into a single buffer.
+    Constants written to a buffer with local scope are not merged.
+    """
+
+    def mergeConstantsPass(mod):
+        nonlocal const_dict
+        try:
+            mod["main"]
+        except:
+            raise tvm.TVMError(
+                "Expected a single primitive function called 'main'. "
+                "Please run the MergeConstants pass in conjunction with the LowerToTIR() pass."
+            )
+
+        new_const_dict = {}
+        for param in const_dict.keys():
+            new_const_dict[tvm.tir.IntImm("int64", param)] = tvm.nd.array(const_dict[param])
+        mod["main"] = mod["main"].with_attr("ethos-u.const-dict", new_const_dict)

Review Comment:
   Done.



##########
src/tir/contrib/ethosu/passes.cc:
##########
@@ -223,6 +224,572 @@ tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements)
 TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering")
     .set_body_typed(CopyComputeReordering);
 
+/*!
+ * \brief This pass looks for the constants used by each compute operator
+ * and merges them into a single buffer.
+ * Constants written to a buffer with local scope are not merged.
+ */
+class MergeConstantsMutator : public StmtExprMutator {
+ public:
+  MergeConstantsMutator() {}
+
+  PrimFunc operator()(PrimFunc main_func, const Map<IntImm, runtime::NDArray>& const_dict) {
+    // Analyze
+    Stmt new_body{this->VisitStmt(main_func->body)};
+
+    // Rewrite
+    analyze = false;
+    new_body = rewrite_prim_func_body(new_body);
+    std::set<ObjectRef> params_to_delete{};
+    auto new_buffer_map{make_new_buffer_map(main_func->buffer_map, &params_to_delete)};
+    auto new_params{make_new_params(main_func->params, params_to_delete)};
+
+    // Make the new const dict
+    auto args_to_merge{get_args_to_merge(main_func->buffer_map, main_func->params)};
+    auto buffers_to_merge{
+        get_args_to_merge_without_args_not_in_const_dict(args_to_merge, const_dict)};
+    auto new_const_dict{make_new_const_dict(buffers_to_merge, const_dict)};
+
+    // Make the new prim func
+    auto prim_func_node{main_func.CopyOnWrite()};
+    prim_func_node->body = std::move(new_body);
+    prim_func_node->buffer_map = std::move(new_buffer_map);
+    prim_func_node->params = std::move(new_params);
+    prim_func_node->preflattened_buffer_map = {};
+    PrimFunc f{GetRef<PrimFunc>(prim_func_node)};
+
+    // Add the new const dict as an attribute
+    f = WithAttr(std::move(f), "ethos-u.const-dict", new_const_dict);
+
+    return f;
+  }
+
+ private:
+  /*! Indicates whether the pass is analyzing or rewriting */
+  bool analyze = true;
+
+  /*! A stack to store allocates as they are visited. */
+  std::vector<Allocate> allocates{};
+
+  /*! A list that contains in the i-th position the write buffer of the i-th statement
+   * if that statement is a copy to a buffer with global scope  */
+  std::vector<Optional<Buffer>> copy_write_buffers{};
+
+  /*! Maps a copy's write buffer to an index representing the
+   * new buffer and an offset in that buffer */
+  std::map<Buffer, std::pair<int /* new buffer index */, int /* offset */>>
+      old_to_new_write_buffer{};
+
+  /*! Maps an index representing a new buffer to the length of that buffer */
+  std::map<int /* new buffer index */, int /* length */> new_buffers_length{};
+
+  /*! Maps an index representing a new buffer to the new buffer */
+  std::map<int /* new buffer index */, Buffer> new_buffers{};
+
+  /*! Maps an index representing a new buffer to the cycle_counts needed to copy that buffer */
+  std::map<int /* new buffer index */, int64_t> cycle_counts{};
+
+  /*! Maps a copy's read buffer to the new copy's read buffer */
+  std::map<Buffer, Buffer> old_to_new_read_buffers{};
+
+  /*! Maps an index representing a new buffer to the list of buffers to be merged in the new buffer
+   */
+  std::map<int /* new buffer index */, std::vector<Buffer>> buffers_to_merge{};
+
+  /*! A set of buffers to delete */
+  std::set<Buffer> buffers_to_delete{};
+
+  // Visit
+
+  Stmt VisitStmt_(const AllocateNode* op) override {
+    if (analyze) {
+      allocates.push_back(GetRef<Allocate>(op));
+      return VisitStmt(op->body);
+    } else {
+      auto allocate{CopyOnWrite(op)};
+      allocate->body = this->VisitStmt(op->body);
+      return Stmt(allocate);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) override {
+    if (op->size() <= 1) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    return analyze ? analyze_seq_stmt(op) : rewrite_seq_stmt(op);
+  }
+
+  Stmt analyze_seq_stmt(const SeqStmtNode* op) {

Review Comment:
   Done.



##########
python/tvm/relay/backend/contrib/ethosu/tir/passes.py:
##########
@@ -938,3 +938,38 @@ def CopyComputeReordering(max_copy_movements: Optional[int] = None) -> tvm.IRMod
         The new module with copy and compute nodes reordered.
     """
     return _ffi_api.CopyComputeReordering(max_copy_movements)
+
+
+def MergeConstants(const_dict):
+    """
+    This pass looks for the constants used by each compute operator
+    and merges them into a single buffer.
+    Constants written to a buffer with local scope are not merged.
+    """
+
+    def mergeConstantsPass(mod):

Review Comment:
   Done.



##########
src/tir/contrib/ethosu/passes.cc:
##########
@@ -223,6 +224,572 @@ tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements)
 TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering")
     .set_body_typed(CopyComputeReordering);
 
+/*!
+ * \brief This pass looks for the constants used by each compute operator
+ * and merges them into a single buffer.
+ * Constants written to a buffer with local scope are not merged.
+ */
+class MergeConstantsMutator : public StmtExprMutator {
+ public:
+  MergeConstantsMutator() {}
+
+  PrimFunc operator()(PrimFunc main_func, const Map<IntImm, runtime::NDArray>& const_dict) {
+    // Analyze
+    Stmt new_body{this->VisitStmt(main_func->body)};
+
+    // Rewrite
+    analyze = false;
+    new_body = rewrite_prim_func_body(new_body);
+    std::set<ObjectRef> params_to_delete{};
+    auto new_buffer_map{make_new_buffer_map(main_func->buffer_map, &params_to_delete)};
+    auto new_params{make_new_params(main_func->params, params_to_delete)};
+
+    // Make the new const dict
+    auto args_to_merge{get_args_to_merge(main_func->buffer_map, main_func->params)};
+    auto buffers_to_merge{
+        get_args_to_merge_without_args_not_in_const_dict(args_to_merge, const_dict)};
+    auto new_const_dict{make_new_const_dict(buffers_to_merge, const_dict)};
+
+    // Make the new prim func
+    auto prim_func_node{main_func.CopyOnWrite()};
+    prim_func_node->body = std::move(new_body);
+    prim_func_node->buffer_map = std::move(new_buffer_map);
+    prim_func_node->params = std::move(new_params);
+    prim_func_node->preflattened_buffer_map = {};
+    PrimFunc f{GetRef<PrimFunc>(prim_func_node)};
+
+    // Add the new const dict as an attribute
+    f = WithAttr(std::move(f), "ethos-u.const-dict", new_const_dict);
+
+    return f;
+  }
+
+ private:
+  /*! Indicates whether the pass is analyzing or rewriting */
+  bool analyze = true;
+
+  /*! A stack to store allocates as they are visited. */
+  std::vector<Allocate> allocates{};
+
+  /*! A list that contains in the i-th position the write buffer of the i-th statement
+   * if that statement is a copy to a buffer with global scope  */
+  std::vector<Optional<Buffer>> copy_write_buffers{};
+
+  /*! Maps a copy's write buffer to an index representing the
+   * new buffer and an offset in that buffer */
+  std::map<Buffer, std::pair<int /* new buffer index */, int /* offset */>>
+      old_to_new_write_buffer{};
+
+  /*! Maps an index representing a new buffer to the length of that buffer */
+  std::map<int /* new buffer index */, int /* length */> new_buffers_length{};

Review Comment:
   I switched to unordered_maps.



##########
src/tir/contrib/ethosu/passes.cc:
##########
@@ -223,6 +224,572 @@ tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements)
 TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering")
     .set_body_typed(CopyComputeReordering);
 
+/*!
+ * \brief This pass looks for the constants used by each compute operator
+ * and merges them into a single buffer.
+ * Constants written to a buffer with local scope are not merged.
+ */
+class MergeConstantsMutator : public StmtExprMutator {
+ public:
+  MergeConstantsMutator() {}
+
+  PrimFunc operator()(PrimFunc main_func, const Map<IntImm, runtime::NDArray>& const_dict) {
+    // Analyze
+    Stmt new_body{this->VisitStmt(main_func->body)};
+
+    // Rewrite
+    analyze = false;
+    new_body = rewrite_prim_func_body(new_body);
+    std::set<ObjectRef> params_to_delete{};
+    auto new_buffer_map{make_new_buffer_map(main_func->buffer_map, &params_to_delete)};
+    auto new_params{make_new_params(main_func->params, params_to_delete)};
+
+    // Make the new const dict
+    auto args_to_merge{get_args_to_merge(main_func->buffer_map, main_func->params)};
+    auto buffers_to_merge{
+        get_args_to_merge_without_args_not_in_const_dict(args_to_merge, const_dict)};
+    auto new_const_dict{make_new_const_dict(buffers_to_merge, const_dict)};
+
+    // Make the new prim func
+    auto prim_func_node{main_func.CopyOnWrite()};
+    prim_func_node->body = std::move(new_body);
+    prim_func_node->buffer_map = std::move(new_buffer_map);
+    prim_func_node->params = std::move(new_params);
+    prim_func_node->preflattened_buffer_map = {};
+    PrimFunc f{GetRef<PrimFunc>(prim_func_node)};
+
+    // Add the new const dict as an attribute
+    f = WithAttr(std::move(f), "ethos-u.const-dict", new_const_dict);
+
+    return f;
+  }
+
+ private:
+  /*! Indicates whether the pass is analyzing or rewriting */
+  bool analyze = true;
+
+  /*! A stack to store allocates as they are visited. */
+  std::vector<Allocate> allocates{};
+
+  /*! A list that contains in the i-th position the write buffer of the i-th statement
+   * if that statement is a copy to a buffer with global scope  */
+  std::vector<Optional<Buffer>> copy_write_buffers{};
+
+  /*! Maps a copy's write buffer to an index representing the
+   * new buffer and an offset in that buffer */
+  std::map<Buffer, std::pair<int /* new buffer index */, int /* offset */>>
+      old_to_new_write_buffer{};
+
+  /*! Maps an index representing a new buffer to the length of that buffer */
+  std::map<int /* new buffer index */, int /* length */> new_buffers_length{};
+
+  /*! Maps an index representing a new buffer to the new buffer */
+  std::map<int /* new buffer index */, Buffer> new_buffers{};
+
+  /*! Maps an index representing a new buffer to the cycle_counts needed to copy that buffer */
+  std::map<int /* new buffer index */, int64_t> cycle_counts{};
+
+  /*! Maps a copy's read buffer to the new copy's read buffer */
+  std::map<Buffer, Buffer> old_to_new_read_buffers{};
+
+  /*! Maps an index representing a new buffer to the list of buffers to be merged in the new buffer
+   */
+  std::map<int /* new buffer index */, std::vector<Buffer>> buffers_to_merge{};
+
+  /*! A set of buffers to delete */
+  std::set<Buffer> buffers_to_delete{};
+
+  // Visit
+
+  Stmt VisitStmt_(const AllocateNode* op) override {
+    if (analyze) {
+      allocates.push_back(GetRef<Allocate>(op));
+      return VisitStmt(op->body);
+    } else {
+      auto allocate{CopyOnWrite(op)};
+      allocate->body = this->VisitStmt(op->body);
+      return Stmt(allocate);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) override {
+    if (op->size() <= 1) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    return analyze ? analyze_seq_stmt(op) : rewrite_seq_stmt(op);
+  }
+
+  Stmt analyze_seq_stmt(const SeqStmtNode* op) {
+    auto seq_stmt{GetRef<SeqStmt>(op)};
+
+    for (size_t i = 0; i < seq_stmt.size(); ++i) {
+      Stmt stmt{seq_stmt[i]};
+
+      switch (get_stmt_type(stmt)) {

Review Comment:
   It is a switch on an enumerated value, and all cases are covered. So we don't need a default, don't we?



##########
src/tir/contrib/ethosu/passes.cc:
##########
@@ -223,6 +224,572 @@ tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements)
 TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering")
     .set_body_typed(CopyComputeReordering);
 
+/*!
+ * \brief This pass looks for the constants used by each compute operator
+ * and merges them into a single buffer.
+ * Constants written to a buffer with local scope are not merged.
+ */
+class MergeConstantsMutator : public StmtExprMutator {
+ public:
+  MergeConstantsMutator() {}
+
+  PrimFunc operator()(PrimFunc main_func, const Map<IntImm, runtime::NDArray>& const_dict) {
+    // Analyze
+    Stmt new_body{this->VisitStmt(main_func->body)};
+
+    // Rewrite
+    analyze = false;
+    new_body = rewrite_prim_func_body(new_body);
+    std::set<ObjectRef> params_to_delete{};
+    auto new_buffer_map{make_new_buffer_map(main_func->buffer_map, &params_to_delete)};
+    auto new_params{make_new_params(main_func->params, params_to_delete)};
+
+    // Make the new const dict
+    auto args_to_merge{get_args_to_merge(main_func->buffer_map, main_func->params)};
+    auto buffers_to_merge{
+        get_args_to_merge_without_args_not_in_const_dict(args_to_merge, const_dict)};
+    auto new_const_dict{make_new_const_dict(buffers_to_merge, const_dict)};
+
+    // Make the new prim func
+    auto prim_func_node{main_func.CopyOnWrite()};
+    prim_func_node->body = std::move(new_body);
+    prim_func_node->buffer_map = std::move(new_buffer_map);
+    prim_func_node->params = std::move(new_params);
+    prim_func_node->preflattened_buffer_map = {};
+    PrimFunc f{GetRef<PrimFunc>(prim_func_node)};
+
+    // Add the new const dict as an attribute
+    f = WithAttr(std::move(f), "ethos-u.const-dict", new_const_dict);
+
+    return f;
+  }
+
+ private:
+  /*! Indicates whether the pass is analyzing or rewriting */
+  bool analyze = true;
+
+  /*! A stack to store allocates as they are visited. */
+  std::vector<Allocate> allocates{};
+
+  /*! A list that contains in the i-th position the write buffer of the i-th statement
+   * if that statement is a copy to a buffer with global scope  */
+  std::vector<Optional<Buffer>> copy_write_buffers{};
+
+  /*! Maps a copy's write buffer to an index representing the
+   * new buffer and an offset in that buffer */
+  std::map<Buffer, std::pair<int /* new buffer index */, int /* offset */>>
+      old_to_new_write_buffer{};
+
+  /*! Maps an index representing a new buffer to the length of that buffer */
+  std::map<int /* new buffer index */, int /* length */> new_buffers_length{};
+
+  /*! Maps an index representing a new buffer to the new buffer */
+  std::map<int /* new buffer index */, Buffer> new_buffers{};
+
+  /*! Maps an index representing a new buffer to the cycle_counts needed to copy that buffer */
+  std::map<int /* new buffer index */, int64_t> cycle_counts{};
+
+  /*! Maps a copy's read buffer to the new copy's read buffer */
+  std::map<Buffer, Buffer> old_to_new_read_buffers{};
+
+  /*! Maps an index representing a new buffer to the list of buffers to be merged in the new buffer
+   */
+  std::map<int /* new buffer index */, std::vector<Buffer>> buffers_to_merge{};
+
+  /*! A set of buffers to delete */
+  std::set<Buffer> buffers_to_delete{};
+
+  // Visit
+
+  Stmt VisitStmt_(const AllocateNode* op) override {
+    if (analyze) {
+      allocates.push_back(GetRef<Allocate>(op));
+      return VisitStmt(op->body);
+    } else {
+      auto allocate{CopyOnWrite(op)};
+      allocate->body = this->VisitStmt(op->body);
+      return Stmt(allocate);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) override {
+    if (op->size() <= 1) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    return analyze ? analyze_seq_stmt(op) : rewrite_seq_stmt(op);
+  }
+
+  Stmt analyze_seq_stmt(const SeqStmtNode* op) {
+    auto seq_stmt{GetRef<SeqStmt>(op)};
+
+    for (size_t i = 0; i < seq_stmt.size(); ++i) {
+      Stmt stmt{seq_stmt[i]};
+
+      switch (get_stmt_type(stmt)) {
+        case StmtType::global_copy: {
+          Buffer write_buffer{get_copy_write_buffer(stmt)};
+          copy_write_buffers.push_back(write_buffer);
+          old_to_new_write_buffer[write_buffer] = std::make_pair(-1, -1);
+          break;
+        }
+        case StmtType::local_copy: {
+          copy_write_buffers.push_back(Optional<Buffer>{});
+          break;
+        }
+        case StmtType::compute: {
+          copy_write_buffers.push_back(Optional<Buffer>{});
+          auto buffers{get_copied_buffers_used_by_stmt(stmt)};
+          if (buffers.empty()) {
+            continue;
+          }
+          new_buffers_length[i] = 0;
+          for (auto buffer : buffers) {
+            for (size_t j{i - 1}; j >= 0; --j) {
+              if (copy_write_buffers[j] == buffer) {
+                old_to_new_write_buffer[buffer] = std::make_pair(i, new_buffers_length[i]);
+                new_buffers_length[i] += get_copy_length(seq_stmt[j]);
+                cycle_counts[i] += get_stmt_cycle_counts(seq_stmt[j]);
+                break;
+              }
+            }
+          }
+          break;
+        }
+      }
+    }
+    return std::move(seq_stmt);
+  }
+
+  Stmt rewrite_prim_func_body(Stmt body) {
+    std::map<const VarNode*, Allocate> var_to_allocate{};
+
+    // Rewrite old allocates
+    std::set<ObjectRef> buffer_vars{get_vars_for_written_copy_buffers()};

Review Comment:
   I switched to unordered_set.



##########
src/tir/contrib/ethosu/passes.cc:
##########
@@ -223,6 +224,572 @@ tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements)
 TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering")
     .set_body_typed(CopyComputeReordering);
 
+/*!
+ * \brief This pass looks for the constants used by each compute operator
+ * and merges them into a single buffer.
+ * Constants written to a buffer with local scope are not merged.
+ */
+class MergeConstantsMutator : public StmtExprMutator {
+ public:
+  MergeConstantsMutator() {}
+
+  PrimFunc operator()(PrimFunc main_func, const Map<IntImm, runtime::NDArray>& const_dict) {
+    // Analyze
+    Stmt new_body{this->VisitStmt(main_func->body)};
+
+    // Rewrite
+    analyze = false;
+    new_body = rewrite_prim_func_body(new_body);
+    std::set<ObjectRef> params_to_delete{};
+    auto new_buffer_map{make_new_buffer_map(main_func->buffer_map, &params_to_delete)};
+    auto new_params{make_new_params(main_func->params, params_to_delete)};
+
+    // Make the new const dict
+    auto args_to_merge{get_args_to_merge(main_func->buffer_map, main_func->params)};
+    auto buffers_to_merge{
+        get_args_to_merge_without_args_not_in_const_dict(args_to_merge, const_dict)};
+    auto new_const_dict{make_new_const_dict(buffers_to_merge, const_dict)};
+
+    // Make the new prim func
+    auto prim_func_node{main_func.CopyOnWrite()};
+    prim_func_node->body = std::move(new_body);
+    prim_func_node->buffer_map = std::move(new_buffer_map);
+    prim_func_node->params = std::move(new_params);
+    prim_func_node->preflattened_buffer_map = {};
+    PrimFunc f{GetRef<PrimFunc>(prim_func_node)};
+
+    // Add the new const dict as an attribute
+    f = WithAttr(std::move(f), "ethos-u.const-dict", new_const_dict);
+
+    return f;
+  }
+
+ private:
+  /*! Indicates whether the pass is analyzing or rewriting */
+  bool analyze = true;
+
+  /*! A stack to store allocates as they are visited. */
+  std::vector<Allocate> allocates{};
+
+  /*! A list that contains in the i-th position the write buffer of the i-th statement
+   * if that statement is a copy to a buffer with global scope  */
+  std::vector<Optional<Buffer>> copy_write_buffers{};
+
+  /*! Maps a copy's write buffer to an index representing the
+   * new buffer and an offset in that buffer */
+  std::map<Buffer, std::pair<int /* new buffer index */, int /* offset */>>
+      old_to_new_write_buffer{};
+
+  /*! Maps an index representing a new buffer to the length of that buffer */
+  std::map<int /* new buffer index */, int /* length */> new_buffers_length{};
+
+  /*! Maps an index representing a new buffer to the new buffer */
+  std::map<int /* new buffer index */, Buffer> new_buffers{};
+
+  /*! Maps an index representing a new buffer to the cycle_counts needed to copy that buffer */
+  std::map<int /* new buffer index */, int64_t> cycle_counts{};
+
+  /*! Maps a copy's read buffer to the new copy's read buffer */
+  std::map<Buffer, Buffer> old_to_new_read_buffers{};
+
+  /*! Maps an index representing a new buffer to the list of buffers to be merged in the new buffer
+   */
+  std::map<int /* new buffer index */, std::vector<Buffer>> buffers_to_merge{};
+
+  /*! A set of buffers to delete */
+  std::set<Buffer> buffers_to_delete{};
+
+  // Visit
+
+  Stmt VisitStmt_(const AllocateNode* op) override {
+    if (analyze) {
+      allocates.push_back(GetRef<Allocate>(op));
+      return VisitStmt(op->body);
+    } else {
+      auto allocate{CopyOnWrite(op)};
+      allocate->body = this->VisitStmt(op->body);
+      return Stmt(allocate);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) override {
+    if (op->size() <= 1) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    return analyze ? analyze_seq_stmt(op) : rewrite_seq_stmt(op);
+  }
+
+  Stmt analyze_seq_stmt(const SeqStmtNode* op) {
+    auto seq_stmt{GetRef<SeqStmt>(op)};
+
+    for (size_t i = 0; i < seq_stmt.size(); ++i) {
+      Stmt stmt{seq_stmt[i]};
+
+      switch (get_stmt_type(stmt)) {
+        case StmtType::global_copy: {
+          Buffer write_buffer{get_copy_write_buffer(stmt)};
+          copy_write_buffers.push_back(write_buffer);
+          old_to_new_write_buffer[write_buffer] = std::make_pair(-1, -1);
+          break;
+        }
+        case StmtType::local_copy: {
+          copy_write_buffers.push_back(Optional<Buffer>{});
+          break;
+        }
+        case StmtType::compute: {
+          copy_write_buffers.push_back(Optional<Buffer>{});
+          auto buffers{get_copied_buffers_used_by_stmt(stmt)};
+          if (buffers.empty()) {
+            continue;
+          }
+          new_buffers_length[i] = 0;
+          for (auto buffer : buffers) {
+            for (size_t j{i - 1}; j >= 0; --j) {
+              if (copy_write_buffers[j] == buffer) {
+                old_to_new_write_buffer[buffer] = std::make_pair(i, new_buffers_length[i]);
+                new_buffers_length[i] += get_copy_length(seq_stmt[j]);
+                cycle_counts[i] += get_stmt_cycle_counts(seq_stmt[j]);
+                break;
+              }
+            }
+          }
+          break;
+        }
+      }
+    }
+    return std::move(seq_stmt);
+  }
+
+  Stmt rewrite_prim_func_body(Stmt body) {
+    std::map<const VarNode*, Allocate> var_to_allocate{};
+
+    // Rewrite old allocates
+    std::set<ObjectRef> buffer_vars{get_vars_for_written_copy_buffers()};
+    for (auto it{allocates.rbegin()}; it != allocates.rend(); ++it) {
+      Allocate alloc{*it};
+      var_to_allocate[alloc->buffer_var.get()] = alloc;
+      if (buffer_vars.count(alloc->buffer_var) == 0) {
+        body = Allocate(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->condition, body,
+                        alloc->annotations, alloc->span);
+      }
+    }
+
+    // Rewrite new allocates
+    for (auto it{copy_write_buffers.rbegin()}; it != copy_write_buffers.rend(); ++it) {
+      if (auto buffer_opt = *it) {
+        Buffer old_write_buffer{buffer_opt.value()};
+        int new_buffer_index{old_to_new_write_buffer[old_write_buffer].first};
+
+        // Check if the allocate has already been created
+        if (new_buffers.count(new_buffer_index) == 0) {
+          BufferNode* new_buffer{old_write_buffer.CopyOnWrite()};
+          new_buffer->shape = {new_buffers_length[new_buffer_index]};
+
+          new_buffers[new_buffer_index] = GetRef<Buffer>(new_buffer);
+
+          auto old_allocate{var_to_allocate[old_write_buffer->data.get()]};
+          body = Allocate(new_buffer->data, new_buffer->dtype, new_buffer->shape, tir::const_true(),
+                          body, old_allocate->annotations, old_allocate->span);
+        }
+      }
+    }
+
+    // Rewrite operators
+    return this->VisitStmt(body);
+  }
+
+  Stmt rewrite_seq_stmt(const SeqStmtNode* op) {
+    Array<Stmt> new_seq{};
+
+    auto seq_stmt{GetRef<SeqStmt>(op)};
+    for (size_t i{0}; i < seq_stmt.size(); ++i) {
+      Stmt stmt{seq_stmt[i]};
+
+      switch (get_stmt_type(stmt)) {

Review Comment:
   It is a switch on an enumerated value, and all cases are covered. So we don't need a default, don't we?



##########
src/tir/contrib/ethosu/passes.cc:
##########
@@ -223,6 +224,572 @@ tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements)
 TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering")
     .set_body_typed(CopyComputeReordering);
 
+/*!
+ * \brief This pass looks for the constants used by each compute operator
+ * and merges them into a single buffer.
+ * Constants written to a buffer with local scope are not merged.
+ */
+class MergeConstantsMutator : public StmtExprMutator {
+ public:
+  MergeConstantsMutator() {}
+
+  PrimFunc operator()(PrimFunc main_func, const Map<IntImm, runtime::NDArray>& const_dict) {
+    // Analyze
+    Stmt new_body{this->VisitStmt(main_func->body)};
+
+    // Rewrite
+    analyze = false;
+    new_body = rewrite_prim_func_body(new_body);
+    std::set<ObjectRef> params_to_delete{};
+    auto new_buffer_map{make_new_buffer_map(main_func->buffer_map, &params_to_delete)};
+    auto new_params{make_new_params(main_func->params, params_to_delete)};
+
+    // Make the new const dict
+    auto args_to_merge{get_args_to_merge(main_func->buffer_map, main_func->params)};
+    auto buffers_to_merge{
+        get_args_to_merge_without_args_not_in_const_dict(args_to_merge, const_dict)};
+    auto new_const_dict{make_new_const_dict(buffers_to_merge, const_dict)};
+
+    // Make the new prim func
+    auto prim_func_node{main_func.CopyOnWrite()};
+    prim_func_node->body = std::move(new_body);
+    prim_func_node->buffer_map = std::move(new_buffer_map);
+    prim_func_node->params = std::move(new_params);
+    prim_func_node->preflattened_buffer_map = {};
+    PrimFunc f{GetRef<PrimFunc>(prim_func_node)};
+
+    // Add the new const dict as an attribute
+    f = WithAttr(std::move(f), "ethos-u.const-dict", new_const_dict);
+
+    return f;
+  }
+
+ private:
+  /*! Indicates whether the pass is analyzing or rewriting */
+  bool analyze = true;
+
+  /*! A stack to store allocates as they are visited. */
+  std::vector<Allocate> allocates{};
+
+  /*! A list that contains in the i-th position the write buffer of the i-th statement
+   * if that statement is a copy to a buffer with global scope  */
+  std::vector<Optional<Buffer>> copy_write_buffers{};
+
+  /*! Maps a copy's write buffer to an index representing the
+   * new buffer and an offset in that buffer */
+  std::map<Buffer, std::pair<int /* new buffer index */, int /* offset */>>
+      old_to_new_write_buffer{};
+
+  /*! Maps an index representing a new buffer to the length of that buffer */
+  std::map<int /* new buffer index */, int /* length */> new_buffers_length{};
+
+  /*! Maps an index representing a new buffer to the new buffer */
+  std::map<int /* new buffer index */, Buffer> new_buffers{};
+
+  /*! Maps an index representing a new buffer to the cycle_counts needed to copy that buffer */
+  std::map<int /* new buffer index */, int64_t> cycle_counts{};
+
+  /*! Maps a copy's read buffer to the new copy's read buffer */
+  std::map<Buffer, Buffer> old_to_new_read_buffers{};
+
+  /*! Maps an index representing a new buffer to the list of buffers to be merged in the new buffer
+   */
+  std::map<int /* new buffer index */, std::vector<Buffer>> buffers_to_merge{};
+
+  /*! A set of buffers to delete */
+  std::set<Buffer> buffers_to_delete{};
+
+  // Visit

Review Comment:
   Done.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org