You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/07/08 09:39:56 UTC

[GitHub] [tvm] manupa-arm commented on a diff in pull request #12029: [microNPU] Add MergeConstants pass

manupa-arm commented on code in PR #12029:
URL: https://github.com/apache/tvm/pull/12029#discussion_r915742054


##########
python/tvm/relay/backend/contrib/ethosu/tir/passes.py:
##########
@@ -938,3 +938,38 @@ def CopyComputeReordering(max_copy_movements: Optional[int] = None) -> tvm.IRMod
         The new module with copy and compute nodes reordered.
     """
     return _ffi_api.CopyComputeReordering(max_copy_movements)
+
+
+def MergeConstants(const_dict):
+    """
+    This pass looks for the constants used by each compute operator
+    and merges them into a single buffer.
+    Constants written to a buffer with local scope are not merged.
+    """
+
+    def mergeConstantsPass(mod):
+        nonlocal const_dict
+        try:
+            mod["main"]
+        except:
+            raise tvm.TVMError(
+                "Expected a single primitive function called 'main'. "
+                "Please run the MergeConstants pass in conjunction with the LowerToTIR() pass."
+            )
+
+        new_const_dict = {}
+        for param in const_dict.keys():
+            new_const_dict[tvm.tir.IntImm("int64", param)] = tvm.nd.array(const_dict[param])
+        mod["main"] = mod["main"].with_attr("ethos-u.const-dict", new_const_dict)

Review Comment:
   nit : lets stick to const_dict



##########
python/tvm/relay/backend/contrib/ethosu/tir/compiler.py:
##########
@@ -90,6 +90,8 @@ def lower_ethosu(sch, args, const_dict, name="main"):
         mod = tvm.tir.transform.RemoveNoOp()(mod)
         mod, const_dict = ethosu_passes.EncodeConstants(const_dict)(mod)
         mod = ethosu_passes.HoistAllocates()(mod)
+        if not util.is_striping_enabled():

Review Comment:
   Please leave a comment that MergeConstant pass currently does not support striped schedules and requires further investigation.



##########
src/tir/contrib/ethosu/passes.cc:
##########
@@ -223,6 +224,572 @@ tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements)
 TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering")
     .set_body_typed(CopyComputeReordering);
 
+/*!
+ * \brief This pass looks for the constants used by each compute operator
+ * and merges them into a single buffer.
+ * Constants written to a buffer with local scope are not merged.
+ */
+class MergeConstantsMutator : public StmtExprMutator {
+ public:
+  MergeConstantsMutator() {}
+
+  PrimFunc operator()(PrimFunc main_func, const Map<IntImm, runtime::NDArray>& const_dict) {
+    // Analyze
+    Stmt new_body{this->VisitStmt(main_func->body)};
+
+    // Rewrite
+    analyze = false;
+    new_body = rewrite_prim_func_body(new_body);
+    std::set<ObjectRef> params_to_delete{};
+    auto new_buffer_map{make_new_buffer_map(main_func->buffer_map, &params_to_delete)};
+    auto new_params{make_new_params(main_func->params, params_to_delete)};
+
+    // Make the new const dict
+    auto args_to_merge{get_args_to_merge(main_func->buffer_map, main_func->params)};
+    auto buffers_to_merge{
+        get_args_to_merge_without_args_not_in_const_dict(args_to_merge, const_dict)};
+    auto new_const_dict{make_new_const_dict(buffers_to_merge, const_dict)};
+
+    // Make the new prim func
+    auto prim_func_node{main_func.CopyOnWrite()};
+    prim_func_node->body = std::move(new_body);
+    prim_func_node->buffer_map = std::move(new_buffer_map);
+    prim_func_node->params = std::move(new_params);
+    prim_func_node->preflattened_buffer_map = {};
+    PrimFunc f{GetRef<PrimFunc>(prim_func_node)};
+
+    // Add the new const dict as an attribute
+    f = WithAttr(std::move(f), "ethos-u.const-dict", new_const_dict);
+
+    return f;
+  }
+
+ private:
+  /*! Indicates whether the pass is analyzing or rewriting */
+  bool analyze = true;
+
+  /*! A stack to store allocates as they are visited. */
+  std::vector<Allocate> allocates{};
+
+  /*! A list that contains in the i-th position the write buffer of the i-th statement
+   * if that statement is a copy to a buffer with global scope  */
+  std::vector<Optional<Buffer>> copy_write_buffers{};
+
+  /*! Maps a copy's write buffer to an index representing the
+   * new buffer and an offset in that buffer */
+  std::map<Buffer, std::pair<int /* new buffer index */, int /* offset */>>
+      old_to_new_write_buffer{};
+
+  /*! Maps an index representing a new buffer to the length of that buffer */
+  std::map<int /* new buffer index */, int /* length */> new_buffers_length{};
+
+  /*! Maps an index representing a new buffer to the new buffer */
+  std::map<int /* new buffer index */, Buffer> new_buffers{};
+
+  /*! Maps an index representing a new buffer to the cycle_counts needed to copy that buffer */
+  std::map<int /* new buffer index */, int64_t> cycle_counts{};
+
+  /*! Maps a copy's read buffer to the new copy's read buffer */
+  std::map<Buffer, Buffer> old_to_new_read_buffers{};
+
+  /*! Maps an index representing a new buffer to the list of buffers to be merged in the new buffer
+   */
+  std::map<int /* new buffer index */, std::vector<Buffer>> buffers_to_merge{};
+
+  /*! A set of buffers to delete */
+  std::set<Buffer> buffers_to_delete{};
+
+  // Visit
+
+  Stmt VisitStmt_(const AllocateNode* op) override {
+    if (analyze) {
+      allocates.push_back(GetRef<Allocate>(op));
+      return VisitStmt(op->body);
+    } else {
+      auto allocate{CopyOnWrite(op)};
+      allocate->body = this->VisitStmt(op->body);
+      return Stmt(allocate);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) override {
+    if (op->size() <= 1) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    return analyze ? analyze_seq_stmt(op) : rewrite_seq_stmt(op);
+  }
+
+  Stmt analyze_seq_stmt(const SeqStmtNode* op) {
+    auto seq_stmt{GetRef<SeqStmt>(op)};
+
+    for (size_t i = 0; i < seq_stmt.size(); ++i) {
+      Stmt stmt{seq_stmt[i]};
+
+      switch (get_stmt_type(stmt)) {
+        case StmtType::global_copy: {
+          Buffer write_buffer{get_copy_write_buffer(stmt)};
+          copy_write_buffers.push_back(write_buffer);
+          old_to_new_write_buffer[write_buffer] = std::make_pair(-1, -1);
+          break;
+        }
+        case StmtType::local_copy: {
+          copy_write_buffers.push_back(Optional<Buffer>{});
+          break;
+        }
+        case StmtType::compute: {
+          copy_write_buffers.push_back(Optional<Buffer>{});
+          auto buffers{get_copied_buffers_used_by_stmt(stmt)};
+          if (buffers.empty()) {
+            continue;
+          }
+          new_buffers_length[i] = 0;
+          for (auto buffer : buffers) {
+            for (size_t j{i - 1}; j >= 0; --j) {
+              if (copy_write_buffers[j] == buffer) {
+                old_to_new_write_buffer[buffer] = std::make_pair(i, new_buffers_length[i]);
+                new_buffers_length[i] += get_copy_length(seq_stmt[j]);
+                cycle_counts[i] += get_stmt_cycle_counts(seq_stmt[j]);
+                break;
+              }
+            }
+          }
+          break;
+        }
+      }
+    }
+    return seq_stmt;
+  }
+
+  Stmt rewrite_prim_func_body(Stmt body) {
+    std::map<const VarNode*, Allocate> var_to_allocate{};
+
+    // Rewrite old allocates
+    std::set<ObjectRef> buffer_vars{get_vars_for_written_copy_buffers()};
+    for (auto it{allocates.rbegin()}; it != allocates.rend(); ++it) {
+      Allocate alloc{*it};
+      var_to_allocate[alloc->buffer_var.get()] = alloc;
+      if (buffer_vars.count(alloc->buffer_var) == 0) {
+        body = Allocate(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->condition, body,
+                        alloc->annotations, alloc->span);
+      }
+    }
+
+    // Rewrite new allocates
+    for (auto it{copy_write_buffers.rbegin()}; it != copy_write_buffers.rend(); ++it) {
+      if (auto buffer_opt = *it) {

Review Comment:
   Lets use the type here.



##########
src/tir/contrib/ethosu/passes.cc:
##########
@@ -223,6 +224,572 @@ tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements)
 TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering")
     .set_body_typed(CopyComputeReordering);
 
+/*!
+ * \brief This pass looks for the constants used by each compute operator
+ * and merges them into a single buffer.
+ * Constants written to a buffer with local scope are not merged.
+ */
+class MergeConstantsMutator : public StmtExprMutator {
+ public:
+  MergeConstantsMutator() {}
+
+  PrimFunc operator()(PrimFunc main_func, const Map<IntImm, runtime::NDArray>& const_dict) {
+    // Analyze
+    Stmt new_body{this->VisitStmt(main_func->body)};
+
+    // Rewrite
+    analyze = false;
+    new_body = rewrite_prim_func_body(new_body);
+    std::set<ObjectRef> params_to_delete{};
+    auto new_buffer_map{make_new_buffer_map(main_func->buffer_map, &params_to_delete)};

Review Comment:
   lets not use auto here.
   it is fine to use auto when you can understand the type from just looking at the line. In this case, one need to look into the function.
   
   Quote from from the following : 
   "For example, you can assume that the return type of make_unique<Foo>() is obvious, but the return type of MyWidgetFactory() probably isn't."
   
   Reference : https://google.github.io/styleguide/cppguide.html#Type_deduction
   
   (Applies to all the lines)



##########
python/tvm/relay/backend/contrib/ethosu/tir/passes.py:
##########
@@ -938,3 +938,38 @@ def CopyComputeReordering(max_copy_movements: Optional[int] = None) -> tvm.IRMod
         The new module with copy and compute nodes reordered.
     """
     return _ffi_api.CopyComputeReordering(max_copy_movements)
