You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ya...@apache.org on 2023/04/06 20:43:04 UTC
[tvm] branch unity updated: [Unity][Graph matching] Improved matching algorithm and implementation (#14501)
This is an automated email from the ASF dual-hosted git repository.
yaxingcai pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new cea447cf37 [Unity][Graph matching] Improved matching algorithm and implementation (#14501)
cea447cf37 is described below
commit cea447cf37176feb52c93853dccbea96289982ac
Author: masahi <ma...@gmail.com>
AuthorDate: Fri Apr 7 05:42:54 2023 +0900
[Unity][Graph matching] Improved matching algorithm and implementation (#14501)
* remove start_hint from MatchGraph
* improve graph matching algorithm
* remove side effect from matching algo
* pylint
* add comments
* add more const now that we can
* cpplint
* fix compile warning
* Update src/relax/ir/dataflow_matcher.cc
Co-authored-by: Jiawei Liu <ja...@gmail.com>
* Update src/relax/ir/dataflow_matcher.cc
Co-authored-by: Jiawei Liu <ja...@gmail.com>
* use insert for merging MatchState
* fix
* parent check is not specific to wildcard
* use map merge
* cpplint
* Pass and check current_match in TryMatch
---------
Co-authored-by: Jiawei Liu <ja...@gmail.com>
---
include/tvm/relax/dataflow_matcher.h | 9 +-
python/tvm/relax/dpl/context.py | 10 +-
src/relax/ir/dataflow_matcher.cc | 294 ++++++++++++++--------------
src/relax/ir/dataflow_pattern.cc | 1 +
tests/python/relax/test_dataflow_pattern.py | 45 -----
5 files changed, 148 insertions(+), 211 deletions(-)
diff --git a/include/tvm/relax/dataflow_matcher.h b/include/tvm/relax/dataflow_matcher.h
index e4268be882..cf7c58f093 100644
--- a/include/tvm/relax/dataflow_matcher.h
+++ b/include/tvm/relax/dataflow_matcher.h
@@ -51,19 +51,12 @@ Optional<Map<DFPattern, Expr>> ExtractMatchedExpr(
/**
* \brief Match a sub-graph in a DataflowBlock with a graph of patterns and return the mapping.
- * \note This algorithm returns the first matched sub-graph. Use `start_hint` to specify the
- * starting point of the matching so that we can distinguish multiple matches.
- *
* \param ctx The graph-wise patterns.
* \param dfb The function to match.
- * \param start_hint The starting point expression to match to distinguish multiple matches.
- * \param must_include_hint If start_hint is given, the return pattern must include start_hint.
* \return Matched patterns and corresponding bound variables
*/
TVM_DLL Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx,
- const DataflowBlock& dfb,
- Optional<Var> start_hint = NullOpt,
- bool must_include_hint = false);
+ const DataflowBlock& dfb);
} // namespace relax
} // namespace tvm
diff --git a/python/tvm/relax/dpl/context.py b/python/tvm/relax/dpl/context.py
index 69a5e70ed0..16d86fb32d 100644
--- a/python/tvm/relax/dpl/context.py
+++ b/python/tvm/relax/dpl/context.py
@@ -17,7 +17,7 @@
"""The Graph Matching Context Manager for Dataflow Pattern Language."""
-from typing import Optional, Dict
+from typing import Dict
import tvm
from ..expr import DataflowBlock, Var
@@ -63,8 +63,6 @@ class PatternContext(tvm.runtime.Object):
def match_dfb(
self,
dfb: DataflowBlock,
- start_hint: Optional[Var] = None,
- must_include_hint: bool = False,
) -> Dict[DFPattern, Var]:
"""
Match a DataflowBlock via a graph of DFPattern and corresponding constraints
@@ -73,14 +71,10 @@ class PatternContext(tvm.runtime.Object):
----------
dfb : DataflowBlock
The DataflowBlock to match
- start_hint : Optional[Var], optional
- Indicating the starting expression to match, by default None
- must_include_hint : bool, optional
- Whether the start_hint expression must be matched, by default False
Returns
-------
Dict[DFPattern, Var]
The mapping from DFPattern to matched expression
"""
- return ffi.match_dfb(self, dfb, start_hint, must_include_hint) # type: ignore
+ return ffi.match_dfb(self, dfb) # type: ignore
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index c1306ff690..6e8211cfd3 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -523,109 +523,114 @@ bool MatchExpr(DFPattern pattern, Expr expr, Optional<Map<Var, Expr>> bindings_o
TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr);
+class MatcherUseDefAnalysis : public relax::ExprVisitor {
+ public:
+ std::vector<const VarNode*> vars;
+ std::map<const VarNode*, std::vector<const VarNode*>> def2use;
+ // caller -> callee table.
+ std::map<const VarNode*, std::vector<const VarNode*>> caller2callees;
+
+ const VarNode* cur_user_;
+
+ void VisitBinding_(const VarBindingNode* binding) override {
+ // init
+ cur_user_ = binding->var.get();
+ this->VisitVarDef(binding->var);
+ this->VisitExpr(binding->value);
+ cur_user_ = nullptr;
+ }
+
+ void VisitExpr_(const VarNode* op) override {
+ if (nullptr == cur_user_) return;
+
+ auto check_and_push = [](std::vector<const VarNode*>& vec, const VarNode* var) {
+ if (std::find(vec.begin(), vec.end(), var) == vec.end()) {
+ vec.push_back(var);
+ }
+ };
+
+ check_and_push(def2use[op], cur_user_);
+ check_and_push(vars, op);
+
+ caller2callees[cur_user_].push_back(op);
+ }
+
+ void VisitExpr_(const DataflowVarNode* op) override {
+ VisitExpr_(static_cast<const VarNode*>(op));
+ }
+};
+
struct PNode {
const DFPatternNode* ptr;
- const VarNode* matched = nullptr;
std::vector<std::pair<PNode*, const std::vector<PairCons>&>> children;
std::vector<std::pair<PNode*, const std::vector<PairCons>&>> parents;
};
struct RNode {
const VarNode* ptr;
- const DFPatternNode* matched = nullptr;
std::vector<RNode*> children;
std::vector<RNode*> parents;
};
-/**
- * \brief This method try to match a real node and a pattern node along with its neighbors.
- */
-using UndoItems = std::vector<std::pair<PNode*, RNode*>>;
-static std::optional<UndoItems> try_match(
- PNode* p, RNode* r, DFPatternMatcher* m,
- const std::map<const VarNode*, std::vector<const VarNode*>>& def2use,
- const std::map<const VarNode*, std::vector<const VarNode*>>& use2def) {
- if (p->matched != nullptr && p->matched == r->ptr) return {}; // matched before.
- if (!m->Match(GetRef<DFPattern>(p->ptr), GetRef<Var>(r->ptr))) return std::nullopt;
-
- UndoItems undo;
-
- const auto commit = [&undo](PNode* p, RNode* r) {
- // match with each other.
- // TODO(ganler, masahi): Why commit on the same p-r pair happens more than once?
- if (p->ptr == r->matched) {
- ICHECK_EQ(p->matched, r->ptr);
- return;
- }
- p->matched = r->ptr;
- r->matched = p->ptr;
- undo.emplace_back(p, r);
- };
+struct MatchState {
+ void add(const PNode* p, const RNode* r) {
+ match_p_r[p] = r;
+ match_r_p[r] = p;
+ }
- const auto quit = [&undo] {
- for (auto& [p_node, r_node] : undo) {
- p_node->matched = nullptr;
- r_node->matched = nullptr;
- }
- return std::nullopt;
- };
+ void add(MatchState&& other) {
+ match_p_r.merge(std::move(other.match_p_r));
+ match_r_p.merge(std::move(other.match_r_p));
+ }
- const auto try_match_update_undo = [&](PNode* p, RNode* r) {
- if (auto undo_more = try_match(p, r, m, def2use, use2def)) {
- undo.insert(undo.end(), undo_more->begin(), undo_more->end());
- return true;
+ const VarNode* matched(const PNode* p) const {
+ if (auto it = match_p_r.find(p); it != match_p_r.end()) {
+ return it->second->ptr;
}
- return false;
- };
+ return nullptr;
+ }
- commit(p, r);
+ const DFPatternNode* matched(const RNode* r) const {
+ if (auto it = match_r_p.find(r); it != match_r_p.end()) {
+ return it->second->ptr;
+ }
+ return nullptr;
+ }
- // match parent patterns.
- for (auto& [pparent, constraints] : p->parents) {
- bool any_cons_sat = false;
- for (auto& rparent : r->parents) {
- // skip if mismatch.
- if (rparent->matched && rparent->matched != pparent->ptr) continue;
+ const VarNode* matched(const PNode& p) const { return matched(&p); }
+ const DFPatternNode* matched(const RNode& r) const { return matched(&r); }
- const auto& uses = def2use.at(rparent->ptr);
+ private:
+ std::unordered_map<const PNode*, const RNode*> match_p_r;
+ std::unordered_map<const RNode*, const PNode*> match_r_p;
+};
- // check edge constraints.
- bool cons_sat = true;
- for (const auto& cons : constraints) {
- if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) {
- cons_sat = false;
- break;
- }
+/**
+ * \brief This method try to match a real node and a pattern node along with its neighbors.
+ */
+static std::optional<MatchState> TryMatch(const PNode& p, const RNode& r,
+ const MatchState& current_match, DFPatternMatcher* m,
+ const MatcherUseDefAnalysis& ud_analysis) {
+ if (!m->Match(GetRef<DFPattern>(p.ptr), GetRef<Var>(r.ptr))) return std::nullopt;
- if (cons.index != -1) {
- const auto& callees = use2def.at(r->ptr);
- if (callees.size() <= static_cast<size_t>(cons.index) ||
- callees[cons.index] != rparent->ptr) {
- cons_sat = false;
- break;
- }
- }
- }
- if (!cons_sat) continue;
- any_cons_sat = true;
+ MatchState new_match;
- // try all parent R nodes that are not matched yet.
- // as long as ppattern can match one node.
- if (!pparent->matched && try_match_update_undo(pparent, rparent)) {
- commit(pparent, rparent);
- break;
- }
- }
- if (!pparent->matched || !any_cons_sat) return quit();
- }
+ new_match.add(&p, &r);
// forward matching;
- for (auto& [pchild, constraints] : p->children) {
+ for (const auto& [pchild, constraints] : p.children) {
bool any_cons_sat = false;
- for (auto& rchild : r->children) {
- if (rchild->matched && rchild->matched != pchild->ptr) continue;
+ for (const auto& rchild : r.children) {
+ if (new_match.matched(rchild)) {
+ // The child variable is already matched to other child pattern in a previous iteration.
+ continue;
+ }
+ if (auto v = current_match.matched(pchild); v && v != rchild->ptr) {
+ // The child pattern is already matched to other variable in a earlier call to TryMatch.
+ continue;
+ }
- const auto& uses = def2use.at(r->ptr);
+ const auto& uses = ud_analysis.def2use.at(r.ptr);
// check edge constraints.
bool all_cons_pass = true;
@@ -636,88 +641,87 @@ static std::optional<UndoItems> try_match(
}
if (cons.index != -1) {
- const auto& callees = use2def.at(rchild->ptr);
- if (callees.size() <= static_cast<size_t>(cons.index) || callees[cons.index] != r->ptr) {
+ const auto& callees = ud_analysis.caller2callees.at(rchild->ptr);
+ if (callees.size() <= static_cast<size_t>(cons.index) || callees[cons.index] != r.ptr) {
all_cons_pass = false;
break;
}
}
}
- if (!all_cons_pass) continue;
+ if (!all_cons_pass || new_match.matched(pchild)) continue;
any_cons_sat = true;
- if (!pchild->matched && try_match_update_undo(pchild, rchild)) {
- commit(pchild, rchild);
- break;
+ if (auto match_rec = TryMatch(*pchild, *rchild, current_match, m, ud_analysis)) {
+ new_match.add(pchild, rchild);
+ new_match.add(std::move(*match_rec));
}
}
- if (!pchild->matched || !any_cons_sat) return quit();
+ if (!new_match.matched(pchild) || !any_cons_sat) return std::nullopt;
}
- return undo;
+
+ return new_match;
}
-class MatcherUseDefAnalysis : public relax::ExprVisitor {
- public:
- std::vector<const VarNode*> vars;
- std::map<const VarNode*, std::vector<const VarNode*>> def2use;
- // caller -> callee table.
- std::map<const VarNode*, std::vector<const VarNode*>> caller2callees;
+static std::optional<MatchState> MatchTree(
+ const MatchState& current_match, size_t current_root_idx,
+ const std::unordered_map<const DFPatternNode*, PNode>& pattern2node,
+ const std::unordered_map<const VarNode*, RNode>& var2node, DFPatternMatcher* matcher,
+ const std::vector<DFPattern>& roots, const MatcherUseDefAnalysis& ud_analysis) {
+ auto get_next_root = [&](size_t root_idx) -> const PNode* {
+ // Look for the next unmatched root node.
+ for (; root_idx < roots.size(); ++root_idx) {
+ const auto& root = pattern2node.at(roots[root_idx].get());
+ if (!current_match.matched(root)) {
+ return &root;
+ }
+ }
+ return nullptr;
+ };
- const VarNode* cur_user_;
+ const auto root = get_next_root(current_root_idx);
- void VisitBinding_(const VarBindingNode* binding) override {
- // init
- cur_user_ = binding->var.get();
- this->VisitVarDef(binding->var);
- this->VisitExpr(binding->value);
- cur_user_ = nullptr;
+ if (!root) {
+ // All root nodes have been matched
+ return current_match;
}
- void VisitExpr_(const VarNode* op) override {
- if (nullptr == cur_user_) return;
+ MatchState new_match = current_match;
- auto check_and_push = [](std::vector<const VarNode*>& vec, const VarNode* var) {
- if (std::find(vec.begin(), vec.end(), var) == vec.end()) {
- vec.push_back(var);
+ for (const auto& var : ud_analysis.vars) {
+ const RNode& r_node = var2node.at(var);
+ if (new_match.matched(r_node)) continue;
+ if (auto match = TryMatch(*root, r_node, new_match, matcher, ud_analysis)) {
+ // Recursivly try to match the next subtree.
+ new_match.add(std::move(*match));
+ if (auto match_rec = MatchTree(new_match, current_root_idx + 1, pattern2node, var2node,
+ matcher, roots, ud_analysis)) {
+ new_match.add(std::move(*match_rec));
+ return new_match;
}
- };
-
- check_and_push(def2use[op], cur_user_);
- check_and_push(vars, op);
-
- caller2callees[cur_user_].push_back(op);
- }
-
- void VisitExpr_(const DataflowVarNode* op) override {
- VisitExpr_(static_cast<const VarNode*>(op));
+ // Recursive matching has failed, backtrack.
+ continue;
+ }
}
-};
-Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb,
- Optional<Var> start_hint, bool must_include_hint) {
- if (ctx->src_ordered.size() == 0) {
- return NullOpt;
- }
+ return std::nullopt;
+}
+Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb) {
// TODO(@ganler): Handle non-may external use.
ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet.";
- ICHECK(!must_include_hint || start_hint.defined())
- << "must_include_hint is only supported with start_hint.";
const auto var2val = AnalyzeVar2Value(dfb);
DFPatternMatcher matcher(var2val);
MatcherUseDefAnalysis ud_analysis;
ud_analysis.VisitBindingBlock_(dfb.get());
- const auto& def2use = ud_analysis.def2use;
- const auto& caller2callees = ud_analysis.caller2callees;
// First construct a graph of PNode and RNode.
std::unordered_map<const VarNode*, RNode> var2node;
var2node.reserve(dfb->bindings.size());
for (const VarNode* cur_var : ud_analysis.vars) {
- const auto& uses = def2use.at(cur_var);
+ const auto& uses = ud_analysis.def2use.at(cur_var);
RNode& cur_node = var2node[cur_var];
cur_node.ptr = cur_var;
for (const VarNode* use : uses) {
@@ -731,8 +735,9 @@ Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx, const Datafl
std::unordered_map<const DFPatternNode*, PNode> pattern2node;
pattern2node.reserve(ctx->constraints.size());
- for (const auto& [def_pattern, uses] : ctx->constraints) {
+ for (const auto& def_pattern : ctx->src_ordered) {
PNode& def_node = pattern2node[def_pattern.get()];
+ const auto& uses = ctx->constraints.at(def_pattern);
def_node.ptr = def_pattern.get();
def_node.children.reserve(uses.size());
for (const auto& [use_pattern, cons] : uses) {
@@ -743,35 +748,24 @@ Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx, const Datafl
}
}
- Map<DFPattern, Var> ret;
-
- if (start_hint) {
- auto rnode_ptr = var2node.at(start_hint.value().get());
- for (auto& p_node : pattern2node) {
- if (try_match(&p_node.second, &rnode_ptr, &matcher, def2use, caller2callees)) {
- for (const auto& [df_pattern, pattern_node] : pattern2node) {
- ret.Set(GetRef<DFPattern>(df_pattern), GetRef<Var>(pattern_node.matched));
- }
- return ret;
- }
+ std::vector<DFPattern> roots;
+ for (const auto& pat : ctx->src_ordered) {
+ if (pattern2node[pat.get()].parents.empty()) {
+ roots.push_back(pat);
}
-
- if (must_include_hint) return ret;
}
- PNode& pnode_start = pattern2node[ctx->src_ordered[0].get()];
+ if (roots.empty()) {
+ return NullOpt;
+ }
- if (!pnode_start.matched) {
- for (const auto& var : ud_analysis.vars) {
- if (start_hint.defined() && start_hint.value().get() == var) continue;
- RNode& r_node = var2node[var];
- if (try_match(&pnode_start, &r_node, &matcher, def2use, caller2callees)) {
- for (const auto& [df_pattern, pattern_node] : pattern2node) {
- ret.Set(GetRef<DFPattern>(df_pattern), GetRef<Var>(pattern_node.matched));
- }
- return ret;
- }
+ if (auto match = MatchTree({}, 0, pattern2node, var2node, &matcher, roots, ud_analysis)) {
+ Map<DFPattern, Var> ret;
+ for (const auto& [pat, p_node] : pattern2node) {
+ ICHECK(match->matched(p_node));
+ ret.Set(GetRef<DFPattern>(pat), GetRef<Var>(match->matched(p_node)));
}
+ return ret;
}
return NullOpt;
diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc
index 5580f6a1ab..4d225ceecf 100644
--- a/src/relax/ir/dataflow_pattern.cc
+++ b/src/relax/ir/dataflow_pattern.cc
@@ -406,6 +406,7 @@ PatternContext::PatternContext(bool incremental) {
<< "Incremental context needs to be built inside a existing context.";
n->allow_extern_use = pattern_ctx_stack().top()->allow_extern_use;
n->constraints = pattern_ctx_stack().top()->constraints;
+ n->src_ordered = pattern_ctx_stack().top()->src_ordered;
}
data_ = std::move(n);
diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py
index b85543cafc..a73a62eeef 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -519,51 +519,6 @@ class CBRx2:
return lv6
-def test_single_cbr():
- with PatternContext() as ctx:
- (
- is_call_dps_packed("conv1x1")
- >> is_call_dps_packed("bias_add")
- >> is_call_dps_packed("my_relu")
- )
- dfb = CBRx2["main"].body.blocks[0]
- matched = ctx.match_dfb(dfb)
- assert matched
-
- with PatternContext() as ctx:
- chain = (
- is_call_dps_packed("conv1x1")
- >> is_call_dps_packed("bias_add")
- >> is_call_dps_packed("my_relu")
- )
- dfb = CBRx2["main"].body.blocks[0]
- # we want to specifically match the first CBR (lv0)
- matched = ctx.match_dfb(dfb, start_hint=dfb.bindings[0].var)
- assert matched
- assert matched[chain[0]] == dfb.bindings[0].var
- # we want to specifically match the second CBR (lv3)
- matched = ctx.match_dfb(dfb, start_hint=dfb.bindings[3].var)
- assert matched
- assert matched[chain[0]] == dfb.bindings[3].var
-
-
-def test_counter_single_crb():
- with PatternContext() as ctx:
- (
- is_call_dps_packed("conv1x1")
- >> is_call_dps_packed("my_relu")
- >> is_call_dps_packed("bias_add")
- )
- dfb = CBRx2["main"].body.blocks[0]
- assert not ctx.match_dfb(dfb)
- # Quickly fails unpromising matches by assuming `start_hint` must be matched by a pattern.
- # This is usually faster than the full match:
- # Full match: let one pattern to match -> all Var: complexity ~ #Var
- # must_include_hint: let `start_hint` to match -> all patterns: complexity ~ #patterns
- # Usually #patterns is much smaller than #Var, so this is faster.
- assert not ctx.match_dfb(dfb, start_hint=dfb.bindings[0].var, must_include_hint=True)
-
-
def test_nested_context():
dfb = CBRx2["main"].body.blocks[0]
with PatternContext() as ctx0: