You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/06/02 00:14:47 UTC
[incubator-tvm] branch master updated: [PatternLang] Simplify
Pattern API Implementations (#5703)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 43162d6 [PatternLang] Simplify Pattern API Implementations (#5703)
43162d6 is described below
commit 43162d669d466c69bc6a64771b2fb8441f0a1c69
Author: Cody Yu <co...@gmail.com>
AuthorDate: Mon Jun 1 17:14:33 2020 -0700
[PatternLang] Simplify Pattern API Implementations (#5703)
* Add syntatic sugar; include pattern to API docs
* fix doc warnings
---
docs/api/python/index.rst | 1 +
.../{index.rst => relay/dataflow_pattern.rst} | 36 +---
docs/langref/relay_pattern.rst | 21 +-
python/tvm/relay/dataflow_pattern/__init__.py | 230 ++++++++++++++-------
tests/python/relay/test_dataflow_pattern.py | 55 +++--
5 files changed, 201 insertions(+), 142 deletions(-)
diff --git a/docs/api/python/index.rst b/docs/api/python/index.rst
index 50d7a3d..bee6e56 100644
--- a/docs/api/python/index.rst
+++ b/docs/api/python/index.rst
@@ -37,6 +37,7 @@ Python API
relay/transform
relay/analysis
relay/backend
+ relay/dataflow_pattern
relay/testing
autotvm
rpc
diff --git a/docs/api/python/index.rst b/docs/api/python/relay/dataflow_pattern.rst
similarity index 70%
copy from docs/api/python/index.rst
copy to docs/api/python/relay/dataflow_pattern.rst
index 50d7a3d..fe1d4e9 100644
--- a/docs/api/python/index.rst
+++ b/docs/api/python/relay/dataflow_pattern.rst
@@ -15,33 +15,11 @@
specific language governing permissions and limitations
under the License.
-Python API
-==========
+tvm.relay.dataflow_pattern
+--------------------------
-.. toctree::
- :maxdepth: 2
-
- runtime
- ndarray
- error
- ir
- target
- tir
- te
- driver
- relay/index
- relay/frontend
- relay/nn
- relay/vision
- relay/image
- relay/transform
- relay/analysis
- relay/backend
- relay/testing
- autotvm
- rpc
- micro
- contrib
- graph_runtime
- vta/index
- topi
+.. automodule:: tvm.relay.dataflow_pattern
+ :members:
+ :imported-members:
+ :exclude-members: Object, Node
+ :autosummary:
diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst
index 7d5deb2..962dcc6 100644
--- a/docs/langref/relay_pattern.rst
+++ b/docs/langref/relay_pattern.rst
@@ -114,7 +114,7 @@ Since there are not call nodes, we need to use specific pattern nodes to match t
x = relay.var('x')
y = relay.var('y')
z = relay.var('z')
- tuple_pattern = TuplePattern((wildcard(), wildcard(), wildcard()))
+ tuple_pattern = is_tuple((wildcard(), wildcard(), wildcard()))
assert tuple_pattern.match(relay.expr.Tuple((x,y,z)))
The next example is matching a pattern of batch_norm -> get(0) -> relu:
@@ -123,7 +123,7 @@ The next example is matching a pattern of batch_norm -> get(0) -> relu:
def test_match_tuple_get_item():
bn_node = is_op('nn.batch_norm')(wildcard(), wildcard(), wildcard(), wildcard(), wildcard())
- tuple_get_item_node = TupleGetItemPattern(bn_node, 0)
+ tuple_get_item_node = is_tuple_get_item(bn_node, 0)
pat = is_op('nn.relu')(tuple_get_item_node)
x = relay.var('x', shape=(1, 8))
@@ -142,7 +142,7 @@ if a specific parameter in a subgraph has been bound or not.
.. code-block:: python
def test_match_constant():
- conv2d = is_op('nn.conv2d')(wildcard(), ConstantPattern())
+ conv2d = is_op('nn.conv2d')(wildcard(), is_constant())
pattern = is_op('nn.bias_add')(conv2d, wildcard())
x = relay.var('x', shape=(1, 3, 224, 224))
@@ -162,12 +162,12 @@ if a specific parameter in a subgraph has been bound or not.
assert pattern.match(mod['main'].body)
On the other hand, if you need to match the constant with a specific value, you can directly
-use ``ExprPattern``. This could be useful for algebraic simplify.
+use ``is_expr``. This could be useful for algebraic simplify.
.. code-block:: python
def test_match_plus_zero():
- zero = (ExprPattern(relay.const(0)) | ExprPattern(relay.const(0.0)))
+ zero = (is_expr(relay.const(0)) | is_expr(relay.const(0.0)))
pattern = wildcard() + zero
x = relay.Var('x')
@@ -193,7 +193,7 @@ The next example is matching a diamond with two inputs at the top of the diamond
def test_match_diamond():
# Pattern
- is_conv2d = is_op('nn.conv2d')(is_input(), is_input())
+ is_conv2d = is_op('nn.conv2d')(is_var(), is_var())
path1 = is_op('nn.relu')(is_conv2d)
path2 = is_op('nn.leaky_relu')(is_conv2d)
diamond = is_op('add')(path1, path2)
@@ -213,7 +213,7 @@ The final example is matching diamonds with a post-dominator relationship. We em
def test_match_dom_diamond():
# Pattern
- is_conv2d = is_op('nn.conv2d')(is_input(), is_input())
+ is_conv2d = is_op('nn.conv2d')(is_var(), is_var())
reduction = is_op('add')(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_elemwise, reduction)
@@ -240,7 +240,12 @@ The high level design is to introduce a language of patterns for now we propose
| pattern(pattern1, ... patternN)
| has_type(pattern, type)
| has_attr(pattern, attrs)
- | is_input(name)
+ | is_var(name)
+ | is_constant()
+ | is_expr(expr)
+ | is_op(op_name)
+ | is_tuple()
+ | is_tuple_get_item()
| pattern1 `|` pattern2
| dominates(parent_pattern, path_pattern, child_pattern)
diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py
index f1d0784..e6a1a5e 100644
--- a/python/tvm/relay/dataflow_pattern/__init__.py
+++ b/python/tvm/relay/dataflow_pattern/__init__.py
@@ -15,12 +15,16 @@
# specific language governing permissions and limitations
# under the License.
"""The Relay Pattern Language and tooling."""
-from tvm.relay.expr import RelayExpr as Expr
+# pylint: disable=no-member
+from typing import Callable, Dict, List, Optional
+
import tvm._ffi
-from ...ir.base import Node
+from tvm.relay.expr import RelayExpr as Expr
+
+from ... import _ffi as tvm_ffi
from ...ir import make_node
+from ...ir.base import Node
from ...runtime import Object
-from ... import _ffi as tvm_ffi
from ..op import get
from . import _ffi as ffi
@@ -61,7 +65,7 @@ class DFPattern(Node):
def __truediv__(self, other):
return is_op("divide")(self, other)
- def has_attr(self, attrs):
+ def has_attr(self, attrs: Dict[str, Object]):
"""
Add an attribute constraint to this pattern
@@ -77,13 +81,13 @@ class DFPattern(Node):
attrs = make_node("DictAttrs", **attrs)
return AttrPattern(self, attrs)
- def has_type(self, ttype):
+ def has_type(self, ttype: tvm.ir.type.Type):
"""
Add a type constraint to this pattern
Parameters
----------
- ttype: tvm.relay.Type
+ ttype: tvm.ir.type.Type
The type to match
Returns
@@ -109,7 +113,10 @@ class DFPattern(Node):
"""
return match(self, expr)
- def partition(self, expr: Expr, attrs=None, check=lambda x: True) -> Expr:
+ def partition(self,
+ expr: Expr,
+ attrs: Optional[Dict[str, Object]] = None,
+ check: Callable[[Expr], bool] = lambda x: True) -> Expr:
"""
Parition the expression into functions defined by this pattern
@@ -119,7 +126,7 @@ class DFPattern(Node):
The expression to match.
attrs : Optional[Dict[str, Object]]
A dictionary of Attribute name/values to add to the paritioned function
- check : Function
+ check : Callable[[Expr], bool]
A function to perform more complicated checks on the matched expression.
Returns true if partitioning should proceed, false otherwise.
@@ -130,9 +137,9 @@ class DFPattern(Node):
"""
return partition(self, expr, attrs, check)
- def dominates(self, parent, path=None):
+ def dominates(self, parent: "DFPattern", path: "DFPattern" = None):
"""
- Create a dominator for this pattern
+ Create a dominator for this pattern.
Parameters
----------
@@ -144,15 +151,15 @@ class DFPattern(Node):
Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
- The resulting DominatorPattern
+ The resulting DominatorPattern.
"""
if path is None:
path = wildcard()
return DominatorPattern(parent, path, self)
- def optional(self, option_constructor):
+ def optional(self, option_constructor: Callable[["DFPattern"], "DFPattern"]):
"""
- Create a optional user of this pattern
+ Create a optional user of this pattern.
Parameters
----------
@@ -168,26 +175,60 @@ class DFPattern(Node):
return self | option_constructor(self)
-def is_input(name: str = "") -> DFPattern:
+def is_var(name: str = "") -> "DFPattern":
"""
- Syntatic sugar for creating an optionally named VarPattern
+ Syntatic sugar for creating an optionally named VarPattern.
Parameters
----------
name: str
- The name of the input pattern to match
+ The name of the input pattern to match.
Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
- The resulting InputPattern
+ The resulting pattern.
"""
return VarPattern(name)
-def is_op(op_name: str) -> DFPattern:
+def is_constant() -> "DFPattern":
+ """
+ Syntatic sugar for creating a ConstantPattern.
+
+ Parameters
+ ----------
+ name: str
+ The name of the input pattern to match.
+
+ Returns
+ -------
+ result: tvm.relay.dataflow_pattern.DFPattern
+ The resulting pattern.
+ """
+ return ConstantPattern()
+
+
+def is_expr(expr: Expr) -> "DFPattern":
+ """
+ Syntatic sugar for creating an ExprPattern.
+
+ Parameters
+ ----------
+ expr: Expr
+ The Relay expression to match.
+
+ Returns
+ -------
+ result: tvm.relay.dataflow_pattern.DFPattern
+ The resulting pattern.
+ """
+ return ExprPattern(expr)
+
+
+def is_op(op_name: str) -> "DFPattern":
"""
- Syntatic sugar for creating an operator ExprPattern
+ Syntatic sugar for creating an operator ExprPattern.
Parameters
----------
@@ -203,19 +244,56 @@ def is_op(op_name: str) -> DFPattern:
return ExprPattern(op)
-def wildcard() -> DFPattern:
+def is_tuple(fields: tvm.ir.container.Array) -> "DFPattern":
"""
- Syntatic sugar for creating a WildcardPattern
+ Syntatic sugar for creating an ExprPattern.
+
+ Parameters
+ ----------
+ fields : Array[tvm.relay.dataflow_pattern.DFPattern]
+ The fields in the tuple.
Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
- The resulting WildcardPattern
+ The resulting pattern.
+ """
+ return TuplePattern(fields)
+
+
+def is_tuple_get_item(tuple_value: "DFPattern", index: int) -> "DFPattern":
+ """
+ Syntatic sugar for creating an ExprPattern.
+
+ Parameters
+ ----------
+ tuple_value: tvm.relay.dataflow_pattern.DFPattern
+ The input tuple expression.
+
+ index: int
+ The index.
+
+ Returns
+ -------
+ result: tvm.relay.dataflow_pattern.DFPattern
+ The resulting pattern.
+ """
+ return TupleGetItemPattern(tuple_value, index)
+
+
+def wildcard() -> "DFPattern":
+ """
+ Syntatic sugar for creating a WildcardPattern.
+
+ Returns
+ -------
+ result: tvm.relay.dataflow_pattern.DFPattern
+ The resulting pattern.
"""
return WildcardPattern()
-def has_type(ttype, pattern: DFPattern = None) -> DFPattern:
+def has_type(ttype, pattern: "DFPattern" = None) -> "DFPattern":
"""
Syntatic sugar for creating a TypePattern
@@ -224,7 +302,7 @@ def has_type(ttype, pattern: DFPattern = None) -> DFPattern:
pattern: tvm.relay.dataflow_pattern.DFPattern
The pattern that needs type annotation
- ttype: tvm.relay.Type
+ ttype: tvm.ir.type.Type
The type to match
Returns
@@ -237,7 +315,7 @@ def has_type(ttype, pattern: DFPattern = None) -> DFPattern:
return TypePattern(pattern, ttype)
-def has_attr(attrs, pattern=None) -> DFPattern:
+def has_attr(attrs, pattern=None) -> "DFPattern":
"""
Syntatic sugar for creating an AttrPattern
@@ -259,7 +337,7 @@ def has_attr(attrs, pattern=None) -> DFPattern:
return pattern.has_attr(attrs)
-def dominates(parent: DFPattern, path: DFPattern, child: DFPattern) -> DFPattern:
+def dominates(parent: "DFPattern", path: "DFPattern", child: "DFPattern") -> "DFPattern":
"""
Syntatic sugar for creating an Dominator pattern
@@ -275,12 +353,12 @@ def dominates(parent: DFPattern, path: DFPattern, child: DFPattern) -> DFPattern
Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
- The resulting DominatorPattern
+ The resulting DominatorPattern.
"""
return DominatorPattern(parent, path, child)
-def match(pattern: DFPattern, expr: Expr) -> bool:
+def match(pattern: "DFPattern", expr: Expr) -> bool:
"""
Match a pattern to an expression
@@ -321,13 +399,12 @@ class VarPattern(DFPattern):
The name of the variable. Optional, if not provided,
the pattern will match any VarNode.
- type_annotation: tvm.relay.Type, optional
+ type_annotation: tvm.ir.type.Type, optional
The type annotation on the variable.
"""
- def __init__(self, name_hint="", type_annotation=None):
- self.__init_handle_by_constructor__(
- ffi.VarPattern, name_hint, type_annotation)
+ def __init__(self, name_hint: str = "", type_annotation: Optional[tvm.ir.type.Type] = None):
+ self.__init_handle_by_constructor__(ffi.VarPattern, name_hint, type_annotation)
@register_df_node
@@ -350,19 +427,22 @@ class CallPattern(DFPattern):
args: List[realy.dataflow_pattern.DFPattern]
The arguments to the call.
- attrs: Optional[tvm.Attrs]
+ attrs: Optional[tvm.ir.attrs.Attrs]
Attributes to the call, can be None
- type_args: Optional[List[tvm.relay.Type]]
+ type_args: Optional[List[tvm.ir.type.Type]]
The additional type arguments, this is only
used in advanced usecase of template functions.
"""
- def __init__(self, op, args, attrs=None, type_args=None):
+ def __init__(self,
+ op: "DFPattern",
+ args: List["DFPattern"],
+ attrs: Optional[tvm.ir.attrs.Attrs] = None,
+ type_args: Optional[List[tvm.ir.type.Type]] = None):
if not type_args:
type_args = []
- self.__init_handle_by_constructor__(
- ffi.CallPattern, op, args, attrs, type_args)
+ self.__init_handle_by_constructor__(ffi.CallPattern, op, args, attrs, type_args)
@register_df_node
@@ -371,14 +451,14 @@ class TuplePattern(DFPattern):
Parameters
----------
- fields : List[tvm.relay.dataflow_pattern.DFPattern]
+ fields : Array[tvm.relay.dataflow_pattern.DFPattern]
The fields in the tuple.
"""
- def __init__(self, fields):
+ def __init__(self, fields: tvm.ir.container.Array):
self.__init_handle_by_constructor__(ffi.TuplePattern, fields)
- def __getitem__(self, index):
+ def __getitem__(self, index: int):
if index >= len(self):
raise IndexError("TuplePattern index out of range")
return self.fields[index]
@@ -403,9 +483,8 @@ class TupleGetItemPattern(DFPattern):
The index.
"""
- def __init__(self, tuple_value: DFPattern, index):
- self.__init_handle_by_constructor__(
- ffi.TupleGetItemPattern, tuple_value, index)
+ def __init__(self, tuple_value: "DFPattern", index: int):
+ self.__init_handle_by_constructor__(ffi.TupleGetItemPattern, tuple_value, index)
@register_df_node
@@ -415,14 +494,13 @@ class AltPattern(DFPattern):
Parameters
----------
left: tvm.relay.dataflow_pattern.DFPattern
- One possible matching Pattern
+ One possible matching pattern.
right: tvm.relay.dataflow_pattern.DFPattern
- One possible matching Pattern
+ One possible matching pattern.
"""
- def __init__(self, left: DFPattern, right: DFPattern):
- self.__init_handle_by_constructor__(
- ffi.AltPattern, left, right)
+ def __init__(self, left: "DFPattern", right: "DFPattern"):
+ self.__init_handle_by_constructor__(ffi.AltPattern, left, right)
@register_df_node
@@ -441,34 +519,32 @@ class TypePattern(DFPattern):
Parameters
----------
pattern: tvm.relay.dataflow_pattern.DFPattern
- The input pattern that needs type annotation
+ The input pattern that needs type annotation.
- ttype: tvm.relay.Type
- The type to match
+ ttype: tvm.ir.type.Type
+ The type to match.
"""
- def __init__(self, pattern: DFPattern, ttype):
- self.__init_handle_by_constructor__(
- ffi.TypePattern, pattern, ttype)
+ def __init__(self, pattern: "DFPattern", ttype: tvm.ir.type.Type):
+ self.__init_handle_by_constructor__(ffi.TypePattern, pattern, ttype)
@register_df_node
class AttrPattern(DFPattern):
"""Get match an expression with a certain attributes.
- Currently only supports Op Attributes, not call Attributes
+ Currently only supports Op Attributes, not call Attributes.
Parameters
----------
pattern: tvm.relay.dataflow_pattern.DFPattern
The input pattern.
- attrs: tvm.Attrs
- The attributes to match
+ attrs: tvm.ir.attrs.Attrs
+ The attributes to match.
"""
- def __init__(self, pattern: DFPattern, attrs):
- self.__init_handle_by_constructor__(
- ffi.AttrPattern, pattern, attrs)
+ def __init__(self, pattern: "DFPattern", attrs: tvm.ir.attrs.Attrs):
+ self.__init_handle_by_constructor__(ffi.AttrPattern, pattern, attrs)
@register_df_node
@@ -479,22 +555,21 @@ class DominatorPattern(DFPattern):
----------
parent: tvm.relay.dataflow_pattern.DFPattern
The parent, i.e., the single node which produces something,
- later aggregated by the child
+ later aggregated by the child.
path: tvm.relay.dataflow_pattern.DFPattern
The fuzzy path pattern between parent and child,
- typically matches elementwise ops
+ typically matches elementwise ops.
child: tvm.relay.dataflow_pattern.DFPattern
The last node in the domination which is the end user
- for all nodes in the path and the parent
+ for all nodes in the path and the parent.
"""
- def __init__(self, parent: DFPattern, path: DFPattern, child: DFPattern):
- self.__init_handle_by_constructor__(
- ffi.DominatorPattern, parent, path, child)
+ def __init__(self, parent: "DFPattern", path: "DFPattern", child: "DFPattern"):
+ self.__init_handle_by_constructor__(ffi.DominatorPattern, parent, path, child)
class DFPatternCallback:
- """A Callback for Pattern Rewriting
+ """A Callback for Pattern Rewriting.
When rewrite is called on this DFPatternCallback, the backend will find matches for the
pattern, call the callback function, and replace the matched expression with whatever
@@ -515,11 +590,11 @@ class DFPatternCallback:
Returns
-------
result : tvm.relay.Expr
- The Expression with matched subgraphs rewritten by the callbacks
+ The Expression with matched subgraphs rewritten by the callbacks.
"""
return rewrite(self, expr)
- def callback(self, pre, post, node_map):
+ def callback(self, pre: Expr, post: Expr, node_map: tvm.ir.container.Map) -> Expr:
"""
Callback function to use when we found a match to the pattern
@@ -529,7 +604,7 @@ class DFPatternCallback:
The matching expression from the original graph.
post : tvm.relay.Expr
The matching expression with rewritten inputs
- node_map : Map(DFPattern, List(Expr))
+ node_map : tvm.ir.container.Map[DFPattern, List[Expr]]
The map between patterns and matched expressions
Returns
@@ -542,13 +617,12 @@ class DFPatternCallback:
class _DFPatternCallback(Object):
"""C++ implemenation"""
def __init__(self, pattern, callback):
- self.__init_handle_by_constructor__(
- ffi.DFPatternCallback, pattern, callback)
+ self.__init_handle_by_constructor__(ffi.DFPatternCallback, pattern, callback)
def rewrite(callbacks, expr: Expr) -> Expr:
"""
- Rewrite expression with the given callbacks
+ Rewrite expression with the given callbacks.
Parameters
----------
@@ -560,7 +634,7 @@ def rewrite(callbacks, expr: Expr) -> Expr:
Returns
-------
result : tvm.relay.Expr
- The Expression with matched subgraphs rewritten by the callbacks
+ The Expression with matched subgraphs rewritten by the callbacks.
"""
if isinstance(callbacks, DFPatternCallback):
tmp = [_DFPatternCallback(callbacks.pattern, callbacks.callback)]
@@ -571,7 +645,11 @@ def rewrite(callbacks, expr: Expr) -> Expr:
return ffi.rewrite(tmp, expr)
-def partition(pattern: DFPattern, expr: Expr, attrs=None, check=lambda x: True) -> Expr:
+
+def partition(pattern: "DFPattern",
+ expr: Expr,
+ attrs: Optional[Dict[str, Object]] = None,
+ check: Callable[[Expr], bool] = lambda x: True) -> Expr:
"""
Parition the expression into a series of functions that match the pattern
@@ -583,7 +661,7 @@ def partition(pattern: DFPattern, expr: Expr, attrs=None, check=lambda x: True)
The expression to split into functions
attrs : Optional[Dict[str, Object]]
A dict of attributes to apply to the partitioned function
- check : Function
+ check : Callable[[Expr], bool]
A function to perform more complicated checks on the matched expression.
Returns true if partitioning should proceed, false otherwise.
diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py
index 89abb2e..8d67db5 100644
--- a/tests/python/relay/test_dataflow_pattern.py
+++ b/tests/python/relay/test_dataflow_pattern.py
@@ -32,19 +32,19 @@ K_BROADCAST = 1
## NODE TESTS
def test_expr_pattern():
- ep = ExprPattern(relay.var('x', shape=(4, 1)))
+ ep = is_expr(relay.var('x', shape=(4, 1)))
assert isinstance(ep, ExprPattern)
assert isinstance(ep.expr, relay.Var)
def test_var_pattern():
- v = is_input("x")
+ v = is_var("x")
assert isinstance(v, VarPattern)
assert v.name == "x"
def test_constant_pattern():
- c = ConstantPattern()
+ c = is_constant()
assert isinstance(c, ConstantPattern)
@@ -65,7 +65,7 @@ def test_CallPattern():
def test_TuplePattern():
wc1 = wildcard()
wc2 = wildcard()
- t = TuplePattern([wc1, wc2])
+ t = is_tuple([wc1, wc2])
assert isinstance(t, TuplePattern)
assert isinstance(t.fields[0], WildcardPattern)
assert isinstance(t.fields[1], WildcardPattern)
@@ -74,8 +74,8 @@ def test_TuplePattern():
def test_TupleGetItemPattern():
wc1 = wildcard()
wc2 = wildcard()
- t = TuplePattern([wc1, wc2])
- tgi = TupleGetItemPattern(t, 1)
+ t = is_tuple([wc1, wc2])
+ tgi = is_tuple_get_item(t, 1)
assert isinstance(tgi, TupleGetItemPattern)
assert isinstance(tgi.tuple, TuplePattern)
assert isinstance(tgi.tuple.fields[0], WildcardPattern)
@@ -120,10 +120,10 @@ def test_match_op_or():
def test_match_call_commutive():
x = relay.var('x')
y = relay.var('y')
- add_pattern = is_op('add')(is_input("x"), is_input("y"))
+ add_pattern = is_op('add')(is_var("x"), is_var("y"))
assert add_pattern.match(x + y)
assert add_pattern.match(y + x)
- mul_pattern = is_op('multiply')(is_input("x"), is_input("y"))
+ mul_pattern = is_op('multiply')(is_var("x"), is_var("y"))
assert mul_pattern.match(x * y)
assert mul_pattern.match(y * x)
@@ -131,10 +131,10 @@ def test_match_call_commutive():
def test_no_match_call_commutive():
x = relay.var('x')
y = relay.var('y')
- add_pattern = is_op('subtract')(is_input("x"), is_input("y"))
+ add_pattern = is_op('subtract')(is_var("x"), is_var("y"))
assert add_pattern.match(x - y)
assert not add_pattern.match(y - x)
- add_pattern = is_op('divide')(is_input("x"), is_input("y"))
+ add_pattern = is_op('divide')(is_var("x"), is_var("y"))
assert add_pattern.match(x / y)
assert not add_pattern.match(y / x)
@@ -211,7 +211,7 @@ def test_no_match_option():
def test_match_const():
- conv2d = is_op('nn.conv2d')(wildcard(), ConstantPattern())
+ conv2d = is_op('nn.conv2d')(wildcard(), is_constant())
pattern = is_op('nn.bias_add')(conv2d, wildcard())
x = relay.var('x', shape=(1, 3, 224, 224))
@@ -232,11 +232,11 @@ def test_match_tuple():
x = relay.var('x')
y = relay.var('y')
z = relay.op.op.get("add")
- tuple_pattern = TuplePattern((is_input("x"), wildcard(), is_op("add")))
+ tuple_pattern = is_tuple((is_var("x"), wildcard(), is_op("add")))
assert tuple_pattern.match(relay.expr.Tuple((x, y, z)))
- tuple_pattern = TuplePattern((is_input("x"), wildcard(), is_op("add")))
- tuple_get_item_pattern = TupleGetItemPattern(tuple_pattern, 1)
+ tuple_pattern = is_tuple((is_var("x"), wildcard(), is_op("add")))
+ tuple_get_item_pattern = is_tuple_get_item(tuple_pattern, 1)
assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 1))
@@ -244,11 +244,11 @@ def test_no_match_tuple():
x = relay.var('x')
y = relay.var('y')
z = relay.op.op.get("add")
- tuple_pattern = TuplePattern((is_input('x'), wildcard(), is_op("add"), wildcard()))
+ tuple_pattern = is_tuple((is_var('x'), wildcard(), is_op("add"), wildcard()))
assert not tuple_pattern.match(relay.expr.Tuple((x, y, z)))
- tuple_pattern = TuplePattern((is_input('x'), wildcard(), is_op("add")))
- tuple_get_item_pattern = TupleGetItemPattern(tuple_pattern, 1)
+ tuple_pattern = is_tuple((is_var('x'), wildcard(), is_op("add")))
+ tuple_get_item_pattern = is_tuple_get_item(tuple_pattern, 1)
assert not tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple(
(x, y, z)), 2))
@@ -596,7 +596,7 @@ class BatchnormCallback(DFPatternCallback):
self.mean = wildcard()
self.beta = wildcard()
self.gamma = wildcard()
- self.eps = ConstantPattern()
+ self.eps = is_constant()
self.pattern = self.gamma * (self.x - self.mean) / is_op("sqrt")(self.var + self.eps) + \
self.beta
@@ -760,8 +760,8 @@ def test_quadruple_rewrite_dominator():
def algebraic_simplify(expr):
- zero = (ExprPattern(relay.const(0)) | ExprPattern(relay.const(0.0)))
- one = (ExprPattern(relay.const(1)) | ExprPattern(relay.const(1.0)))
+ zero = (is_expr(relay.const(0)) | is_expr(relay.const(0.0)))
+ one = (is_expr(relay.const(1)) | is_expr(relay.const(1.0)))
class ElwiseNullCallback(DFPatternCallback):
def callback(self, pre, post, node_map):
@@ -1182,35 +1182,32 @@ def test_partition_constant_embedding():
assert tvm.ir.structural_equal(lifted_func(x, wc, b), pattern.partition(reluc))
# Check lifting of input matches
- pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_input()),
+ pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_var()),
wildcard()))
assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
assert tvm.ir.structural_equal(reluc, pattern.partition(reluc)) #Constants are not Inputs
# Check embedding of constant matches
- pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(),
- ConstantPattern()),
+ pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_constant()),
wildcard()))
assert tvm.ir.structural_equal(relu, pattern.partition(relu))
assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))
# Check embedding of constant ExprPatterns
- pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(),
- ExprPattern(wc)),
+ pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_expr(wc)),
wildcard()))
assert tvm.ir.structural_equal(relu, pattern.partition(relu))
assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))
# Check lifting/embedding of Alt matches
- pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_input()
- | ConstantPattern()),
- wildcard()))
+ pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(
+ wildcard(), is_var() | is_constant()), wildcard()))
assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))
# Check lifting/embedding of Alt matches with the other ordering
pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(
- wildcard(), ConstantPattern() | is_input()), wildcard()))
+ wildcard(), is_constant() | is_var()), wildcard()))
assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu))
assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc))