You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/07/21 06:37:14 UTC

[tvm] branch main updated: [UX][TIR][Schedule] enhance function annotation for tir primitive (#12147)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e54f324311 [UX][TIR][Schedule] enhance function annotation for tir primitive (#12147)
e54f324311 is described below

commit e54f3243111e4a9a2bb83a18f937e637a55f1594
Author: Siyuan Feng <Hz...@sjtu.edu.cn>
AuthorDate: Thu Jul 21 14:37:08 2022 +0800

    [UX][TIR][Schedule] enhance function annotation for tir primitive (#12147)
    
    * [UX][TIR][Schedule] enhance function annotation for tir primitive
    
    * lint
    
    * fix mypy
    
    * fix pylint
---
 python/tvm/tir/schedule/_type_checker.py | 9 ++++++---
 python/tvm/tir/schedule/schedule.py      | 8 +++-----
 2 files changed, 9 insertions(+), 8 deletions(-)

diff --git a/python/tvm/tir/schedule/_type_checker.py b/python/tvm/tir/schedule/_type_checker.py
index 2dc8ff9d58..564d23afad 100644
--- a/python/tvm/tir/schedule/_type_checker.py
+++ b/python/tvm/tir/schedule/_type_checker.py
@@ -17,7 +17,7 @@
 """Type checking functionality"""
 import functools
 import inspect
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
 import typing
 
 
@@ -216,7 +216,10 @@ def _type_check(v: Any, name: str, type_: Any) -> Optional[str]:
     return _TYPE_CHECK[key](v, name, *subtypes)
 
 
-def type_checked(func: Callable) -> Callable:
+FType = TypeVar("FType", bound=Callable[..., Any])
+
+
+def type_checked(func: FType) -> FType:
     """Type check the input arguments of a function."""
     sig = inspect.signature(func)
 
@@ -236,4 +239,4 @@ def type_checked(func: Callable) -> Callable:
                     raise TypeError(error_msg)
         return func(*args, **kwargs)
 
-    return wrap
+    return wrap  # type: ignore
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index fe0afa3011..73bb8140e1 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -2319,10 +2319,10 @@ class Schedule(Object):
         self, block: BlockRV, buffer: Union[Tuple[str, int], str, Buffer]
     ) -> Tuple[str, int, Buffer]:
 
-        block_name = self.get(block).name_hint
+        block_obj: Block = self.get(block)
+        block_name = block_obj.name_hint
 
         def iter_buffers():
-            block_obj = self.get(block)
             for i, read in enumerate(block_obj.reads):
                 yield "read", i, read.buffer
             for i, write in enumerate(block_obj.writes):
@@ -2358,9 +2358,7 @@ class Schedule(Object):
                 f"Expected 'read' or 'write', "
                 f"but received {buffer_index_type}"
             )
-            buffer_list = (
-                self.get(block).reads if buffer_index_type == "read" else self.get(block).writes
-            )
+            buffer_list = block_obj.reads if buffer_index_type == "read" else block_obj.writes
             assert 0 <= buffer_index < len(buffer_list), (
                 f"Invalid buffer_index {buffer_index}.  "
                 f"Block {block_name} has only "