You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/06/11 06:06:54 UTC

[tvm] branch main updated: support matching attributes with more complext objects (#8240)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4e9760b  support matching attributes with more complext objects (#8240)
4e9760b is described below

commit 4e9760bd477c8351c23c3cf482f3d0c8fd09efa0
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Fri Jun 11 00:06:23 2021 -0600

    support matching attributes with more complext objects (#8240)
---
 docs/langref/relay_pattern.rst              | 11 +++++++++++
 src/relay/ir/dataflow_matcher.cc            |  9 +++++++++
 tests/python/relay/test_dataflow_pattern.py | 11 +++++++++++
 3 files changed, 31 insertions(+)

diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst
index 257fe08..49d3a42 100644
--- a/docs/langref/relay_pattern.rst
+++ b/docs/langref/relay_pattern.rst
@@ -80,6 +80,17 @@ Here is another example to match an op with a specific attribute:
         y = relay.var('y')
         assert not is_conv2d.match(relay.op.nn.conv2d(x, y))
 
+Or a convolution with a specific kernel size:
+
+.. code-block:: python
+
+    def test_match_kernel_size():
+        is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"kernel_size": [3, 3]})
+        x = relay.var('x')
+        y = relay.var('y')
+        assert is_conv2d.match(relay.op.nn.conv2d(x, y, kernel_size=[3, 3]))
+      
+
 
 Matching an Optional Op
 ***********************
diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc
index 6ed24d5..5ce06d9 100644
--- a/src/relay/ir/dataflow_matcher.cc
+++ b/src/relay/ir/dataflow_matcher.cc
@@ -131,6 +131,8 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) {
         return rhs.operator std::string() == val->value;
       } else if (auto* val = lhs.as<StringObj>()) {
         return rhs.operator std::string() == val->data;
+      } else {
+        ICHECK(false) << "PatternMatcher: Unsupported TVMDataType " << lhs;
       }
       break;
     case kTVMObjectHandle:
@@ -140,6 +142,13 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) {
         } else if (auto* val = lhs.as<StringObj>()) {
           return rhs.operator String() == val->data;
         }
+      } else {
+        // Compare the objects for structural equality
+        static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual");
+        ICHECK(structural_equal) << "node.StructuralEqual is not registered.";
+        if ((*structural_equal)(lhs, GetRef<ObjectRef>(rhs.ptr<Object>()), false, true)) {
+          return true;
+        }
       }
       break;
     default:
diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py
index 229b990..f95a009 100644
--- a/tests/python/relay/test_dataflow_pattern.py
+++ b/tests/python/relay/test_dataflow_pattern.py
@@ -478,11 +478,17 @@ def test_no_match_func_attr():
 
 
 def test_match_call_attr():
+    # String attr
     is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"data_layout": "NCHW"})
     x = relay.var("x")
     y = relay.var("y")
     assert is_conv2d.match(relay.op.nn.conv2d(x, y))
 
+    # Array attr
+    is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"kernel_size": [3, 3]})
+    out = relay.op.nn.conv2d(x, y, kernel_size=[3, 3])
+    assert is_conv2d.match(out)
+
     # non-operator call
     attr_dict = {"call_attr": "attr"}
     call_has_attr = wildcard()(wildcard()).has_attr(attr_dict)
@@ -508,6 +514,11 @@ def test_no_match_call_attr():
     is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"RandomAttr": "NCHW"})
     assert not is_conv2d.match(relay.op.nn.conv2d(x, y))
 
+    # Array attr
+    is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"kernel_size": [3, 3]})
+    out = relay.op.nn.conv2d(x, y, kernel_size=[2, 1])
+    assert not is_conv2d.match(out)
+
     # non-operator calls
     call_has_attr = wildcard()(wildcard()).has_attr({"call_attr": "attr"})
     wrong_key = tvm.ir.make_node("DictAttrs", **{"wrong": "attr"})