+
+
+def MergeConstants(const_dict):
+    """
+    This pass looks for the constants used by each compute operator
+    and merges them into a single buffer.
+    Constants written to a buffer with local scope are not merged.
+    """
+
+    def mergeConstantsPass(mod):

Review Comment:
   Style : this should be lower case. https://peps.python.org/pep-0008/#function-and-variable-names
   
   (Only exception is the Passes because we want them aligned with C++ ones and I think the from API PoV the accessible pass function should be MergeConstants)



##########
src/tir/contrib/ethosu/passes.cc:
##########
@@ -223,6 +224,572 @@ tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements)
 TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering")
     .set_body_typed(CopyComputeReordering);
 
+/*!
+ * \brief This pass looks for the constants used by each compute operator
+ * and merges them into a single buffer.
+ * Constants written to a buffer with local scope are not merged.
+ */
+class MergeConstantsMutator : public StmtExprMutator {
+ public:
+  MergeConstantsMutator() {}
+
+  PrimFunc operator()(PrimFunc main_func, const Map<IntImm, runtime::NDArray>& const_dict) {
+    // Analyze
+    Stmt new_body{this->VisitStmt(main_func->body)};
+
+    // Rewrite
+    analyze = false;
+    new_body = rewrite_prim_func_body(new_body);
+    std::set<ObjectRef> params_to_delete{};
+    auto new_buffer_map{make_new_buffer_map(main_func->buffer_map, &params_to_delete)};
+    auto new_params{make_new_params(main_func->params, params_to_delete)};
+
+    // Make the new const dict
+    auto args_to_merge{get_args_to_merge(main_func->buffer_map, main_func->params)};
+    auto buffers_to_merge{
+        get_args_to_merge_without_args_not_in_const_dict(args_to_merge, const_dict)};
+    auto new_const_dict{make_new_const_dict(buffers_to_merge, const_dict)};
+
+    // Make the new prim func
+    auto prim_func_node{main_func.CopyOnWrite()};
+    prim_func_node->body = std::move(new_body);
+    prim_func_node->buffer_map = std::move(new_buffer_map);
+    prim_func_node->params = std::move(new_params);
+    prim_func_node->preflattened_buffer_map = {};
+    PrimFunc f{GetRef<PrimFunc>(prim_func_node)};
+
+    // Add the new const dict as an attribute
+    f = WithAttr(std::move(f), "ethos-u.const-dict", new_const_dict);
+
+    return f;
+  }
+
+ private:
+  /*! Indicates whether the pass is analyzing or rewriting */
+  bool analyze = true;
+
+  /*! A stack to store allocates as they are visited. */
+  std::vector<Allocate> allocates{};
+
+  /*! A list that contains in the i-th position the write buffer of the i-th statement
+   * if that statement is a copy to a buffer with global scope  */
+  std::vector<Optional<Buffer>> copy_write_buffers{};
+
+  /*! Maps a copy's write buffer to an index representing the
+   * new buffer and an offset in that buffer */
+  std::map<Buffer, std::pair<int /* new buffer index */, int /* offset */>>
+      old_to_new_write_buffer{};
+
+  /*! Maps an index representing a new buffer to the length of that buffer */
+  std::map<int /* new buffer index */, int /* length */> new_buffers_length{};

Review Comment:
   Maps are ordered and therefore access time is not O(constant). Is there a reason we need a ordering in the keys ? 
   If not I'd suggest we switch unordered_maps.



##########
src/tir/contrib/ethosu/passes.cc:
##########
@@ -223,6 +224,572 @@ tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements)
 TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering")
     .set_body_typed(CopyComputeReordering);
 
