You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/08/29 21:33:15 UTC

[tvm] branch main updated: [Utils] Handled Callable in tir.schedule._type_checker (#12633)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 74988d36bd [Utils] Handled Callable in tir.schedule._type_checker (#12633)
74988d36bd is described below

commit 74988d36bd578b791bbdcea383d343d62029e9cf
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Mon Aug 29 14:33:04 2022 -0700

    [Utils] Handled Callable in tir.schedule._type_checker (#12633)
    
    Previously, `Callable` was handled as an atomic type.  This worked
    when it was included as last element of a `Union[]` annotation with no
    subtypes, but raised an error for other use cases, including
    `Optional[Callable]`.
    
    This commit adds explicit checks for `Callable` type annotations to
    validate whether the argument is callable, but doesn't recursively
    validate the signature of the callable object, because lambda
    functions cannot have type
    annotations. (https://peps.python.org/pep-3107/#lambda)
---
 python/tvm/tir/schedule/_type_checker.py           | 40 +++++++++++
 .../unittest/test_type_annotation_checker.py       | 77 ++++++++++++++++++----
 2 files changed, 103 insertions(+), 14 deletions(-)

diff --git a/python/tvm/tir/schedule/_type_checker.py b/python/tvm/tir/schedule/_type_checker.py
index d45b4fb84b..0b48dfc2b0 100644
--- a/python/tvm/tir/schedule/_type_checker.py
+++ b/python/tvm/tir/schedule/_type_checker.py
@@ -15,6 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 """Type checking functionality"""
+import collections
+import collections.abc
 import functools
 import inspect
 from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
@@ -26,6 +28,7 @@ def _is_none_type(type_: Any) -> bool:
 
 
 if hasattr(typing, "_GenericAlias"):
+    # For python versions 3.7 onward, check the __origin__ attribute.
 
     class _Subtype:
         @staticmethod
@@ -71,7 +74,15 @@ if hasattr(typing, "_GenericAlias"):
                     return list(subtypes)
             return None
 
+        @staticmethod
+        def callable(type_: Any) -> Optional[List[type]]:
+            if _Subtype._origin(type_) is collections.abc.Callable:
+                subtypes = type_.__args__
+                return subtypes
+            return None
+
 elif hasattr(typing, "_Union"):
+    # For python 3.6 and below, check the __name__ attribute, or CallableMeta.
 
     class _Subtype:  # type: ignore
         @staticmethod
@@ -114,6 +125,13 @@ elif hasattr(typing, "_Union"):
                     return list(subtypes)
             return None
 
+        @staticmethod
+        def callable(type_: Any) -> Optional[List[type]]:
+            if isinstance(type_, typing.CallableMeta):  # type: ignore # pylint: disable=no-member,protected-access
+                subtypes = type_.__args__
+                return subtypes
+            return None
+
 
 def _dispatcher(type_: Any) -> Tuple[str, List[type]]:
     if _is_none_type(type_):
@@ -139,12 +157,27 @@ def _dispatcher(type_: Any) -> Tuple[str, List[type]]:
     if subtype is not None:
         return "union", subtype
 
+    subtype = _Subtype.callable(type_)
+    if subtype is not None:
+        return "callable", subtype
+
     return "atomic", [type_]
 
 
+def callable_str(subtypes):
+    if subtypes:
+        *arg_types, return_type = subtypes
+        arg_str = ", ".join(_type2str(arg_type) for arg_type in arg_types)
+        return_type_str = _type2str(return_type)
+        return f"Callable[[{arg_str}], {return_type_str}]"
+    else:
+        return "Callable"
+
+
 _TYPE2STR: Dict[Any, Callable] = {
     "none": lambda: "None",
     "atomic": lambda t: str(t.__name__),
+    "callable": callable_str,
     "list": lambda t: f"List[{_type2str(t)}]",
     "dict": lambda k, v: f"Dict[{_type2str(k)}, {_type2str(v)}]",
     "tuple": lambda *t: f"Tuple[{', '.join([_type2str(x) for x in t])}]",
@@ -188,6 +221,12 @@ def _type_check_vtable() -> Dict[str, Callable]:
     def _type_check_atomic(v: Any, name: str, type_: Any) -> Optional[str]:
         return None if isinstance(v, type_) else _type_check_err(v, name, type_)
 
+    def _type_check_callable(v: Any, name: str, *_subtypes: Any) -> Optional[str]:
+        # Current implementation only validates that the argument is
+        # callable, and doesn't validate the arguments accepted by the
+        # callable, if any.
+        return None if callable(v) else _type_check_err(v, name, Callable)
+
     def _type_check_list(v: List[Any], name: str, type_: Any) -> Optional[str]:
         if not isinstance(v, (list, tuple)):
             return _type_check_err(v, name, list)
@@ -234,6 +273,7 @@ def _type_check_vtable() -> Dict[str, Callable]:
     return {
         "none": _type_check_none,
         "atomic": _type_check_atomic,
+        "callable": _type_check_callable,
         "list": _type_check_list,
         "dict": _type_check_dict,
         "tuple": _type_check_tuple,
diff --git a/tests/python/unittest/test_type_annotation_checker.py b/tests/python/unittest/test_type_annotation_checker.py
index e84ae043d3..204c153313 100644
--- a/tests/python/unittest/test_type_annotation_checker.py
+++ b/tests/python/unittest/test_type_annotation_checker.py
@@ -17,13 +17,22 @@
 """Test type checker based on python's type annotations"""
 
 import sys
-from typing import Dict, List, Tuple, Union
+from typing import Dict, List, Tuple, Union, Callable
 
 import pytest
+import _pytest
 
 from tvm.tir.schedule._type_checker import type_checked
 
 
+def int_func(x: int) -> int:
+    return 2 * x
+
+
+def str_func(x: str) -> str:
+    return 2 * x
+
+
 test_cases = [
     {
         "type_annotation": int,
@@ -90,30 +99,71 @@ test_cases = [
             None,
         ],
     },
+    {
+        "type_annotation": Callable,
+        "positive_cases": [str_func, int_func],
+        "negative_cases": [
+            None,
+            "x",
+            42,
+        ],
+    },
+    {
+        "type_annotation": Callable[[int], int],
+        "positive_cases": [int_func],
+        "negative_cases": [
+            None,
+            "x",
+            42,
+            pytest.param(
+                str_func,
+                marks=pytest.mark.xfail(
+                    reason="Signature of Callable arguments not currently checked"
+                ),
+            ),
+        ],
+    },
 ]
 
-positive_cases = [
-    (config["type_annotation"], case) for config in test_cases for case in config["positive_cases"]
-]
-
-negative_cases = [
-    (config["type_annotation"], case) for config in test_cases for case in config["negative_cases"]
-]
 
+def make_parametrization(type_annotation, case):
+    if isinstance(case, _pytest.mark.structures.ParameterSet):
+        marks = case.marks
+        (case,) = case.values
+    else:
+        marks = []
 
-def format_name(type_annotation, case):
     try:
-        name = type_annotation.__name__
+        annotation_name = type_annotation.__name__
     except AttributeError:
-        name = str(type_annotation).replace("typing.", "")
+        annotation_name = str(type_annotation).replace("typing.", "")
+
+    if hasattr(case, "__name__"):
+        case_name = case.__name__
+    else:
+        case_name = str(case)
 
-    return f"{name}_{case}"
+    name = f"{annotation_name}, {case_name}"
+
+    return pytest.param(type_annotation, case, marks=marks, id=name)
+
+
+positive_cases = [
+    make_parametrization(config["type_annotation"], case)
+    for config in test_cases
+    for case in config["positive_cases"]
+]
+
+negative_cases = [
+    make_parametrization(config["type_annotation"], case)
+    for config in test_cases
+    for case in config["negative_cases"]
+]
 
 
 @pytest.mark.parametrize(
     ["type_annotation", "case"],
     positive_cases,
-    ids=[format_name(t, c) for t, c in positive_cases],
 )
 def test_matches_type(type_annotation, case):
     @type_checked
@@ -126,7 +176,6 @@ def test_matches_type(type_annotation, case):
 @pytest.mark.parametrize(
     ["type_annotation", "case"],
     negative_cases,
-    ids=[format_name(t, c) for t, c in negative_cases],
 )
 def test_not_matches(type_annotation, case):
     @type_checked