You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/07/21 06:37:14 UTC
[tvm] branch main updated: [UX][TIR][Schedule] enhance function annotation for tir primitive (#12147)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e54f324311 [UX][TIR][Schedule] enhance function annotation for tir primitive (#12147)
e54f324311 is described below
commit e54f3243111e4a9a2bb83a18f937e637a55f1594
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Thu Jul 21 14:37:08 2022 +0800
[UX][TIR][Schedule] enhance function annotation for tir primitive (#12147)
* [UX][TIR][Schedule] enhance function annotation for tir primitive
* lint
* fix mypy
* fix pylint
---
python/tvm/tir/schedule/_type_checker.py | 9 ++++++---
python/tvm/tir/schedule/schedule.py | 8 +++-----
2 files changed, 9 insertions(+), 8 deletions(-)
diff --git a/python/tvm/tir/schedule/_type_checker.py b/python/tvm/tir/schedule/_type_checker.py
index 2dc8ff9d58..564d23afad 100644
--- a/python/tvm/tir/schedule/_type_checker.py
+++ b/python/tvm/tir/schedule/_type_checker.py
@@ -17,7 +17,7 @@
"""Type checking functionality"""
import functools
import inspect
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
import typing
@@ -216,7 +216,10 @@ def _type_check(v: Any, name: str, type_: Any) -> Optional[str]:
return _TYPE_CHECK[key](v, name, *subtypes)
-def type_checked(func: Callable) -> Callable:
+FType = TypeVar("FType", bound=Callable[..., Any])
+
+
+def type_checked(func: FType) -> FType:
"""Type check the input arguments of a function."""
sig = inspect.signature(func)
@@ -236,4 +239,4 @@ def type_checked(func: Callable) -> Callable:
raise TypeError(error_msg)
return func(*args, **kwargs)
- return wrap
+ return wrap # type: ignore
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index fe0afa3011..73bb8140e1 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -2319,10 +2319,10 @@ class Schedule(Object):
self, block: BlockRV, buffer: Union[Tuple[str, int], str, Buffer]
) -> Tuple[str, int, Buffer]:
- block_name = self.get(block).name_hint
+ block_obj: Block = self.get(block)
+ block_name = block_obj.name_hint
def iter_buffers():
- block_obj = self.get(block)
for i, read in enumerate(block_obj.reads):
yield "read", i, read.buffer
for i, write in enumerate(block_obj.writes):
@@ -2358,9 +2358,7 @@ class Schedule(Object):
f"Expected 'read' or 'write', "
f"but received {buffer_index_type}"
)
- buffer_list = (
- self.get(block).reads if buffer_index_type == "read" else self.get(block).writes
- )
+ buffer_list = block_obj.reads if buffer_index_type == "read" else block_obj.writes
assert 0 <= buffer_index < len(buffer_list), (
f"Invalid buffer_index {buffer_index}. "
f"Block {block_name} has only "