You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/08/29 21:33:15 UTC
[tvm] branch main updated: [Utils] Handled Callable in tir.schedule._type_checker (#12633)
This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 74988d36bd [Utils] Handled Callable in tir.schedule._type_checker (#12633)
74988d36bd is described below
commit 74988d36bd578b791bbdcea383d343d62029e9cf
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Mon Aug 29 14:33:04 2022 -0700
[Utils] Handled Callable in tir.schedule._type_checker (#12633)
Previously, `Callable` was handled as an atomic type. This worked
when it was included as last element of a `Union[]` annotation with no
subtypes, but raised an error for other use cases, including
`Optional[Callable]`.
This commit adds explicit checks for `Callable` type annotations to
validate whether the argument is callable, but doesn't recursively
validate the signature of the callable object, because lambda
functions cannot have type
annotations. (https://peps.python.org/pep-3107/#lambda)
---
python/tvm/tir/schedule/_type_checker.py | 40 +++++++++++
.../unittest/test_type_annotation_checker.py | 77 ++++++++++++++++++----
2 files changed, 103 insertions(+), 14 deletions(-)
diff --git a/python/tvm/tir/schedule/_type_checker.py b/python/tvm/tir/schedule/_type_checker.py
index d45b4fb84b..0b48dfc2b0 100644
--- a/python/tvm/tir/schedule/_type_checker.py
+++ b/python/tvm/tir/schedule/_type_checker.py
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""Type checking functionality"""
+import collections
+import collections.abc
import functools
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
@@ -26,6 +28,7 @@ def _is_none_type(type_: Any) -> bool:
if hasattr(typing, "_GenericAlias"):
+ # For python versions 3.7 onward, check the __origin__ attribute.
class _Subtype:
@staticmethod
@@ -71,7 +74,15 @@ if hasattr(typing, "_GenericAlias"):
return list(subtypes)
return None
+ @staticmethod
+ def callable(type_: Any) -> Optional[List[type]]:
+ if _Subtype._origin(type_) is collections.abc.Callable:
+ subtypes = type_.__args__
+ return subtypes
+ return None
+
elif hasattr(typing, "_Union"):
+ # For python 3.6 and below, check the __name__ attribute, or CallableMeta.
class _Subtype: # type: ignore
@staticmethod
@@ -114,6 +125,13 @@ elif hasattr(typing, "_Union"):
return list(subtypes)
return None
+ @staticmethod
+ def callable(type_: Any) -> Optional[List[type]]:
+ if isinstance(type_, typing.CallableMeta): # type: ignore # pylint: disable=no-member,protected-access
+ subtypes = type_.__args__
+ return subtypes
+ return None
+
def _dispatcher(type_: Any) -> Tuple[str, List[type]]:
if _is_none_type(type_):
@@ -139,12 +157,27 @@ def _dispatcher(type_: Any) -> Tuple[str, List[type]]:
if subtype is not None:
return "union", subtype
+ subtype = _Subtype.callable(type_)
+ if subtype is not None:
+ return "callable", subtype
+
return "atomic", [type_]
+def callable_str(subtypes):
+ if subtypes:
+ *arg_types, return_type = subtypes
+ arg_str = ", ".join(_type2str(arg_type) for arg_type in arg_types)
+ return_type_str = _type2str(return_type)
+ return f"Callable[[{arg_str}], {return_type_str}]"
+ else:
+ return "Callable"
+
+
_TYPE2STR: Dict[Any, Callable] = {
"none": lambda: "None",
"atomic": lambda t: str(t.__name__),
+ "callable": callable_str,
"list": lambda t: f"List[{_type2str(t)}]",
"dict": lambda k, v: f"Dict[{_type2str(k)}, {_type2str(v)}]",
"tuple": lambda *t: f"Tuple[{', '.join([_type2str(x) for x in t])}]",
@@ -188,6 +221,12 @@ def _type_check_vtable() -> Dict[str, Callable]:
def _type_check_atomic(v: Any, name: str, type_: Any) -> Optional[str]:
return None if isinstance(v, type_) else _type_check_err(v, name, type_)
+ def _type_check_callable(v: Any, name: str, *_subtypes: Any) -> Optional[str]:
+ # Current implementation only validates that the argument is
+ # callable, and doesn't validate the arguments accepted by the
+ # callable, if any.
+ return None if callable(v) else _type_check_err(v, name, Callable)
+
def _type_check_list(v: List[Any], name: str, type_: Any) -> Optional[str]:
if not isinstance(v, (list, tuple)):
return _type_check_err(v, name, list)
@@ -234,6 +273,7 @@ def _type_check_vtable() -> Dict[str, Callable]:
return {
"none": _type_check_none,
"atomic": _type_check_atomic,
+ "callable": _type_check_callable,
"list": _type_check_list,
"dict": _type_check_dict,
"tuple": _type_check_tuple,
diff --git a/tests/python/unittest/test_type_annotation_checker.py b/tests/python/unittest/test_type_annotation_checker.py
index e84ae043d3..204c153313 100644
--- a/tests/python/unittest/test_type_annotation_checker.py
+++ b/tests/python/unittest/test_type_annotation_checker.py
@@ -17,13 +17,22 @@
"""Test type checker based on python's type annotations"""
import sys
-from typing import Dict, List, Tuple, Union
+from typing import Dict, List, Tuple, Union, Callable
import pytest
+import _pytest
from tvm.tir.schedule._type_checker import type_checked
+def int_func(x: int) -> int:
+ return 2 * x
+
+
+def str_func(x: str) -> str:
+ return 2 * x
+
+
test_cases = [
{
"type_annotation": int,
@@ -90,30 +99,71 @@ test_cases = [
None,
],
},
+ {
+ "type_annotation": Callable,
+ "positive_cases": [str_func, int_func],
+ "negative_cases": [
+ None,
+ "x",
+ 42,
+ ],
+ },
+ {
+ "type_annotation": Callable[[int], int],
+ "positive_cases": [int_func],
+ "negative_cases": [
+ None,
+ "x",
+ 42,
+ pytest.param(
+ str_func,
+ marks=pytest.mark.xfail(
+ reason="Signature of Callable arguments not currently checked"
+ ),
+ ),
+ ],
+ },
]
-positive_cases = [
- (config["type_annotation"], case) for config in test_cases for case in config["positive_cases"]
-]
-
-negative_cases = [
- (config["type_annotation"], case) for config in test_cases for case in config["negative_cases"]
-]
+def make_parametrization(type_annotation, case):
+ if isinstance(case, _pytest.mark.structures.ParameterSet):
+ marks = case.marks
+ (case,) = case.values
+ else:
+ marks = []
-def format_name(type_annotation, case):
try:
- name = type_annotation.__name__
+ annotation_name = type_annotation.__name__
except AttributeError:
- name = str(type_annotation).replace("typing.", "")
+ annotation_name = str(type_annotation).replace("typing.", "")
+
+ if hasattr(case, "__name__"):
+ case_name = case.__name__
+ else:
+ case_name = str(case)
- return f"{name}_{case}"
+ name = f"{annotation_name}, {case_name}"
+
+ return pytest.param(type_annotation, case, marks=marks, id=name)
+
+
+positive_cases = [
+ make_parametrization(config["type_annotation"], case)
+ for config in test_cases
+ for case in config["positive_cases"]
+]
+
+negative_cases = [
+ make_parametrization(config["type_annotation"], case)
+ for config in test_cases
+ for case in config["negative_cases"]
+]
@pytest.mark.parametrize(
["type_annotation", "case"],
positive_cases,
- ids=[format_name(t, c) for t, c in positive_cases],
)
def test_matches_type(type_annotation, case):
@type_checked
@@ -126,7 +176,6 @@ def test_matches_type(type_annotation, case):
@pytest.mark.parametrize(
["type_annotation", "case"],
negative_cases,
- ids=[format_name(t, c) for t, c in negative_cases],
)
def test_not_matches(type_annotation, case):
@type_checked