You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ya...@apache.org on 2023/04/06 20:43:04 UTC

[tvm] branch unity updated: [Unity][Graph matching] Improved matching algorithm and implementation (#14501)

This is an automated email from the ASF dual-hosted git repository.

yaxingcai pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new cea447cf37 [Unity][Graph matching] Improved matching algorithm and implementation (#14501)
cea447cf37 is described below

commit cea447cf37176feb52c93853dccbea96289982ac
Author: masahi <ma...@gmail.com>
AuthorDate: Fri Apr 7 05:42:54 2023 +0900

    [Unity][Graph matching] Improved matching algorithm and implementation (#14501)
    
    * remove start_hint from MatchGraph
    
    * improve graph matching algorithm
    
    * remove side effect from matching algo
    
    * pylint
    
    * add comments
    
    * add more const now that we can
    
    * cpplint
    
    * fix compile warning
    
    * Update src/relax/ir/dataflow_matcher.cc
    
    Co-authored-by: Jiawei Liu <ja...@gmail.com>
    
    * Update src/relax/ir/dataflow_matcher.cc
    
    Co-authored-by: Jiawei Liu <ja...@gmail.com>
    
    * use insert for merging MatchState
    
    * fix
    
    * parent check is not specific to wildcard
    
    * use map merge
    
    * cpplint
    
    * Pass and check current_match in TryMatch
    
    ---------
    
    Co-authored-by: Jiawei Liu <ja...@gmail.com>
---
 include/tvm/relax/dataflow_matcher.h        |   9 +-
 python/tvm/relax/dpl/context.py             |  10 +-
 src/relax/ir/dataflow_matcher.cc            | 294 ++++++++++++++--------------
 src/relax/ir/dataflow_pattern.cc            |   1 +
 tests/python/relax/test_dataflow_pattern.py |  45 -----
 5 files changed, 148 insertions(+), 211 deletions(-)

diff --git a/include/tvm/relax/dataflow_matcher.h b/include/tvm/relax/dataflow_matcher.h
index e4268be882..cf7c58f093 100644
--- a/include/tvm/relax/dataflow_matcher.h
+++ b/include/tvm/relax/dataflow_matcher.h
@@ -51,19 +51,12 @@ Optional<Map<DFPattern, Expr>> ExtractMatchedExpr(
 
 /**
  * \brief Match a sub-graph in a DataflowBlock with a graph of patterns and return the mapping.
- * \note This algorithm returns the first matched sub-graph. Use `start_hint` to specify the
- * starting point of the matching so that we can distinguish multiple matches.
- *
  * \param ctx The graph-wise patterns.
  * \param dfb The function to match.
- * \param start_hint The starting point expression to match to distinguish multiple matches.
- * \param must_include_hint If start_hint is given, the return pattern must include start_hint.
  * \return Matched patterns and corresponding bound variables
  */
 TVM_DLL Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx,
-                                                 const DataflowBlock& dfb,
-                                                 Optional<Var> start_hint = NullOpt,
-                                                 bool must_include_hint = false);
+                                                 const DataflowBlock& dfb);
 
 }  // namespace relax
 }  // namespace tvm
diff --git a/python/tvm/relax/dpl/context.py b/python/tvm/relax/dpl/context.py
index 69a5e70ed0..16d86fb32d 100644
--- a/python/tvm/relax/dpl/context.py
+++ b/python/tvm/relax/dpl/context.py
@@ -17,7 +17,7 @@
 
 """The Graph Matching Context Manager for Dataflow Pattern Language."""
 
-from typing import Optional, Dict
+from typing import Dict
 
 import tvm
 from ..expr import DataflowBlock, Var
@@ -63,8 +63,6 @@ class PatternContext(tvm.runtime.Object):
     def match_dfb(
         self,
         dfb: DataflowBlock,
-        start_hint: Optional[Var] = None,
-        must_include_hint: bool = False,
     ) -> Dict[DFPattern, Var]:
         """
         Match a DataflowBlock via a graph of DFPattern and corresponding constraints
@@ -73,14 +71,10 @@ class PatternContext(tvm.runtime.Object):
         ----------
         dfb : DataflowBlock
             The DataflowBlock to match
-        start_hint : Optional[Var], optional
-            Indicating the starting expression to match, by default None
-        must_include_hint : bool, optional
-            Whether the start_hint expression must be matched, by default False
 
         Returns
         -------
         Dict[DFPattern, Var]
             The mapping from DFPattern to matched expression
         """
