You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/02/21 21:35:59 UTC
[tvm] branch main updated: [TIR] Specialize MutateArray in
StmtFunctor. (#7486)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new cfe88c1 [TIR] Specialize MutateArray in StmtFunctor. (#7486)
cfe88c1 is described below
commit cfe88c1eee757b49b2837f31f29a79c08101a55c
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Sun Feb 21 16:35:40 2021 -0500
[TIR] Specialize MutateArray in StmtFunctor. (#7486)
StmtFunctor applies context dependent copy on write,
which requires check over all the dependency chain.
Such function is better suited as a special implementation
to avoid misuse. This PR refactors the code to specialize
the function.
Co-authored-by: Bohan Hou <32...@users.noreply.github.com>
---
src/tir/ir/functor_common.h | 15 +++------------
src/tir/ir/stmt_functor.cc | 44 ++++++++++++++++++++++++++++++++++++++++----
2 files changed, 43 insertions(+), 16 deletions(-)
diff --git a/src/tir/ir/functor_common.h b/src/tir/ir/functor_common.h
index f63dcfe..9ed911f 100644
--- a/src/tir/ir/functor_common.h
+++ b/src/tir/ir/functor_common.h
@@ -34,19 +34,10 @@ inline void VisitArray(const Array<T>& arr, F fvisit) {
}
}
-// Implementation of mutators
template <typename T, typename F>
-inline Array<T> MutateArray(const Array<T>& arr, F fmutate, bool allow_copy_on_write = false) {
- if (allow_copy_on_write) {
- // if we allow copy on write, we can directly
- // call the inplace mutate function.
- const_cast<Array<T>&>(arr).MutateByApply(fmutate);
- return arr;
- } else {
- Array<T> copy = arr;
- copy.MutateByApply(fmutate);
- return copy;
- }
+inline Array<T> MutateArray(Array<T> arr, F fmutate) {
+ arr.MutateByApply(fmutate);
+ return arr;
}
} // namespace tir
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index e0ccb49..e4cc1b7 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -114,14 +114,50 @@ void StmtVisitor::VisitStmt_(const EvaluateNode* op) { this->VisitExpr(op->value
class StmtMutator::Internal {
public:
+ /*!
+ * \brief Mutate array's element by fmutate function.
+ *
+ * \note Use extra care for copy on write setting.
+ *
+ * In particular, consider the following case of two reference chains:
+ * - strongref0 -> loop0 -> loop1 -> loop2
+ * - strongref1 -> loop3 -> loop1 -> loop2
+ *
+ * Think of the case of calling MutateArray on loop1->loop2(as const reference).
+ * When both strongref0 and strongref1 exists, the context does not allow copy
+ * on write, even though loop1 uniquely refers to loop2.
+ *
+ * \param self The pointer to the mutator.
+ * \param arr Array to be mutated, const reference is used to allow copy on write
+ * mutation in a recursive visitor.
+ * \param fmutate The mutator function.
+ * \return The mutated array, a new copy can be created.
+ */
+ template <typename T, typename F>
+ static Array<T> MutateArray(StmtMutator* self, const Array<T>& arr, F fmutate) {
+ if (self->allow_copy_on_write_ && arr.unique()) {
+ // if we allow copy on write, we can directly
+ // call the inplace mutate function.
+ const_cast<Array<T>&>(arr).MutateByApply(fmutate);
+ return arr;
+ } else {
+ bool allow_cow = false;
+ Array<T> copy = arr;
+ std::swap(allow_cow, self->allow_copy_on_write_);
+ copy.MutateByApply(fmutate);
+ std::swap(allow_cow, self->allow_copy_on_write_);
+ return copy;
+ }
+ }
+
static Array<PrimExpr> Mutate(StmtMutator* self, const Array<PrimExpr>& arr) {
auto fmutate = [self](const PrimExpr& e) { return self->VisitExpr(e); };
- return MutateArray(arr, fmutate, self->allow_copy_on_write_);
+ return MutateArray(self, arr, fmutate);
}
static Array<Stmt> Mutate(StmtMutator* self, const Array<Stmt>& arr) {
auto fmutate = [self](const Stmt& s) { return self->VisitStmt(s); };
- return MutateArray(arr, fmutate, self->allow_copy_on_write_);
+ return MutateArray(self, arr, fmutate);
}
static Array<Range> Mutate(StmtMutator* self, const Array<Range>& arr) {
@@ -134,7 +170,7 @@ class StmtMutator::Internal {
return Range::FromMinExtent(min, extent);
}
};
- return MutateArray(arr, fmutate, self->allow_copy_on_write_);
+ return MutateArray(self, arr, fmutate);
}
};
@@ -323,7 +359,7 @@ Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit
}
// function to run the visit.
auto frunvisit = [&](const SeqStmtNode* op) {
- Array<Stmt> seq = fmutate != nullptr ? MutateArray(op->seq, fmutate, allow_copy_on_write_)
+ Array<Stmt> seq = fmutate != nullptr ? Internal::MutateArray(this, op->seq, fmutate)
: Internal::Mutate(this, op->seq);
if (seq.same_as(op->seq)) {
return GetRef<Stmt>(op);