You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/08/04 22:27:40 UTC

[tvm] branch main updated: [Relay] [Bugfix] Fix some bugs of dominator pattern (#15473)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ac99367aa4 [Relay] [Bugfix] Fix some bugs of dominator pattern (#15473)
ac99367aa4 is described below

commit ac99367aa4f4ad05ba9926c614159e7430936502
Author: 电线杆 <44...@qq.com>
AuthorDate: Sat Aug 5 06:27:33 2023 +0800

    [Relay] [Bugfix] Fix some bugs of dominator pattern (#15473)
    
    * fix some bugs
    
    * add test
---
 src/relay/ir/dataflow_matcher.cc            | 15 +++++++++-----
 tests/python/relay/test_dataflow_pattern.py | 31 +++++++++++++++++++++++++++++
 2 files changed, 41 insertions(+), 5 deletions(-)

diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc
index 185a92898d..249f4ccf7a 100644
--- a/src/relay/ir/dataflow_matcher.cc
+++ b/src/relay/ir/dataflow_matcher.cc
@@ -302,12 +302,12 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex
 bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) {
   auto call_node = expr.as<CallNode>();
   auto index_node = expr_to_node(expr);
+  size_t arg_counter{0};
   for (auto node : index_node->inputs_) {
     if (!(call_node && node->ref() == call_node->op)) {
+      arg_counter += 1;
       memoize_ = true;
-      if (VisitDFPattern(op->parent, node->ref())) {
-        return true;
-      } else {
+      if (!VisitDFPattern(op->parent, node->ref())) {
         memoize_ = false;
         if (!VisitDFPattern(op->path, node->ref())) {
           return false;
@@ -318,6 +318,9 @@ bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& e
       }
     }
   }
+  if (!arg_counter) {
+    return false;
+  }
   return true;
 }
 
@@ -605,8 +608,10 @@ void PatternGrouper::CreateGroup(const Expr& expr) {
     // Don't treat fuzzy Dominator patterns input variables for partition
     if (auto op = node->ref().as<DominatorPatternNode>()) {
       for (auto fuzzy_op : {op->parent, op->path}) {
-        for (auto match : node_map[fuzzy_op]) {
-          fuzzy_matches.insert(match);
+        if (node_map.count(fuzzy_op)) {
+          for (auto match : node_map[fuzzy_op]) {
+            fuzzy_matches.insert(match);
+          }
         }
       }
     }
diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py
index bcb665121b..c4a83735ce 100644
--- a/tests/python/relay/test_dataflow_pattern.py
+++ b/tests/python/relay/test_dataflow_pattern.py
@@ -750,6 +750,37 @@ def test_not_match_dominator():
     assert not diamond.match(out)
 
 
+def test_not_match_dominator2():
+    # Pattern
+    P = is_op("nn.conv2d")(wildcard(), wildcard())  # 'parent'
+    I = is_op("nn.relu")(wildcard())  # 'intermediate' ('path' in the code)
+    C = is_op("add")(wildcard(), wildcard())  # 'child'
+    pattern = dominates(P, I, C)
+
+    #       n6(P)
+    #      /  \
+    #     n7   \
+    #    /      \
+    #    n8(P)  n9(I)
+    #    \      /
+    #     \    /
+    #      \  /
+    #      n10(C)
+
+    x = relay.var("x")
+    w = relay.var("w")
+    n6 = relay.op.nn.conv2d(x, w)  # matches P
+    n7 = relay.op.tanh(n6)  # does not match I
+    n8 = relay.op.nn.conv2d(n7, w)  # matches P
+    n9 = relay.op.nn.relu(n6)  # matches I
+    n10 = relay.add(n8, n9)  # matches C
+
+    # Does not match: Can't match the parent pattern P at both 8 and 6.
+    # Note that if we did allow P to be used twice the implementation would
+    # need to be changed to not 'jump over' n7.
+    assert not pattern.match(n10)
+
+
 def test_match_typed_dominator():
     # Pattern
     is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())