-        return ffi.match_dfb(self, dfb, start_hint, must_include_hint)  # type: ignore
+        return ffi.match_dfb(self, dfb)  # type: ignore
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index c1306ff690..6e8211cfd3 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -523,109 +523,114 @@ bool MatchExpr(DFPattern pattern, Expr expr, Optional<Map<Var, Expr>> bindings_o
 
 TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr);
 
+class MatcherUseDefAnalysis : public relax::ExprVisitor {
+ public:
+  std::vector<const VarNode*> vars;
+  std::map<const VarNode*, std::vector<const VarNode*>> def2use;
+  // caller -> callee table.
+  std::map<const VarNode*, std::vector<const VarNode*>> caller2callees;
+
+  const VarNode* cur_user_;
+
+  void VisitBinding_(const VarBindingNode* binding) override {
+    // init
+    cur_user_ = binding->var.get();
+    this->VisitVarDef(binding->var);
+    this->VisitExpr(binding->value);
+    cur_user_ = nullptr;
+  }
+
+  void VisitExpr_(const VarNode* op) override {
+    if (nullptr == cur_user_) return;
+
+    auto check_and_push = [](std::vector<const VarNode*>& vec, const VarNode* var) {
+      if (std::find(vec.begin(), vec.end(), var) == vec.end()) {
+        vec.push_back(var);
+      }
+    };
+
+    check_and_push(def2use[op], cur_user_);
+    check_and_push(vars, op);
+
+    caller2callees[cur_user_].push_back(op);
+  }
+
+  void VisitExpr_(const DataflowVarNode* op) override {
+    VisitExpr_(static_cast<const VarNode*>(op));
+  }
+};
+
 struct PNode {
   const DFPatternNode* ptr;
-  const VarNode* matched = nullptr;
   std::vector<std::pair<PNode*, const std::vector<PairCons>&>> children;
   std::vector<std::pair<PNode*, const std::vector<PairCons>&>> parents;
 };
 
 struct RNode {
   const VarNode* ptr;
-  const DFPatternNode* matched = nullptr;
   std::vector<RNode*> children;
   std::vector<RNode*> parents;
 };
 
