You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/06/11 06:06:54 UTC
[tvm] branch main updated: support matching attributes with more
complext objects (#8240)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4e9760b support matching attributes with more complext objects (#8240)
4e9760b is described below
commit 4e9760bd477c8351c23c3cf482f3d0c8fd09efa0
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Fri Jun 11 00:06:23 2021 -0600
support matching attributes with more complext objects (#8240)
---
docs/langref/relay_pattern.rst | 11 +++++++++++
src/relay/ir/dataflow_matcher.cc | 9 +++++++++
tests/python/relay/test_dataflow_pattern.py | 11 +++++++++++
3 files changed, 31 insertions(+)
diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst
index 257fe08..49d3a42 100644
--- a/docs/langref/relay_pattern.rst
+++ b/docs/langref/relay_pattern.rst
@@ -80,6 +80,17 @@ Here is another example to match an op with a specific attribute:
y = relay.var('y')
assert not is_conv2d.match(relay.op.nn.conv2d(x, y))
+Or a convolution with a specific kernel size:
+
+.. code-block:: python
+
+ def test_match_kernel_size():
+ is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"kernel_size": [3, 3]})
+ x = relay.var('x')
+ y = relay.var('y')
+ assert is_conv2d.match(relay.op.nn.conv2d(x, y, kernel_size=[3, 3]))
+
+
Matching an Optional Op
***********************
diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc
index 6ed24d5..5ce06d9 100644
--- a/src/relay/ir/dataflow_matcher.cc
+++ b/src/relay/ir/dataflow_matcher.cc
@@ -131,6 +131,8 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) {
return rhs.operator std::string() == val->value;
} else if (auto* val = lhs.as<StringObj>()) {
return rhs.operator std::string() == val->data;
+ } else {
+ ICHECK(false) << "PatternMatcher: Unsupported TVMDataType " << lhs;
}
break;
case kTVMObjectHandle:
@@ -140,6 +142,13 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) {
} else if (auto* val = lhs.as<StringObj>()) {
return rhs.operator String() == val->data;
}
+ } else {
+ // Compare the objects for structural equality
+ static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual");
+ ICHECK(structural_equal) << "node.StructuralEqual is not registered.";
+ if ((*structural_equal)(lhs, GetRef<ObjectRef>(rhs.ptr<Object>()), false, true)) {
+ return true;
+ }
}
break;
default:
diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py
index 229b990..f95a009 100644
--- a/tests/python/relay/test_dataflow_pattern.py
+++ b/tests/python/relay/test_dataflow_pattern.py
@@ -478,11 +478,17 @@ def test_no_match_func_attr():
def test_match_call_attr():
+ # String attr
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"data_layout": "NCHW"})
x = relay.var("x")
y = relay.var("y")
assert is_conv2d.match(relay.op.nn.conv2d(x, y))
+ # Array attr
+ is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"kernel_size": [3, 3]})
+ out = relay.op.nn.conv2d(x, y, kernel_size=[3, 3])
+ assert is_conv2d.match(out)
+
# non-operator call
attr_dict = {"call_attr": "attr"}
call_has_attr = wildcard()(wildcard()).has_attr(attr_dict)
@@ -508,6 +514,11 @@ def test_no_match_call_attr():
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"RandomAttr": "NCHW"})
assert not is_conv2d.match(relay.op.nn.conv2d(x, y))
+ # Array attr
+ is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"kernel_size": [3, 3]})
+ out = relay.op.nn.conv2d(x, y, kernel_size=[2, 1])
+ assert not is_conv2d.match(out)
+
# non-operator calls
call_has_attr = wildcard()(wildcard()).has_attr({"call_attr": "attr"})
wrong_key = tvm.ir.make_node("DictAttrs", **{"wrong": "attr"})