You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/02/21 21:35:59 UTC

[tvm] branch main updated: [TIR] Specialize MutateArray in StmtFunctor. (#7486)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cfe88c1  [TIR] Specialize MutateArray in StmtFunctor. (#7486)
cfe88c1 is described below

commit cfe88c1eee757b49b2837f31f29a79c08101a55c
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Sun Feb 21 16:35:40 2021 -0500

    [TIR] Specialize MutateArray in StmtFunctor. (#7486)
    
    StmtFunctor applies context dependent copy on write,
    which requires check over all the dependency chain.
    Such function is better suited as a special implementation
    to avoid misuse. This PR refactors the code to specialize
    the function.
    
    Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
---
 src/tir/ir/functor_common.h | 15 +++------------
 src/tir/ir/stmt_functor.cc  | 44 ++++++++++++++++++++++++++++++++++++++++----
 2 files changed, 43 insertions(+), 16 deletions(-)

diff --git a/src/tir/ir/functor_common.h b/src/tir/ir/functor_common.h
index f63dcfe..9ed911f 100644
--- a/src/tir/ir/functor_common.h
+++ b/src/tir/ir/functor_common.h
@@ -34,19 +34,10 @@ inline void VisitArray(const Array<T>& arr, F fvisit) {
   }
 }
 
-// Implementation of mutators
 template <typename T, typename F>
-inline Array<T> MutateArray(const Array<T>& arr, F fmutate, bool allow_copy_on_write = false) {
-  if (allow_copy_on_write) {
-    // if we allow copy on write, we can directly
-    // call the inplace mutate function.
-    const_cast<Array<T>&>(arr).MutateByApply(fmutate);
-    return arr;
-  } else {
-    Array<T> copy = arr;
-    copy.MutateByApply(fmutate);
-    return copy;
-  }
+inline Array<T> MutateArray(Array<T> arr, F fmutate) {
+  arr.MutateByApply(fmutate);
+  return arr;
 }
 
 }  // namespace tir
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index e0ccb49..e4cc1b7 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -114,14 +114,50 @@ void StmtVisitor::VisitStmt_(const EvaluateNode* op) { this->VisitExpr(op->value
 
 class StmtMutator::Internal {
  public:
+  /*!
+   * \brief Mutate array's element by fmutate function.
+   *
+   * \note Use extra care for copy on write setting.
+   *
+   * In particular, consider the following case of two reference chains:
+   * - strongref0 -> loop0 -> loop1 -> loop2
+   * - strongref1 -> loop3 -> loop1 -> loop2
+   *
+   * Think of the case of calling MutateArray on loop1->loop2(as const reference).
+   * When both strongref0 and strongref1 exists, the context does not allow copy
+   * on write, even though loop1 uniquely refers to loop2.
+   *
+   * \param self The pointer to the mutator.
+   * \param arr Array to be mutated, const reference is used to allow copy on write
+   *            mutation in a recursive visitor.
+   * \param fmutate The mutator function.
+   * \return The mutated array, a new copy can be created.
+   */
+  template <typename T, typename F>
+  static Array<T> MutateArray(StmtMutator* self, const Array<T>& arr, F fmutate) {
+    if (self->allow_copy_on_write_ && arr.unique()) {
+      // if we allow copy on write, we can directly
+      // call the inplace mutate function.
+      const_cast<Array<T>&>(arr).MutateByApply(fmutate);
+      return arr;
+    } else {
+      bool allow_cow = false;
+      Array<T> copy = arr;
+      std::swap(allow_cow, self->allow_copy_on_write_);
+      copy.MutateByApply(fmutate);
+      std::swap(allow_cow, self->allow_copy_on_write_);
+      return copy;
+    }
+  }
+
   static Array<PrimExpr> Mutate(StmtMutator* self, const Array<PrimExpr>& arr) {
     auto fmutate = [self](const PrimExpr& e) { return self->VisitExpr(e); };
-    return MutateArray(arr, fmutate, self->allow_copy_on_write_);
+    return MutateArray(self, arr, fmutate);
   }
 
   static Array<Stmt> Mutate(StmtMutator* self, const Array<Stmt>& arr) {
     auto fmutate = [self](const Stmt& s) { return self->VisitStmt(s); };
-    return MutateArray(arr, fmutate, self->allow_copy_on_write_);
+    return MutateArray(self, arr, fmutate);
   }
 
   static Array<Range> Mutate(StmtMutator* self, const Array<Range>& arr) {
@@ -134,7 +170,7 @@ class StmtMutator::Internal {
         return Range::FromMinExtent(min, extent);
       }
     };
-    return MutateArray(arr, fmutate, self->allow_copy_on_write_);
+    return MutateArray(self, arr, fmutate);
   }
 };
 
@@ -323,7 +359,7 @@ Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit
   }
   // function to run the visit.
   auto frunvisit = [&](const SeqStmtNode* op) {
-    Array<Stmt> seq = fmutate != nullptr ? MutateArray(op->seq, fmutate, allow_copy_on_write_)
+    Array<Stmt> seq = fmutate != nullptr ? Internal::MutateArray(this, op->seq, fmutate)
                                          : Internal::Mutate(this, op->seq);
     if (seq.same_as(op->seq)) {
       return GetRef<Stmt>(op);