-/**
- * \brief This method try to match a real node and a pattern node along with its neighbors.
- */
-using UndoItems = std::vector<std::pair<PNode*, RNode*>>;
-static std::optional<UndoItems> try_match(
-    PNode* p, RNode* r, DFPatternMatcher* m,
-    const std::map<const VarNode*, std::vector<const VarNode*>>& def2use,
-    const std::map<const VarNode*, std::vector<const VarNode*>>& use2def) {
-  if (p->matched != nullptr && p->matched == r->ptr) return {};  // matched before.
-  if (!m->Match(GetRef<DFPattern>(p->ptr), GetRef<Var>(r->ptr))) return std::nullopt;
-
-  UndoItems undo;
-
-  const auto commit = [&undo](PNode* p, RNode* r) {
-    // match with each other.
-    // TODO(ganler, masahi): Why commit on the same p-r pair happens more than once?
-    if (p->ptr == r->matched) {
-      ICHECK_EQ(p->matched, r->ptr);
-      return;
-    }
-    p->matched = r->ptr;
-    r->matched = p->ptr;
-    undo.emplace_back(p, r);
-  };
+struct MatchState {
+  void add(const PNode* p, const RNode* r) {
+    match_p_r[p] = r;
+    match_r_p[r] = p;
+  }
 
-  const auto quit = [&undo] {
-    for (auto& [p_node, r_node] : undo) {
-      p_node->matched = nullptr;
-      r_node->matched = nullptr;
-    }
-    return std::nullopt;
-  };
+  void add(MatchState&& other) {
+    match_p_r.merge(std::move(other.match_p_r));
+    match_r_p.merge(std::move(other.match_r_p));
+  }
 
-  const auto try_match_update_undo = [&](PNode* p, RNode* r) {
-    if (auto undo_more = try_match(p, r, m, def2use, use2def)) {
-      undo.insert(undo.end(), undo_more->begin(), undo_more->end());
-      return true;
+  const VarNode* matched(const PNode* p) const {
+    if (auto it = match_p_r.find(p); it != match_p_r.end()) {
+      return it->second->ptr;
     }
-    return false;
-  };
+    return nullptr;
+  }
 
-  commit(p, r);
+  const DFPatternNode* matched(const RNode* r) const {
+    if (auto it = match_r_p.find(r); it != match_r_p.end()) {
+      return it->second->ptr;
+    }
+    return nullptr;
+  }
 
-  // match parent patterns.
-  for (auto& [pparent, constraints] : p->parents) {
-    bool any_cons_sat = false;
-    for (auto& rparent : r->parents) {
-      // skip if mismatch.
-      if (rparent->matched && rparent->matched != pparent->ptr) continue;
+  const VarNode* matched(const PNode& p) const { return matched(&p); }
+  const DFPatternNode* matched(const RNode& r) const { return matched(&r); }
 
-      const auto& uses = def2use.at(rparent->ptr);
+ private:
+  std::unordered_map<const PNode*, const RNode*> match_p_r;
+  std::unordered_map<const RNode*, const PNode*> match_r_p;
+};
 
-      // check edge constraints.
-      bool cons_sat = true;
-      for (const auto& cons : constraints) {
-        if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) {
-          cons_sat = false;
-          break;
-        }
+/**
+ * \brief This method try to match a real node and a pattern node along with its neighbors.
+ */
+static std::optional<MatchState> TryMatch(const PNode& p, const RNode& r,
+                                          const MatchState& current_match, DFPatternMatcher* m,
+                                          const MatcherUseDefAnalysis& ud_analysis) {
+  if (!m->Match(GetRef<DFPattern>(p.ptr), GetRef<Var>(r.ptr))) return std::nullopt;
 
-        if (cons.index != -1) {
-          const auto& callees = use2def.at(r->ptr);
-          if (callees.size() <= static_cast<size_t>(cons.index) ||
-              callees[cons.index] != rparent->ptr) {
-            cons_sat = false;
-            break;
-          }
-        }
-      }
-      if (!cons_sat) continue;
-      any_cons_sat = true;
+  MatchState new_match;
 
-      // try all parent R nodes that are not matched yet.
-      // as long as ppattern can match one node.
-      if (!pparent->matched && try_match_update_undo(pparent, rparent)) {
-        commit(pparent, rparent);
-        break;
-      }
-    }
-    if (!pparent->matched || !any_cons_sat) return quit();
-  }
+  new_match.add(&p, &r);
 
   // forward matching;
-  for (auto& [pchild, constraints] : p->children) {
+  for (const auto& [pchild, constraints] : p.children) {
     bool any_cons_sat = false;
-    for (auto& rchild : r->children) {
-      if (rchild->matched && rchild->matched != pchild->ptr) continue;
+    for (const auto& rchild : r.children) {
+      if (new_match.matched(rchild)) {
+        // The child variable is already matched to other child pattern in a previous iteration.
+        continue;
+      }
+      if (auto v = current_match.matched(pchild); v && v != rchild->ptr) {
+        // The child pattern is already matched to other variable in a earlier call to TryMatch.
+        continue;
+      }
 
-      const auto& uses = def2use.at(r->ptr);
+      const auto& uses = ud_analysis.def2use.at(r.ptr);
 
       // check edge constraints.
       bool all_cons_pass = true;
@@ -636,88 +641,87 @@ static std::optional<UndoItems> try_match(
         }
 
         if (cons.index != -1) {
-          const auto& callees = use2def.at(rchild->ptr);
-          if (callees.size() <= static_cast<size_t>(cons.index) || callees[cons.index] != r->ptr) {
+          const auto& callees = ud_analysis.caller2callees.at(rchild->ptr);
+          if (callees.size() <= static_cast<size_t>(cons.index) || callees[cons.index] != r.ptr) {
             all_cons_pass = false;
             break;
           }
         }
       }
-      if (!all_cons_pass) continue;
+      if (!all_cons_pass || new_match.matched(pchild)) continue;
       any_cons_sat = true;
 
-      if (!pchild->matched && try_match_update_undo(pchild, rchild)) {
-        commit(pchild, rchild);
-        break;
+      if (auto match_rec = TryMatch(*pchild, *rchild, current_match, m, ud_analysis)) {
+        new_match.add(pchild, rchild);
+        new_match.add(std::move(*match_rec));
       }
     }
-    if (!pchild->matched || !any_cons_sat) return quit();
+    if (!new_match.matched(pchild) || !any_cons_sat) return std::nullopt;
   }
-  return undo;
+
+  return new_match;
 }
 
-class MatcherUseDefAnalysis : public relax::ExprVisitor {
- public:
-  std::vector<const VarNode*> vars;
-  std::map<const VarNode*, std::vector<const VarNode*>> def2use;
-  // caller -> callee table.
-  std::map<const VarNode*, std::vector<const VarNode*>> caller2callees;
+static std::optional<MatchState> MatchTree(
+    const MatchState& current_match, size_t current_root_idx,
+    const std::unordered_map<const DFPatternNode*, PNode>& pattern2node,
+    const std::unordered_map<const VarNode*, RNode>& var2node, DFPatternMatcher* matcher,
+    const std::vector<DFPattern>& roots, const MatcherUseDefAnalysis& ud_analysis) {
+  auto get_next_root = [&](size_t root_idx) -> const PNode* {
+    // Look for the next unmatched root node.
+    for (; root_idx < roots.size(); ++root_idx) {
+      const auto& root = pattern2node.at(roots[root_idx].get());
+      if (!current_match.matched(root)) {
+        return &root;
+      }
+    }
+    return nullptr;
+  };
 
-  const VarNode* cur_user_;
+  const auto root = get_next_root(current_root_idx);
 
-  void VisitBinding_(const VarBindingNode* binding) override {
-    // init
-    cur_user_ = binding->var.get();
-    this->VisitVarDef(binding->var);
-    this->VisitExpr(binding->value);
-    cur_user_ = nullptr;
+  if (!root) {
+    // All root nodes have been matched
+    return current_match;
   }
 
-  void VisitExpr_(const VarNode* op) override {
-    if (nullptr == cur_user_) return;
+  MatchState new_match = current_match;
 
-    auto check_and_push = [](std::vector<const VarNode*>& vec, const VarNode* var) {
-      if (std::find(vec.begin(), vec.end(), var) == vec.end()) {
-        vec.push_back(var);
+  for (const auto& var : ud_analysis.vars) {
+    const RNode& r_node = var2node.at(var);
+    if (new_match.matched(r_node)) continue;
+    if (auto match = TryMatch(*root, r_node, new_match, matcher, ud_analysis)) {
+      // Recursivly try to match the next subtree.
+      new_match.add(std::move(*match));
+      if (auto match_rec = MatchTree(new_match, current_root_idx + 1, pattern2node, var2node,
+                                     matcher, roots, ud_analysis)) {
+        new_match.add(std::move(*match_rec));
+        return new_match;
       }
-    };
-
-    check_and_push(def2use[op], cur_user_);
-    check_and_push(vars, op);
-
-    caller2callees[cur_user_].push_back(op);
-  }
-
-  void VisitExpr_(const DataflowVarNode* op) override {
-    VisitExpr_(static_cast<const VarNode*>(op));
+      // Recursive matching has failed, backtrack.
+      continue;
+    }
   }
-};
 
-Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb,
-                                         Optional<Var> start_hint, bool must_include_hint) {
-  if (ctx->src_ordered.size() == 0) {
-    return NullOpt;
-  }
+  return std::nullopt;
+}
 
+Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb) {
   // TODO(@ganler): Handle non-may external use.
   ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet.";
-  ICHECK(!must_include_hint || start_hint.defined())
-      << "must_include_hint is only supported with start_hint.";
 
   const auto var2val = AnalyzeVar2Value(dfb);
   DFPatternMatcher matcher(var2val);
 
   MatcherUseDefAnalysis ud_analysis;
   ud_analysis.VisitBindingBlock_(dfb.get());
-  const auto& def2use = ud_analysis.def2use;
-  const auto& caller2callees = ud_analysis.caller2callees;
 
   // First construct a graph of PNode and RNode.
   std::unordered_map<const VarNode*, RNode> var2node;
   var2node.reserve(dfb->bindings.size());
 
   for (const VarNode* cur_var : ud_analysis.vars) {
-    const auto& uses = def2use.at(cur_var);
+    const auto& uses = ud_analysis.def2use.at(cur_var);
     RNode& cur_node = var2node[cur_var];
     cur_node.ptr = cur_var;
     for (const VarNode* use : uses) {
@@ -731,8 +735,9 @@ Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx, const Datafl
   std::unordered_map<const DFPatternNode*, PNode> pattern2node;
   pattern2node.reserve(ctx->constraints.size());
 
-  for (const auto& [def_pattern, uses] : ctx->constraints) {
+  for (const auto& def_pattern : ctx->src_ordered) {
     PNode& def_node = pattern2node[def_pattern.get()];
+    const auto& uses = ctx->constraints.at(def_pattern);
     def_node.ptr = def_pattern.get();
     def_node.children.reserve(uses.size());
     for (const auto& [use_pattern, cons] : uses) {
@@ -743,35 +748,24 @@ Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx, const Datafl
     }
   }
 
-  Map<DFPattern, Var> ret;
-
-  if (start_hint) {
-    auto rnode_ptr = var2node.at(start_hint.value().get());
-    for (auto& p_node : pattern2node) {
-      if (try_match(&p_node.second, &rnode_ptr, &matcher, def2use, caller2callees)) {
-        for (const auto& [df_pattern, pattern_node] : pattern2node) {
-          ret.Set(GetRef<DFPattern>(df_pattern), GetRef<Var>(pattern_node.matched));
-        }
-        return ret;
-      }
+  std::vector<DFPattern> roots;
+  for (const auto& pat : ctx->src_ordered) {
+    if (pattern2node[pat.get()].parents.empty()) {
+      roots.push_back(pat);
     }
-
-    if (must_include_hint) return ret;
   }
 
-  PNode& pnode_start = pattern2node[ctx->src_ordered[0].get()];
+  if (roots.empty()) {
+    return NullOpt;
+  }
 
-  if (!pnode_start.matched) {
-    for (const auto& var : ud_analysis.vars) {
-      if (start_hint.defined() && start_hint.value().get() == var) continue;
-      RNode& r_node = var2node[var];
-      if (try_match(&pnode_start, &r_node, &matcher, def2use, caller2callees)) {
-        for (const auto& [df_pattern, pattern_node] : pattern2node) {
-          ret.Set(GetRef<DFPattern>(df_pattern), GetRef<Var>(pattern_node.matched));
-        }
-        return ret;
-      }
+  if (auto match = MatchTree({}, 0, pattern2node, var2node, &matcher, roots, ud_analysis)) {
+    Map<DFPattern, Var> ret;
+    for (const auto& [pat, p_node] : pattern2node) {
+      ICHECK(match->matched(p_node));
+      ret.Set(GetRef<DFPattern>(pat), GetRef<Var>(match->matched(p_node)));
     }
+    return ret;
   }
 
   return NullOpt;
diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc
index 5580f6a1ab..4d225ceecf 100644
--- a/src/relax/ir/dataflow_pattern.cc
+++ b/src/relax/ir/dataflow_pattern.cc
@@ -406,6 +406,7 @@ PatternContext::PatternContext(bool incremental) {
         << "Incremental context needs to be built inside a existing context.";
     n->allow_extern_use = pattern_ctx_stack().top()->allow_extern_use;
     n->constraints = pattern_ctx_stack().top()->constraints;
+    n->src_ordered = pattern_ctx_stack().top()->src_ordered;
   }
 
   data_ = std::move(n);
diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py
index b85543cafc..a73a62eeef 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -519,51 +519,6 @@ class CBRx2:
         return lv6
 
 
-def test_single_cbr():
-    with PatternContext() as ctx:
-        (
-            is_call_dps_packed("conv1x1")
-            >> is_call_dps_packed("bias_add")
-            >> is_call_dps_packed("my_relu")
-        )
-        dfb = CBRx2["main"].body.blocks[0]
-        matched = ctx.match_dfb(dfb)
-        assert matched
-
-    with PatternContext() as ctx:
-        chain = (
-            is_call_dps_packed("conv1x1")
-            >> is_call_dps_packed("bias_add")
-            >> is_call_dps_packed("my_relu")
-        )
-        dfb = CBRx2["main"].body.blocks[0]
-        # we want to specifically match the first CBR (lv0)
-        matched = ctx.match_dfb(dfb, start_hint=dfb.bindings[0].var)
-        assert matched
-        assert matched[chain[0]] == dfb.bindings[0].var
-        # we want to specifically match the second CBR (lv3)
-        matched = ctx.match_dfb(dfb, start_hint=dfb.bindings[3].var)
-        assert matched
-        assert matched[chain[0]] == dfb.bindings[3].var
-
-
-def test_counter_single_crb():
-    with PatternContext() as ctx:
-        (
-            is_call_dps_packed("conv1x1")
-            >> is_call_dps_packed("my_relu")
-            >> is_call_dps_packed("bias_add")
-        )
-        dfb = CBRx2["main"].body.blocks[0]
-        assert not ctx.match_dfb(dfb)
-        # Quickly fails unpromising matches by assuming `start_hint` must be matched by a pattern.
-        # This is usually faster than the full match:
-        # Full match: let one pattern to match -> all Var: complexity ~ #Var
-        # must_include_hint: let `start_hint` to match -> all patterns: complexity ~ #patterns
-        # Usually #patterns is much smaller than #Var, so this is faster.
-        assert not ctx.match_dfb(dfb, start_hint=dfb.bindings[0].var, must_include_hint=True)
-
-
 def test_nested_context():
     dfb = CBRx2["main"].body.blocks[0]
     with PatternContext() as ctx0: