You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/04/12 15:06:21 UTC

[GitHub] [tvm] NicolaLancellotti commented on a diff in pull request #10959: [microNPU] Add a pass to reorder copy and compute nodes

NicolaLancellotti commented on code in PR #10959:
URL: https://github.com/apache/tvm/pull/10959#discussion_r848549306


##########
src/tir/contrib/ethosu/passes.cc:
##########
@@ -110,6 +110,71 @@ tvm::transform::Pass HoistAllocates() {
 
 TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.HoistAllocates").set_body_typed(HoistAllocates);
 
+/*!
+ * \brief Reorders copy and compute nodes in such a way that independent DMA copies,
+ * and computes happen in parallel.
+ */
+class CopyComputeReorderingMutator : public StmtExprMutator {
+ public:
+  CopyComputeReorderingMutator() {}
+
+  PrimFunc operator()(PrimFunc main_func) {
+    auto n{main_func.CopyOnWrite()};
+    n->body = this->VisitStmt(main_func->body);
+    return GetRef<PrimFunc>(n);
+  }
+
+ private:
+  Stmt VisitStmt_(const SeqStmtNode* op) override {
+    if (op->size() <= 1) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+
+    auto seq_stmt{GetRef<SeqStmt>(op)};
+    std::vector<Stmt> new_seq(seq_stmt->size());
+    std::copy(seq_stmt->seq.begin(), seq_stmt->seq.end(), new_seq.begin());
+    bool previous_stmt_is_copy{true};  // Do not move the first stmt if it is a copy
+
+    for (size_t i{}; i < seq_stmt->size(); ++i) {
+      auto stmt{seq_stmt[i]};

Review Comment:
   Done.



##########
src/tir/contrib/ethosu/passes.cc:
##########
@@ -110,6 +110,71 @@ tvm::transform::Pass HoistAllocates() {
 
 TVM_REGISTER_GLOBAL("tir.contrib.ethos-u.HoistAllocates").set_body_typed(HoistAllocates);
 
+/*!
+ * \brief Reorders copy and compute nodes in such a way that independent DMA copies,
+ * and computes happen in parallel.
+ */
+class CopyComputeReorderingMutator : public StmtExprMutator {
+ public:
+  CopyComputeReorderingMutator() {}
+
+  PrimFunc operator()(PrimFunc main_func) {
+    auto n{main_func.CopyOnWrite()};
+    n->body = this->VisitStmt(main_func->body);
+    return GetRef<PrimFunc>(n);
+  }
+
+ private:
+  Stmt VisitStmt_(const SeqStmtNode* op) override {
+    if (op->size() <= 1) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+
+    auto seq_stmt{GetRef<SeqStmt>(op)};
+    std::vector<Stmt> new_seq(seq_stmt->size());
+    std::copy(seq_stmt->seq.begin(), seq_stmt->seq.end(), new_seq.begin());
+    bool previous_stmt_is_copy{true};  // Do not move the first stmt if it is a copy
+
+    for (size_t i{}; i < seq_stmt->size(); ++i) {
+      auto stmt{seq_stmt[i]};
+      auto args{stmt.as<EvaluateNode>()->value.as<CallNode>()->args};

Review Comment:
   Done.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org