+/*!
+ * \brief This pass looks for the constants used by each compute operator
+ * and merges them into a single buffer.
+ * Constants written to a buffer with local scope are not merged.
+ */
+class MergeConstantsMutator : public StmtExprMutator {
+ public:
+  MergeConstantsMutator() {}
+
+  PrimFunc operator()(PrimFunc main_func, const Map<IntImm, runtime::NDArray>& const_dict) {
+    // Analyze
+    Stmt new_body{this->VisitStmt(main_func->body)};
+
+    // Rewrite
+    analyze = false;
+    new_body = rewrite_prim_func_body(new_body);
+    std::set<ObjectRef> params_to_delete{};
+    auto new_buffer_map{make_new_buffer_map(main_func->buffer_map, &params_to_delete)};
+    auto new_params{make_new_params(main_func->params, params_to_delete)};
+
+    // Make the new const dict
+    auto args_to_merge{get_args_to_merge(main_func->buffer_map, main_func->params)};
+    auto buffers_to_merge{
+        get_args_to_merge_without_args_not_in_const_dict(args_to_merge, const_dict)};
+    auto new_const_dict{make_new_const_dict(buffers_to_merge, const_dict)};
+
+    // Make the new prim func
+    auto prim_func_node{main_func.CopyOnWrite()};
+    prim_func_node->body = std::move(new_body);
+    prim_func_node->buffer_map = std::move(new_buffer_map);
+    prim_func_node->params = std::move(new_params);
+    prim_func_node->preflattened_buffer_map = {};
+    PrimFunc f{GetRef<PrimFunc>(prim_func_node)};
+
+    // Add the new const dict as an attribute
+    f = WithAttr(std::move(f), "ethos-u.const-dict", new_const_dict);
+
+    return f;
+  }
+
+ private:
+  /*! Indicates whether the pass is analyzing or rewriting */
+  bool analyze = true;
+
+  /*! A stack to store allocates as they are visited. */
+  std::vector<Allocate> allocates{};
+
+  /*! A list that contains in the i-th position the write buffer of the i-th statement
+   * if that statement is a copy to a buffer with global scope  */
+  std::vector<Optional<Buffer>> copy_write_buffers{};
+
+  /*! Maps a copy's write buffer to an index representing the
+   * new buffer and an offset in that buffer */
+  std::map<Buffer, std::pair<int /* new buffer index */, int /* offset */>>
+      old_to_new_write_buffer{};
+
+  /*! Maps an index representing a new buffer to the length of that buffer */
+  std::map<int /* new buffer index */, int /* length */> new_buffers_length{};
+
+  /*! Maps an index representing a new buffer to the new buffer */
+  std::map<int /* new buffer index */, Buffer> new_buffers{};
+
+  /*! Maps an index representing a new buffer to the cycle_counts needed to copy that buffer */
+  std::map<int /* new buffer index */, int64_t> cycle_counts{};
+
+  /*! Maps a copy's read buffer to the new copy's read buffer */
+  std::map<Buffer, Buffer> old_to_new_read_buffers{};
+
+  /*! Maps an index representing a new buffer to the list of buffers to be merged in the new buffer
+   */
+  std::map<int /* new buffer index */, std::vector<Buffer>> buffers_to_merge{};
+
+  /*! A set of buffers to delete */
+  std::set<Buffer> buffers_to_delete{};
+
+  // Visit

Review Comment:
   nit : remove



##########
src/tir/contrib/ethosu/passes.cc:
##########
@@ -223,6 +224,572 @@ tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements)
 TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering")
     .set_body_typed(CopyComputeReordering);
 
