You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/08/04 22:27:40 UTC
[tvm] branch main updated: [Relay] [Bugfix] Fix some bugs of dominator pattern (#15473)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ac99367aa4 [Relay] [Bugfix] Fix some bugs of dominator pattern (#15473)
ac99367aa4 is described below
commit ac99367aa4f4ad05ba9926c614159e7430936502
Author: 电线杆 <44...@qq.com>
AuthorDate: Sat Aug 5 06:27:33 2023 +0800
[Relay] [Bugfix] Fix some bugs of dominator pattern (#15473)
* fix some bugs
* add test
---
src/relay/ir/dataflow_matcher.cc | 15 +++++++++-----
tests/python/relay/test_dataflow_pattern.py | 31 +++++++++++++++++++++++++++++
2 files changed, 41 insertions(+), 5 deletions(-)
diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc
index 185a92898d..249f4ccf7a 100644
--- a/src/relay/ir/dataflow_matcher.cc
+++ b/src/relay/ir/dataflow_matcher.cc
@@ -302,12 +302,12 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex
bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) {
auto call_node = expr.as<CallNode>();
auto index_node = expr_to_node(expr);
+ size_t arg_counter{0};
for (auto node : index_node->inputs_) {
if (!(call_node && node->ref() == call_node->op)) {
+ arg_counter += 1;
memoize_ = true;
- if (VisitDFPattern(op->parent, node->ref())) {
- return true;
- } else {
+ if (!VisitDFPattern(op->parent, node->ref())) {
memoize_ = false;
if (!VisitDFPattern(op->path, node->ref())) {
return false;
@@ -318,6 +318,9 @@ bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& e
}
}
}
+ if (!arg_counter) {
+ return false;
+ }
return true;
}
@@ -605,8 +608,10 @@ void PatternGrouper::CreateGroup(const Expr& expr) {
// Don't treat fuzzy Dominator patterns input variables for partition
if (auto op = node->ref().as<DominatorPatternNode>()) {
for (auto fuzzy_op : {op->parent, op->path}) {
- for (auto match : node_map[fuzzy_op]) {
- fuzzy_matches.insert(match);
+ if (node_map.count(fuzzy_op)) {
+ for (auto match : node_map[fuzzy_op]) {
+ fuzzy_matches.insert(match);
+ }
}
}
}
diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py
index bcb665121b..c4a83735ce 100644
--- a/tests/python/relay/test_dataflow_pattern.py
+++ b/tests/python/relay/test_dataflow_pattern.py
@@ -750,6 +750,37 @@ def test_not_match_dominator():
assert not diamond.match(out)
+def test_not_match_dominator2():
+ # Pattern
+ P = is_op("nn.conv2d")(wildcard(), wildcard()) # 'parent'
+ I = is_op("nn.relu")(wildcard()) # 'intermediate' ('path' in the code)
+ C = is_op("add")(wildcard(), wildcard()) # 'child'
+ pattern = dominates(P, I, C)
+
+ # n6(P)
+ # / \
+ # n7 \
+ # / \
+ # n8(P) n9(I)
+ # \ /
+ # \ /
+ # \ /
+ # n10(C)
+
+ x = relay.var("x")
+ w = relay.var("w")
+ n6 = relay.op.nn.conv2d(x, w) # matches P
+ n7 = relay.op.tanh(n6) # does not match I
+ n8 = relay.op.nn.conv2d(n7, w) # matches P
+ n9 = relay.op.nn.relu(n6) # matches I
+ n10 = relay.add(n8, n9) # matches C
+
+ # Does not match: Can't match the parent pattern P at both 8 and 6.
+ # Note that if we did allow P to be used twice the implementation would
+ # need to be changed to not 'jump over' n7.
+ assert not pattern.match(n10)
+
+
def test_match_typed_dominator():
# Pattern
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())