+/*!
+ * \brief This pass looks for the constants used by each compute operator
+ * and merges them into a single buffer.
+ * Constants written to a buffer with local scope are not merged.
+ */
+class MergeConstantsMutator : public StmtExprMutator {
+ public:
+  MergeConstantsMutator() {}
+
+  PrimFunc operator()(PrimFunc main_func, const Map<IntImm, runtime::NDArray>& const_dict) {
+    // Analyze
+    Stmt new_body{this->VisitStmt(main_func->body)};
+
+    // Rewrite
+    analyze = false;
+    new_body = rewrite_prim_func_body(new_body);
+    std::set<ObjectRef> params_to_delete{};
+    auto new_buffer_map{make_new_buffer_map(main_func->buffer_map, &params_to_delete)};
+    auto new_params{make_new_params(main_func->params, params_to_delete)};
+
+    // Make the new const dict
+    auto args_to_merge{get_args_to_merge(main_func->buffer_map, main_func->params)};
+    auto buffers_to_merge{
+        get_args_to_merge_without_args_not_in_const_dict(args_to_merge, const_dict)};
+    auto new_const_dict{make_new_const_dict(buffers_to_merge, const_dict)};
+
+    // Make the new prim func
+    auto prim_func_node{main_func.CopyOnWrite()};
+    prim_func_node->body = std::move(new_body);
+    prim_func_node->buffer_map = std::move(new_buffer_map);
+    prim_func_node->params = std::move(new_params);
+    prim_func_node->preflattened_buffer_map = {};
+    PrimFunc f{GetRef<PrimFunc>(prim_func_node)};
+
+    // Add the new const dict as an attribute
+    f = WithAttr(std::move(f), "ethos-u.const-dict", new_const_dict);
+
+    return f;
+  }
+
+ private:
+  /*! Indicates whether the pass is analyzing or rewriting */
+  bool analyze = true;
+
+  /*! A stack to store allocates as they are visited. */
+  std::vector<Allocate> allocates{};
+
+  /*! A list that contains in the i-th position the write buffer of the i-th statement
+   * if that statement is a copy to a buffer with global scope  */
+  std::vector<Optional<Buffer>> copy_write_buffers{};
+
+  /*! Maps a copy's write buffer to an index representing the
+   * new buffer and an offset in that buffer */
+  std::map<Buffer, std::pair<int /* new buffer index */, int /* offset */>>
+      old_to_new_write_buffer{};
+
+  /*! Maps an index representing a new buffer to the length of that buffer */
+  std::map<int /* new buffer index */, int /* length */> new_buffers_length{};
+
+  /*! Maps an index representing a new buffer to the new buffer */
+  std::map<int /* new buffer index */, Buffer> new_buffers{};
+
+  /*! Maps an index representing a new buffer to the cycle_counts needed to copy that buffer */
+  std::map<int /* new buffer index */, int64_t> cycle_counts{};
+
+  /*! Maps a copy's read buffer to the new copy's read buffer */
+  std::map<Buffer, Buffer> old_to_new_read_buffers{};
+
+  /*! Maps an index representing a new buffer to the list of buffers to be merged in the new buffer
+   */
+  std::map<int /* new buffer index */, std::vector<Buffer>> buffers_to_merge{};
+
+  /*! A set of buffers to delete */
+  std::set<Buffer> buffers_to_delete{};
+
+  // Visit
+
+  Stmt VisitStmt_(const AllocateNode* op) override {
+    if (analyze) {
+      allocates.push_back(GetRef<Allocate>(op));
+      return VisitStmt(op->body);
+    } else {
+      auto allocate{CopyOnWrite(op)};
+      allocate->body = this->VisitStmt(op->body);
+      return Stmt(allocate);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) override {
+    if (op->size() <= 1) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    return analyze ? analyze_seq_stmt(op) : rewrite_seq_stmt(op);
+  }
+
+  Stmt analyze_seq_stmt(const SeqStmtNode* op) {
+    auto seq_stmt{GetRef<SeqStmt>(op)};
+
+    for (size_t i = 0; i < seq_stmt.size(); ++i) {
+      Stmt stmt{seq_stmt[i]};
+
+      switch (get_stmt_type(stmt)) {
+        case StmtType::global_copy: {
+          Buffer write_buffer{get_copy_write_buffer(stmt)};
+          copy_write_buffers.push_back(write_buffer);
+          old_to_new_write_buffer[write_buffer] = std::make_pair(-1, -1);
+          break;
+        }
+        case StmtType::local_copy: {
+          copy_write_buffers.push_back(Optional<Buffer>{});
+          break;
+        }
+        case StmtType::compute: {
+          copy_write_buffers.push_back(Optional<Buffer>{});
+          auto buffers{get_copied_buffers_used_by_stmt(stmt)};
+          if (buffers.empty()) {
+            continue;
+          }
+          new_buffers_length[i] = 0;
+          for (auto buffer : buffers) {
+            for (size_t j{i - 1}; j >= 0; --j) {
+              if (copy_write_buffers[j] == buffer) {
+                old_to_new_write_buffer[buffer] = std::make_pair(i, new_buffers_length[i]);
+                new_buffers_length[i] += get_copy_length(seq_stmt[j]);
+                cycle_counts[i] += get_stmt_cycle_counts(seq_stmt[j]);
+                break;
+              }
+            }
+          }
+          break;
+        }
+      }
+    }
+    return std::move(seq_stmt);
+  }
+
+  Stmt rewrite_prim_func_body(Stmt body) {
+    std::map<const VarNode*, Allocate> var_to_allocate{};
+
+    // Rewrite old allocates
+    std::set<ObjectRef> buffer_vars{get_vars_for_written_copy_buffers()};
+    for (auto it{allocates.rbegin()}; it != allocates.rend(); ++it) {
+      Allocate alloc{*it};
+      var_to_allocate[alloc->buffer_var.get()] = alloc;
+      if (buffer_vars.count(alloc->buffer_var) == 0) {
+        body = Allocate(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->condition, body,
+                        alloc->annotations, alloc->span);
+      }
+    }
+
+    // Rewrite new allocates
+    for (auto it{copy_write_buffers.rbegin()}; it != copy_write_buffers.rend(); ++it) {
+      if (auto buffer_opt = *it) {
+        Buffer old_write_buffer{buffer_opt.value()};
+        int new_buffer_index{old_to_new_write_buffer[old_write_buffer].first};
+
+        // Check if the allocate has already been created
+        if (new_buffers.count(new_buffer_index) == 0) {
+          BufferNode* new_buffer{old_write_buffer.CopyOnWrite()};
+          new_buffer->shape = {new_buffers_length[new_buffer_index]};
+
+          new_buffers[new_buffer_index] = GetRef<Buffer>(new_buffer);
+
+          auto old_allocate{var_to_allocate[old_write_buffer->data.get()]};
+          body = Allocate(new_buffer->data, new_buffer->dtype, new_buffer->shape, tir::const_true(),
+                          body, old_allocate->annotations, old_allocate->span);
+        }
+      }
+    }
+
+    // Rewrite operators
+    return this->VisitStmt(body);
+  }
+
+  Stmt rewrite_seq_stmt(const SeqStmtNode* op) {
+    Array<Stmt> new_seq{};
+
+    auto seq_stmt{GetRef<SeqStmt>(op)};
+    for (size_t i{0}; i < seq_stmt.size(); ++i) {
+      Stmt stmt{seq_stmt[i]};
+
+      switch (get_stmt_type(stmt)) {

Review Comment:
   We need a default here.
   Ref : https://google.github.io/styleguide/cppguide.html#Loops_and_Switch_Statements



##########
src/tir/contrib/ethosu/passes.cc:
##########
@@ -223,6 +224,572 @@ tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements)
 TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering")
     .set_body_typed(CopyComputeReordering);
 
+/*!
+ * \brief This pass looks for the constants used by each compute operator
+ * and merges them into a single buffer.
+ * Constants written to a buffer with local scope are not merged.
+ */
+class MergeConstantsMutator : public StmtExprMutator {
+ public:
+  MergeConstantsMutator() {}
+
+  PrimFunc operator()(PrimFunc main_func, const Map<IntImm, runtime::NDArray>& const_dict) {
+    // Analyze
+    Stmt new_body{this->VisitStmt(main_func->body)};
+
+    // Rewrite
+    analyze = false;
+    new_body = rewrite_prim_func_body(new_body);
+    std::set<ObjectRef> params_to_delete{};
+    auto new_buffer_map{make_new_buffer_map(main_func->buffer_map, &params_to_delete)};
+    auto new_params{make_new_params(main_func->params, params_to_delete)};
+
+    // Make the new const dict
+    auto args_to_merge{get_args_to_merge(main_func->buffer_map, main_func->params)};
+    auto buffers_to_merge{
+        get_args_to_merge_without_args_not_in_const_dict(args_to_merge, const_dict)};
+    auto new_const_dict{make_new_const_dict(buffers_to_merge, const_dict)};
+
+    // Make the new prim func
+    auto prim_func_node{main_func.CopyOnWrite()};
+    prim_func_node->body = std::move(new_body);
+    prim_func_node->buffer_map = std::move(new_buffer_map);
+    prim_func_node->params = std::move(new_params);
+    prim_func_node->preflattened_buffer_map = {};
+    PrimFunc f{GetRef<PrimFunc>(prim_func_node)};
+
+    // Add the new const dict as an attribute
+    f = WithAttr(std::move(f), "ethos-u.const-dict", new_const_dict);
+
+    return f;
+  }
+
+ private:
+  /*! Indicates whether the pass is analyzing or rewriting */
+  bool analyze = true;
+
+  /*! A stack to store allocates as they are visited. */
+  std::vector<Allocate> allocates{};
+
+  /*! A list that contains in the i-th position the write buffer of the i-th statement
+   * if that statement is a copy to a buffer with global scope  */
+  std::vector<Optional<Buffer>> copy_write_buffers{};
+
+  /*! Maps a copy's write buffer to an index representing the
+   * new buffer and an offset in that buffer */
+  std::map<Buffer, std::pair<int /* new buffer index */, int /* offset */>>
+      old_to_new_write_buffer{};
+
+  /*! Maps an index representing a new buffer to the length of that buffer */
+  std::map<int /* new buffer index */, int /* length */> new_buffers_length{};
+
+  /*! Maps an index representing a new buffer to the new buffer */
+  std::map<int /* new buffer index */, Buffer> new_buffers{};
+
+  /*! Maps an index representing a new buffer to the cycle_counts needed to copy that buffer */
+  std::map<int /* new buffer index */, int64_t> cycle_counts{};
+
+  /*! Maps a copy's read buffer to the new copy's read buffer */
+  std::map<Buffer, Buffer> old_to_new_read_buffers{};
+
+  /*! Maps an index representing a new buffer to the list of buffers to be merged in the new buffer
+   */
+  std::map<int /* new buffer index */, std::vector<Buffer>> buffers_to_merge{};
+
+  /*! A set of buffers to delete */
+  std::set<Buffer> buffers_to_delete{};
+
+  // Visit
+
+  Stmt VisitStmt_(const AllocateNode* op) override {
+    if (analyze) {
+      allocates.push_back(GetRef<Allocate>(op));
+      return VisitStmt(op->body);
+    } else {
+      auto allocate{CopyOnWrite(op)};
+      allocate->body = this->VisitStmt(op->body);
+      return Stmt(allocate);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) override {
+    if (op->size() <= 1) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    return analyze ? analyze_seq_stmt(op) : rewrite_seq_stmt(op);
+  }
+
+  Stmt analyze_seq_stmt(const SeqStmtNode* op) {
+    auto seq_stmt{GetRef<SeqStmt>(op)};
+
+    for (size_t i = 0; i < seq_stmt.size(); ++i) {
+      Stmt stmt{seq_stmt[i]};
+
+      switch (get_stmt_type(stmt)) {
+        case StmtType::global_copy: {
+          Buffer write_buffer{get_copy_write_buffer(stmt)};
+          copy_write_buffers.push_back(write_buffer);
+          old_to_new_write_buffer[write_buffer] = std::make_pair(-1, -1);
+          break;
+        }
+        case StmtType::local_copy: {
+          copy_write_buffers.push_back(Optional<Buffer>{});
+          break;
+        }
+        case StmtType::compute: {
+          copy_write_buffers.push_back(Optional<Buffer>{});
+          auto buffers{get_copied_buffers_used_by_stmt(stmt)};
+          if (buffers.empty()) {
+            continue;
+          }
+          new_buffers_length[i] = 0;
+          for (auto buffer : buffers) {
+            for (size_t j{i - 1}; j >= 0; --j) {
+              if (copy_write_buffers[j] == buffer) {
+                old_to_new_write_buffer[buffer] = std::make_pair(i, new_buffers_length[i]);
+                new_buffers_length[i] += get_copy_length(seq_stmt[j]);
+                cycle_counts[i] += get_stmt_cycle_counts(seq_stmt[j]);
+                break;
+              }
+            }
+          }
+          break;
+        }
+      }
+    }
+    return std::move(seq_stmt);
+  }
+
+  Stmt rewrite_prim_func_body(Stmt body) {
+    std::map<const VarNode*, Allocate> var_to_allocate{};
+
+    // Rewrite old allocates
+    std::set<ObjectRef> buffer_vars{get_vars_for_written_copy_buffers()};

Review Comment:
   Do we need a ordered set here ? (Noting that ordered_set's access time is higher)



##########
src/tir/contrib/ethosu/passes.cc:
##########
@@ -223,6 +224,572 @@ tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements)
 TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering")
     .set_body_typed(CopyComputeReordering);
 
+/*!
+ * \brief This pass looks for the constants used by each compute operator
+ * and merges them into a single buffer.
+ * Constants written to a buffer with local scope are not merged.
+ */
+class MergeConstantsMutator : public StmtExprMutator {
+ public:
+  MergeConstantsMutator() {}
+
+  PrimFunc operator()(PrimFunc main_func, const Map<IntImm, runtime::NDArray>& const_dict) {
+    // Analyze
+    Stmt new_body{this->VisitStmt(main_func->body)};
+
+    // Rewrite
+    analyze = false;
+    new_body = rewrite_prim_func_body(new_body);
+    std::set<ObjectRef> params_to_delete{};
+    auto new_buffer_map{make_new_buffer_map(main_func->buffer_map, &params_to_delete)};
+    auto new_params{make_new_params(main_func->params, params_to_delete)};
+
+    // Make the new const dict
+    auto args_to_merge{get_args_to_merge(main_func->buffer_map, main_func->params)};
+    auto buffers_to_merge{
+        get_args_to_merge_without_args_not_in_const_dict(args_to_merge, const_dict)};
+    auto new_const_dict{make_new_const_dict(buffers_to_merge, const_dict)};
+
+    // Make the new prim func
+    auto prim_func_node{main_func.CopyOnWrite()};
+    prim_func_node->body = std::move(new_body);
+    prim_func_node->buffer_map = std::move(new_buffer_map);
+    prim_func_node->params = std::move(new_params);
+    prim_func_node->preflattened_buffer_map = {};
+    PrimFunc f{GetRef<PrimFunc>(prim_func_node)};
+
+    // Add the new const dict as an attribute
+    f = WithAttr(std::move(f), "ethos-u.const-dict", new_const_dict);
+
+    return f;
+  }
+
+ private:
+  /*! Indicates whether the pass is analyzing or rewriting */
+  bool analyze = true;
+
+  /*! A stack to store allocates as they are visited. */
+  std::vector<Allocate> allocates{};
+
+  /*! A list that contains in the i-th position the write buffer of the i-th statement
+   * if that statement is a copy to a buffer with global scope  */
+  std::vector<Optional<Buffer>> copy_write_buffers{};
+
+  /*! Maps a copy's write buffer to an index representing the
+   * new buffer and an offset in that buffer */
+  std::map<Buffer, std::pair<int /* new buffer index */, int /* offset */>>
+      old_to_new_write_buffer{};
+
+  /*! Maps an index representing a new buffer to the length of that buffer */
+  std::map<int /* new buffer index */, int /* length */> new_buffers_length{};
+
+  /*! Maps an index representing a new buffer to the new buffer */
+  std::map<int /* new buffer index */, Buffer> new_buffers{};
+
+  /*! Maps an index representing a new buffer to the cycle_counts needed to copy that buffer */
+  std::map<int /* new buffer index */, int64_t> cycle_counts{};
+
+  /*! Maps a copy's read buffer to the new copy's read buffer */
+  std::map<Buffer, Buffer> old_to_new_read_buffers{};
+
+  /*! Maps an index representing a new buffer to the list of buffers to be merged in the new buffer
+   */
+  std::map<int /* new buffer index */, std::vector<Buffer>> buffers_to_merge{};
+
+  /*! A set of buffers to delete */
+  std::set<Buffer> buffers_to_delete{};
+
+  // Visit
+
+  Stmt VisitStmt_(const AllocateNode* op) override {
+    if (analyze) {
+      allocates.push_back(GetRef<Allocate>(op));
+      return VisitStmt(op->body);
+    } else {
+      auto allocate{CopyOnWrite(op)};
+      allocate->body = this->VisitStmt(op->body);
+      return Stmt(allocate);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) override {
+    if (op->size() <= 1) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    return analyze ? analyze_seq_stmt(op) : rewrite_seq_stmt(op);
+  }
+
+  Stmt analyze_seq_stmt(const SeqStmtNode* op) {
+    auto seq_stmt{GetRef<SeqStmt>(op)};
+
+    for (size_t i = 0; i < seq_stmt.size(); ++i) {
+      Stmt stmt{seq_stmt[i]};
+
+      switch (get_stmt_type(stmt)) {

Review Comment:
   We need a default here.
   Ref : https://google.github.io/styleguide/cppguide.html#Loops_and_Switch_Statements



##########
src/tir/contrib/ethosu/passes.cc:
##########
@@ -223,6 +224,572 @@ tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements)
 TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering")
     .set_body_typed(CopyComputeReordering);
 
+/*!
+ * \brief This pass looks for the constants used by each compute operator
+ * and merges them into a single buffer.
+ * Constants written to a buffer with local scope are not merged.
+ */
+class MergeConstantsMutator : public StmtExprMutator {
+ public:
+  MergeConstantsMutator() {}
+
+  PrimFunc operator()(PrimFunc main_func, const Map<IntImm, runtime::NDArray>& const_dict) {
+    // Analyze
+    Stmt new_body{this->VisitStmt(main_func->body)};

Review Comment:
   It is generally a not good practice to mix analysis and transformation in a single pass.
   E.g. for analysis pass : (https://github.com/apache/tvm/blob/main/src/tir/usmp/analysis/extract_buffer_info.cc)
   However, since this is simpler analysis we can keep the analysis pass internal to this source.
   
   Can we break this out to a new pass (in this source file) that return the analysis information and run the sub-sequent pass to rewrite w/o resorting to a mode variable analyze ?
   
   
   



##########
src/tir/contrib/ethosu/passes.cc:
##########
@@ -223,6 +224,572 @@ tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements)
 TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering")
     .set_body_typed(CopyComputeReordering);
 
+/*!
+ * \brief This pass looks for the constants used by each compute operator
+ * and merges them into a single buffer.
+ * Constants written to a buffer with local scope are not merged.
+ */
+class MergeConstantsMutator : public StmtExprMutator {
+ public:
+  MergeConstantsMutator() {}
+
+  PrimFunc operator()(PrimFunc main_func, const Map<IntImm, runtime::NDArray>& const_dict) {
+    // Analyze
+    Stmt new_body{this->VisitStmt(main_func->body)};
+
+    // Rewrite
+    analyze = false;
+    new_body = rewrite_prim_func_body(new_body);
+    std::set<ObjectRef> params_to_delete{};
+    auto new_buffer_map{make_new_buffer_map(main_func->buffer_map, &params_to_delete)};
+    auto new_params{make_new_params(main_func->params, params_to_delete)};
+
+    // Make the new const dict
+    auto args_to_merge{get_args_to_merge(main_func->buffer_map, main_func->params)};
+    auto buffers_to_merge{
+        get_args_to_merge_without_args_not_in_const_dict(args_to_merge, const_dict)};
+    auto new_const_dict{make_new_const_dict(buffers_to_merge, const_dict)};
+
+    // Make the new prim func
+    auto prim_func_node{main_func.CopyOnWrite()};
+    prim_func_node->body = std::move(new_body);
+    prim_func_node->buffer_map = std::move(new_buffer_map);
+    prim_func_node->params = std::move(new_params);
+    prim_func_node->preflattened_buffer_map = {};
+    PrimFunc f{GetRef<PrimFunc>(prim_func_node)};
+
+    // Add the new const dict as an attribute
+    f = WithAttr(std::move(f), "ethos-u.const-dict", new_const_dict);
+
+    return f;
+  }
+
+ private:
+  /*! Indicates whether the pass is analyzing or rewriting */
+  bool analyze = true;
+
+  /*! A stack to store allocates as they are visited. */
+  std::vector<Allocate> allocates{};
+
+  /*! A list that contains in the i-th position the write buffer of the i-th statement
+   * if that statement is a copy to a buffer with global scope  */
+  std::vector<Optional<Buffer>> copy_write_buffers{};
+
+  /*! Maps a copy's write buffer to an index representing the
+   * new buffer and an offset in that buffer */
+  std::map<Buffer, std::pair<int /* new buffer index */, int /* offset */>>
+      old_to_new_write_buffer{};
+
+  /*! Maps an index representing a new buffer to the length of that buffer */
+  std::map<int /* new buffer index */, int /* length */> new_buffers_length{};
+
+  /*! Maps an index representing a new buffer to the new buffer */
+  std::map<int /* new buffer index */, Buffer> new_buffers{};
+
+  /*! Maps an index representing a new buffer to the cycle_counts needed to copy that buffer */
+  std::map<int /* new buffer index */, int64_t> cycle_counts{};
+
+  /*! Maps a copy's read buffer to the new copy's read buffer */
+  std::map<Buffer, Buffer> old_to_new_read_buffers{};
+
+  /*! Maps an index representing a new buffer to the list of buffers to be merged in the new buffer
+   */
+  std::map<int /* new buffer index */, std::vector<Buffer>> buffers_to_merge{};
+
+  /*! A set of buffers to delete */
+  std::set<Buffer> buffers_to_delete{};
+
+  // Visit
+
+  Stmt VisitStmt_(const AllocateNode* op) override {
+    if (analyze) {
+      allocates.push_back(GetRef<Allocate>(op));
+      return VisitStmt(op->body);
+    } else {
+      auto allocate{CopyOnWrite(op)};
+      allocate->body = this->VisitStmt(op->body);
+      return Stmt(allocate);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) override {
+    if (op->size() <= 1) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    return analyze ? analyze_seq_stmt(op) : rewrite_seq_stmt(op);

Review Comment:
   Refer to the above comment about seperation of analysis and transformation passes.



##########
src/tir/contrib/ethosu/passes.cc:
##########
@@ -223,6 +224,572 @@ tvm::transform::Pass CopyComputeReordering(Optional<Integer> max_copy_movements)
 TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.CopyComputeReordering")
     .set_body_typed(CopyComputeReordering);
 
+/*!
+ * \brief This pass looks for the constants used by each compute operator
+ * and merges them into a single buffer.
+ * Constants written to a buffer with local scope are not merged.
+ */
+class MergeConstantsMutator : public StmtExprMutator {
+ public:
+  MergeConstantsMutator() {}
+
+  PrimFunc operator()(PrimFunc main_func, const Map<IntImm, runtime::NDArray>& const_dict) {
+    // Analyze
+    Stmt new_body{this->VisitStmt(main_func->body)};
+
+    // Rewrite
+    analyze = false;
+    new_body = rewrite_prim_func_body(new_body);
+    std::set<ObjectRef> params_to_delete{};
+    auto new_buffer_map{make_new_buffer_map(main_func->buffer_map, &params_to_delete)};
+    auto new_params{make_new_params(main_func->params, params_to_delete)};
+
+    // Make the new const dict
+    auto args_to_merge{get_args_to_merge(main_func->buffer_map, main_func->params)};
+    auto buffers_to_merge{
+        get_args_to_merge_without_args_not_in_const_dict(args_to_merge, const_dict)};
+    auto new_const_dict{make_new_const_dict(buffers_to_merge, const_dict)};
+
+    // Make the new prim func
+    auto prim_func_node{main_func.CopyOnWrite()};
+    prim_func_node->body = std::move(new_body);
+    prim_func_node->buffer_map = std::move(new_buffer_map);
+    prim_func_node->params = std::move(new_params);
+    prim_func_node->preflattened_buffer_map = {};
+    PrimFunc f{GetRef<PrimFunc>(prim_func_node)};
+
+    // Add the new const dict as an attribute
+    f = WithAttr(std::move(f), "ethos-u.const-dict", new_const_dict);
+
+    return f;
+  }
+
+ private:
+  /*! Indicates whether the pass is analyzing or rewriting */
+  bool analyze = true;
+
+  /*! A stack to store allocates as they are visited. */
+  std::vector<Allocate> allocates{};
+
+  /*! A list that contains in the i-th position the write buffer of the i-th statement
+   * if that statement is a copy to a buffer with global scope  */
+  std::vector<Optional<Buffer>> copy_write_buffers{};
+
+  /*! Maps a copy's write buffer to an index representing the
+   * new buffer and an offset in that buffer */
+  std::map<Buffer, std::pair<int /* new buffer index */, int /* offset */>>
+      old_to_new_write_buffer{};
+
+  /*! Maps an index representing a new buffer to the length of that buffer */
+  std::map<int /* new buffer index */, int /* length */> new_buffers_length{};
+
+  /*! Maps an index representing a new buffer to the new buffer */
+  std::map<int /* new buffer index */, Buffer> new_buffers{};
+
+  /*! Maps an index representing a new buffer to the cycle_counts needed to copy that buffer */
+  std::map<int /* new buffer index */, int64_t> cycle_counts{};
+
+  /*! Maps a copy's read buffer to the new copy's read buffer */
+  std::map<Buffer, Buffer> old_to_new_read_buffers{};
+
+  /*! Maps an index representing a new buffer to the list of buffers to be merged in the new buffer
+   */
+  std::map<int /* new buffer index */, std::vector<Buffer>> buffers_to_merge{};
+
+  /*! A set of buffers to delete */
+  std::set<Buffer> buffers_to_delete{};
+
+  // Visit
+
+  Stmt VisitStmt_(const AllocateNode* op) override {
+    if (analyze) {
+      allocates.push_back(GetRef<Allocate>(op));
+      return VisitStmt(op->body);
+    } else {
+      auto allocate{CopyOnWrite(op)};
+      allocate->body = this->VisitStmt(op->body);
+      return Stmt(allocate);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* op) override {
+    if (op->size() <= 1) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    return analyze ? analyze_seq_stmt(op) : rewrite_seq_stmt(op);
+  }
+
+  Stmt analyze_seq_stmt(const SeqStmtNode* op) {

Review Comment:
   Please use Google C++ style : https://google.github.io/styleguide/cppguide.html#Function_Names as per https://tvm.apache.org/docs/contribute/code_guide.html#c-code-styles 
   
   (applies to following)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org