You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/11/02 19:01:52 UTC

[GitHub] [tvm] shingjan opened a new pull request #9432: [TIR][WIP] Add type hint for TIR

shingjan opened a new pull request #9432:
URL: https://github.com/apache/tvm/pull/9432


   This PR intends to add type hinting and enable auto-completion for `TIR` and namespace `tvm.script.tir as T`
   
   Co-authored-by: Zihao Ye <zi...@gmail.com>
   
   cc: @vinx13 @junrushao1994 
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r742614926



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,270 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from numbers import Number
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...
+def func_attr(attrs: Dict) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(extents, dtype, scope: str, condition=True, annotations=None) -> None: ...
+def launch_thread(env_var, extent): ...
+def realize(buffer_slice: BufferSlice, scope: str, condition=True) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition, message): ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: List[Union[PrimExpr, int]]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty - redefine types
+"""
+
+class boolean: ...
+class handle: ...
+class int8: ...
+class int16: ...
+class int32: ...
+class int64: ...
+class float16: ...

Review comment:
       yes it did. And due to the fact that `__init__` function cannot return anything but `None`, I believe one way to work around this is to only include those definitions in `script.tir.ty.py`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#issuecomment-958416998


   We will need a few test cases that cover various TIR elements (block related: T.block, T.match_buffer, T.alloc_buffer, T.block_attr, loop related: T.grid, T.serial, T.thread_binding, etc). You can find good examples of TIR scripts written for previous test cases. Please also add script to run the tests to `https://github.com/apache/tvm/blob/main/tests/scripts/task_mypy.sh` and make sure it can pass the tests


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR][WIP] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r741449460



##########
File path: python/tvm/tir/__init__.pyi
##########
@@ -0,0 +1,612 @@
+from typing import (

Review comment:
       Thanks for the prompt review! There is a discussion in which Wuwei and I decided to put `PrimExpr` and `PrimExprWithOp` (for future use) here in `tir/__init__.pyi` instead of `script/tir/__init__.pyi`. Therefore stubs from `script/tir` will need to import `PrimExpr`, `IterVar` and etc. from `tir`. I am curious about your take on this.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r742315159



##########
File path: mypy.ini
##########
@@ -23,6 +23,14 @@ follow_imports = skip
 ignore_errors = False
 strict_optional = False
 
+#
+# Note: not all tests under .tests/ are typed 
+# Therefore include test files that should be
+# checked by mypy here
+#
+files = 
+    tests/python/unittest/test_tvmscript_type.py

Review comment:
       Good point. The reason why I include that test in `mypy.ini` is that PyTorch seems to add individual test to mypy this [way](https://github.com/pytorch/pytorch/blob/master/mypy.ini). But adding that to `task_mypy.sh` should also work.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r742308542



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...

Review comment:
       the definition of `bool` here actually shadows all use of native `bool` type annotation of Python in this file. My workaround is to not include type annotation for method/class parameter that may be typed `bool`. E.g. this line would just be `def bool(imm) -> PrimExpr: ...`
   
   

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...

Review comment:
       done

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def range(begin, end): ...

Review comment:
       removed




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r742258491



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...

Review comment:
       imm should has type Union[PrimExpr, bool, Number]

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...

Review comment:
       (for this function and below) `imm: Union[PrimExpr, Number]`

##########
File path: mypy.ini
##########
@@ -23,6 +23,14 @@ follow_imports = skip
 ignore_errors = False
 strict_optional = False
 
+#
+# Note: not all tests under .tests/ are typed 
+# Therefore include test files that should be
+# checked by mypy here
+#
+files = 
+    tests/python/unittest/test_tvmscript_type.py

Review comment:
       I'd prefer adding the command to https://github.com/apache/tvm/blob/main/tests/scripts/task_mypy.sh
   My concern is that adding it in the config file is not as explicit as adding in a centralized file of test commands. Also  IIUC this implies check the specified file every time running Mypy (even if it is intended to type check other modules)

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def range(begin, end): ...

Review comment:
       duplicated with definitions below 

##########
File path: tests/python/unittest/test_tvmscript_type.py
##########
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys
+import pytest
+import tvm
+from tvm import tir
+from tvm.script import tir as T
+
+"""
+This module tests the type of
+T.prim_func, T.handle, T.match_buffer, T.block
+T.reads, T.writes, T.alloc_buffer, T.serial
+T.block_attr, T.float32
+"""
+
+
+@pytest.mark.mypy_testing

Review comment:
       remove this as not needed

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+from . import axis
+from .ty import ConcreteType
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+class Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+Axis
+"""
+
+def reduce_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def range(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def scan_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def opaque_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+def axis_spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_remap(iter_types: str, loop_vars: List[Var]) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...
+def func_attr(attrs: Dict) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(extents, dtype, scope: str, condition=True, annotations=None) -> None: ...
+def launch_thread(env_var, extent): ...
+def realize(buffer_slice: BufferSlice, scope: str, condition=True) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition, message): ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: List[Union[PrimExpr, int]]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty
+"""
+boolean = ConcreteType("bool")
+handle = ConcreteType("handle")

Review comment:
       is it necessary to instantiate `ConcreteType`? alternatively you can just add `class boolean` to declare a type
   also if you want to add the contents ty.py here it should be exhaustive

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+from . import axis
+from .ty import ConcreteType
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+class Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+Axis
+"""
+
+def reduce_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def range(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def scan_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def opaque_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+def axis_spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...

Review comment:
       should be converted to static method of `class axis` (`axis_spatial` is not a keyword exposed to user)
   e.g. 
   ```
   class axis:
     @staticmethod
     def spatial(...)
   ```

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,270 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from numbers import Number
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm) -> PrimExpr: ...

Review comment:
       you can use `builtins.bool` to refer to python's bool type

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,270 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from numbers import Number
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...
+def func_attr(attrs: Dict) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(extents, dtype, scope: str, condition=True, annotations=None) -> None: ...
+def launch_thread(env_var, extent): ...
+def realize(buffer_slice: BufferSlice, scope: str, condition=True) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition, message): ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: List[Union[PrimExpr, int]]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty - redefine types
+"""
+
+class boolean: ...
+class handle: ...
+class int8: ...
+class int16: ...
+class int32: ...
+class int64: ...
+class float16: ...

Review comment:
       does it shadow the function name like float16 above?
   If so, you can also try making those functions as `__init__` to these types 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r743217991



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, *pos: Union[int, PrimExpr, slice, Any]) -> Buffer: ...

Review comment:
       isn't the return value `PrimExpr`?

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):

Review comment:
       `Buffer` shouldn't subclass `Var`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r743977546



##########
File path: tests/python/unittest/test_tvmscript_type.py
##########
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement
+from tvm.script import tir as T
+
+
+@T.prim_func
+def element_wise_storage_align(a: T.handle, c: T.handle) -> None:
+    """
+    This prim func include necessary buffer types that need to be checked
+    e.g. reads/writes, match_buffer/alloc_buffer, serial/block etc.
+    """
+    C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1)
+    A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1)
+    # body
+    with T.block("root"):
+        T.reads([])
+        T.writes([])
+        B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1)
+        for i0 in T.serial(0, 128):
+            for ax1 in T.serial(0, 128):
+                with T.block("B"):
+                    vi, vj = T.axis.remap("SS", [i0, ax1])
+                    T.reads([A[vi, vj]])
+                    T.writes([B[vi, vj]])
+                    T.block_attr({"buffer_dim_align": [[0, 0, 128, 127]]})
+                    B[vi, vj] = A[vi, vj] * T.float32(2)
+            for i1 in T.serial(0, 128):
+                with T.block("C"):
+                    vi_1, vj_1 = T.axis.remap("SS", [i0, i1])
+                    T.reads([B[vi_1, vj_1]])
+                    T.writes([C[vi_1, vj_1]])
+                    C[vi_1, vj_1] = B[vi_1, vj_1] + T.float32(1)
+
+
+"""

Review comment:
       remove this




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r743318193



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):

Review comment:
       is it because `__setitem__` is missing?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r743148098



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(
+        self: Buffer, pos: Tuple[Union[int, PrimExpr], Union[int, PrimExpr]]
+    ) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...

Review comment:
       ```suggestion
   def buffer_var(dtype: str, storage_scope: str) -> IterVar: ...
   ```

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(
+        self: Buffer, pos: Tuple[Union[int, PrimExpr], Union[int, PrimExpr]]
+    ) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...

Review comment:
       can we type these functions?

##########
File path: tests/python/unittest/test_tvmscript_type.py
##########
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys

Review comment:
       not needed

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(
+        self: Buffer, pos: Tuple[Union[int, PrimExpr], Union[int, PrimExpr]]
+    ) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...
+def func_attr(attrs: Dict) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(
+    extents, dtype, scope: str, condition: builtins.bool = True, annotations=None
+) -> None: ...
+def launch_thread(env_var, extent) -> None: ...
+def realize(buffer_slice: BufferSlice, scope: str, condition: builtins.bool = True) -> None: ...

Review comment:
       same as above

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(
+        self: Buffer, pos: Tuple[Union[int, PrimExpr], Union[int, PrimExpr]]
+    ) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...
+def func_attr(attrs: Dict) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(
+    extents, dtype, scope: str, condition: builtins.bool = True, annotations=None

Review comment:
       ```suggestion
       extents, dtype, scope: str, condition: Union[PrimExpr, builtins.bool] = True, annotations=None
   ```
   if possible, please also type other args




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r745112463



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,353 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import Range
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class PrimExpr:
+    def __init__(self: PrimExpr) -> None: ...
+    @overload
+    def __add__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __add__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    @overload
+    def __sub__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __sub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    @overload
+    def __mul__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __mul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    @overload
+    def __div__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __div__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __radd__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __rsub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __rmul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __rdiv__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+
+class Var(PrimExpr): ...
+class IterVar(Var): ...
+
+class Buffer:
+    @overload
+    def __getitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int]]) -> PrimExpr: ...
+    @overload
+    def __getitem__(self: Buffer, pos: Union[PrimExpr, int]) -> PrimExpr: ...
+    @overload
+    def __setitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int]], value: PrimExpr) -> None: ...
+    @overload
+    def __setitem__(self: Buffer, pos: Union[PrimExpr, int], value: PrimExpr) -> None: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, builtins.int]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype: str) -> PrimExpr: ...
+def max_value(dtype: str) -> PrimExpr: ...
+def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def abs(x: PrimExpr) -> PrimExpr: ...
+def load(
+    dtype: str, var: Var, index: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = None
+) -> PrimExpr: ...
+def cast(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ...
+def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ...
+def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ...
+def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ...
+def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
+def evaluate(value: PrimExpr) -> None: ...
+def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def store(
+    var: Var, index: PrimExpr, value: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = True
+) -> None: ...
+def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @overload
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def spatial(
+        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
+    ) -> IterVar: ...
+    @overload
+    @staticmethod
+    def S(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def S(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def reduce(
+        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
+    ) -> IterVar: ...
+    @overload
+    @staticmethod
+    def R(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def R(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def scan(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def scan(
+        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
+    ) -> IterVar: ...
+    @overload
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def opaque(
+        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
+    ) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+def get_axis(begin: PrimExpr, end: PrimExpr, iter_type: int) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype: str, storage_scope: str) -> Var: ...
+def func_attr(attrs: Mapping[str, Object]) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(
+    extents: List[PrimExpr],
+    dtype: str,
+    scope: str,
+    condition: Union[PrimExpr, builtins.bool] = True,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Var: ...
+def launch_thread(env_var: Var, extent: Union[int, PrimExpr]) -> Var: ...
+def realize(
+    buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True
+) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr: ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: Union[PrimExpr, int]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty - redefine types
+"""
+
+class boolean: ...
+
+class handle:
+    def __getitem__(self: handle, pos: Tuple[Union[int, PrimExpr, slice]]) -> Buffer: ...
+    def __setitem__(
+        self: handle, pos: Tuple[Union[int, PrimExpr, slice]], value: Buffer

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r742304506



##########
File path: tests/python/unittest/test_tvmscript_type.py
##########
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys
+import pytest
+import tvm
+from tvm import tir
+from tvm.script import tir as T
+
+"""
+This module tests the type of
+T.prim_func, T.handle, T.match_buffer, T.block
+T.reads, T.writes, T.alloc_buffer, T.serial
+T.block_attr, T.float32
+"""
+
+
+@pytest.mark.mypy_testing

Review comment:
       done

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...

Review comment:
       the definition of `bool` here actually shadows all use of native `bool` type annotation of Python in this file. My workaround is to not include type annotation for method/class parameter that may be typed `bool`. E.g. this line would just be `def bool(imm) -> PrimExpr: ...`
   
   

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...

Review comment:
       done

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def range(begin, end): ...

Review comment:
       removed

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+from . import axis
+from .ty import ConcreteType
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+class Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+Axis
+"""
+
+def reduce_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def range(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def scan_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def opaque_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+def axis_spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...

Review comment:
       done

##########
File path: mypy.ini
##########
@@ -23,6 +23,14 @@ follow_imports = skip
 ignore_errors = False
 strict_optional = False
 
+#
+# Note: not all tests under .tests/ are typed 
+# Therefore include test files that should be
+# checked by mypy here
+#
+files = 
+    tests/python/unittest/test_tvmscript_type.py

Review comment:
       Good point. The reason why I include that test in `mypy.ini` is that PyTorch seems to add individual test to mypy this [way](https://github.com/pytorch/pytorch/blob/master/mypy.ini). But adding that to `task_mypy.sh` should also work.

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+from . import axis
+from .ty import ConcreteType
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+class Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+Axis
+"""
+
+def reduce_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def range(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def scan_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def opaque_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+def axis_spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_remap(iter_types: str, loop_vars: List[Var]) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...
+def func_attr(attrs: Dict) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(extents, dtype, scope: str, condition=True, annotations=None) -> None: ...
+def launch_thread(env_var, extent): ...
+def realize(buffer_slice: BufferSlice, scope: str, condition=True) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition, message): ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: List[Union[PrimExpr, int]]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty
+"""
+boolean = ConcreteType("bool")
+handle = ConcreteType("handle")

Review comment:
       done

##########
File path: tests/python/unittest/test_tvmscript_type.py
##########
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys
+import pytest
+import tvm
+from tvm import tir
+from tvm.script import tir as T
+
+"""
+This module tests the type of
+T.prim_func, T.handle, T.match_buffer, T.block
+T.reads, T.writes, T.alloc_buffer, T.serial
+T.block_attr, T.float32
+"""
+
+
+@pytest.mark.mypy_testing

Review comment:
       done

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...

Review comment:
       the definition of `bool` here actually shadows all use of native `bool` type annotation of Python in this file. My workaround is to not include type annotation for method/class parameter that may be typed `bool`. E.g. this line would just be `def bool(imm) -> PrimExpr: ...`
   
   

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...

Review comment:
       done

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def range(begin, end): ...

Review comment:
       removed

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+from . import axis
+from .ty import ConcreteType
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+class Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+Axis
+"""
+
+def reduce_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def range(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def scan_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def opaque_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+def axis_spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...

Review comment:
       done

##########
File path: mypy.ini
##########
@@ -23,6 +23,14 @@ follow_imports = skip
 ignore_errors = False
 strict_optional = False
 
+#
+# Note: not all tests under .tests/ are typed 
+# Therefore include test files that should be
+# checked by mypy here
+#
+files = 
+    tests/python/unittest/test_tvmscript_type.py

Review comment:
       Good point. The reason why I include that test in `mypy.ini` is that PyTorch seems to add individual test to mypy this [way](https://github.com/pytorch/pytorch/blob/master/mypy.ini). But adding that to `task_mypy.sh` should also work.

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+from . import axis
+from .ty import ConcreteType
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+class Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+Axis
+"""
+
+def reduce_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def range(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def scan_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def opaque_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+def axis_spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_remap(iter_types: str, loop_vars: List[Var]) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...
+def func_attr(attrs: Dict) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(extents, dtype, scope: str, condition=True, annotations=None) -> None: ...
+def launch_thread(env_var, extent): ...
+def realize(buffer_slice: BufferSlice, scope: str, condition=True) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition, message): ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: List[Union[PrimExpr, int]]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty
+"""
+boolean = ConcreteType("bool")
+handle = ConcreteType("handle")

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r742258491



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...

Review comment:
       imm should has type Union[PrimExpr, bool, Number]

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...

Review comment:
       (for this function and below) `imm: Union[PrimExpr, Number]`

##########
File path: mypy.ini
##########
@@ -23,6 +23,14 @@ follow_imports = skip
 ignore_errors = False
 strict_optional = False
 
+#
+# Note: not all tests under .tests/ are typed 
+# Therefore include test files that should be
+# checked by mypy here
+#
+files = 
+    tests/python/unittest/test_tvmscript_type.py

Review comment:
       I'd prefer adding the command to https://github.com/apache/tvm/blob/main/tests/scripts/task_mypy.sh
   My concern is that adding it in the config file is not as explicit as adding in a centralized file of test commands. Also  IIUC this implies check the specified file every time running Mypy (even if it is intended to type check other modules)

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def range(begin, end): ...

Review comment:
       duplicated with definitions below 

##########
File path: tests/python/unittest/test_tvmscript_type.py
##########
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys
+import pytest
+import tvm
+from tvm import tir
+from tvm.script import tir as T
+
+"""
+This module tests the type of
+T.prim_func, T.handle, T.match_buffer, T.block
+T.reads, T.writes, T.alloc_buffer, T.serial
+T.block_attr, T.float32
+"""
+
+
+@pytest.mark.mypy_testing

Review comment:
       remove this as not needed

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+from . import axis
+from .ty import ConcreteType
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+class Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+Axis
+"""
+
+def reduce_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def range(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def scan_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def opaque_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+def axis_spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_remap(iter_types: str, loop_vars: List[Var]) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...
+def func_attr(attrs: Dict) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(extents, dtype, scope: str, condition=True, annotations=None) -> None: ...
+def launch_thread(env_var, extent): ...
+def realize(buffer_slice: BufferSlice, scope: str, condition=True) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition, message): ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: List[Union[PrimExpr, int]]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty
+"""
+boolean = ConcreteType("bool")
+handle = ConcreteType("handle")

Review comment:
       is it necessary to instantiate `ConcreteType`? alternatively you can just add `class boolean` to declare a type
   also if you want to add the contents ty.py here it should be exhaustive

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+from . import axis
+from .ty import ConcreteType
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+class Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+Axis
+"""
+
+def reduce_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def range(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def scan_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def opaque_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+def axis_spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...

Review comment:
       should be converted to static method of `class axis` (`axis_spatial` is not a keyword exposed to user)
   e.g. 
   ```
   class axis:
     @staticmethod
     def spatial(...)
   ```

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,270 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from numbers import Number
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm) -> PrimExpr: ...

Review comment:
       you can use `builtins.bool` to refer to python's bool type

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,270 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from numbers import Number
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...
+def func_attr(attrs: Dict) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(extents, dtype, scope: str, condition=True, annotations=None) -> None: ...
+def launch_thread(env_var, extent): ...
+def realize(buffer_slice: BufferSlice, scope: str, condition=True) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition, message): ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: List[Union[PrimExpr, int]]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty - redefine types
+"""
+
+class boolean: ...
+class handle: ...
+class int8: ...
+class int16: ...
+class int32: ...
+class int64: ...
+class float16: ...

Review comment:
       does it shadow the function name like float16 above?
   If so, you can also try making those functions as `__init__` to these types 

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...

Review comment:
       imm should has type Union[PrimExpr, bool, Number]

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...

Review comment:
       (for this function and below) `imm: Union[PrimExpr, Number]`

##########
File path: mypy.ini
##########
@@ -23,6 +23,14 @@ follow_imports = skip
 ignore_errors = False
 strict_optional = False
 
+#
+# Note: not all tests under .tests/ are typed 
+# Therefore include test files that should be
+# checked by mypy here
+#
+files = 
+    tests/python/unittest/test_tvmscript_type.py

Review comment:
       I'd prefer adding the command to https://github.com/apache/tvm/blob/main/tests/scripts/task_mypy.sh
   My concern is that adding it in the config file is not as explicit as adding in a centralized file of test commands. Also  IIUC this implies check the specified file every time running Mypy (even if it is intended to type check other modules)

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def range(begin, end): ...

Review comment:
       duplicated with definitions below 

##########
File path: tests/python/unittest/test_tvmscript_type.py
##########
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys
+import pytest
+import tvm
+from tvm import tir
+from tvm.script import tir as T
+
+"""
+This module tests the type of
+T.prim_func, T.handle, T.match_buffer, T.block
+T.reads, T.writes, T.alloc_buffer, T.serial
+T.block_attr, T.float32
+"""
+
+
+@pytest.mark.mypy_testing

Review comment:
       remove this as not needed

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+from . import axis
+from .ty import ConcreteType
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+class Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+Axis
+"""
+
+def reduce_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def range(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def scan_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def opaque_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+def axis_spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_remap(iter_types: str, loop_vars: List[Var]) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...
+def func_attr(attrs: Dict) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(extents, dtype, scope: str, condition=True, annotations=None) -> None: ...
+def launch_thread(env_var, extent): ...
+def realize(buffer_slice: BufferSlice, scope: str, condition=True) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition, message): ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: List[Union[PrimExpr, int]]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty
+"""
+boolean = ConcreteType("bool")
+handle = ConcreteType("handle")

Review comment:
       is it necessary to instantiate `ConcreteType`? alternatively you can just add `class boolean` to declare a type
   also if you want to add the contents ty.py here it should be exhaustive

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+from . import axis
+from .ty import ConcreteType
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+class Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+Axis
+"""
+
+def reduce_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def range(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def scan_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def opaque_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+def axis_spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...

Review comment:
       should be converted to static method of `class axis` (`axis_spatial` is not a keyword exposed to user)
   e.g. 
   ```
   class axis:
     @staticmethod
     def spatial(...)
   ```

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,270 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from numbers import Number
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm) -> PrimExpr: ...

Review comment:
       you can use `builtins.bool` to refer to python's bool type

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,270 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from numbers import Number
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...
+def func_attr(attrs: Dict) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(extents, dtype, scope: str, condition=True, annotations=None) -> None: ...
+def launch_thread(env_var, extent): ...
+def realize(buffer_slice: BufferSlice, scope: str, condition=True) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition, message): ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: List[Union[PrimExpr, int]]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty - redefine types
+"""
+
+class boolean: ...
+class handle: ...
+class int8: ...
+class int16: ...
+class int32: ...
+class int64: ...
+class float16: ...

Review comment:
       does it shadow the function name like float16 above?
   If so, you can also try making those functions as `__init__` to these types 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r744006730



##########
File path: tests/python/unittest/test_tvmscript_type.py
##########
@@ -0,0 +1,82 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement
+from tvm.script import tir as T
+
+
+@T.prim_func
+def element_wise_storage_align(a: T.handle, c: T.handle) -> None:
+    """
+    This prim func include necessary buffer types that need to be checked
+    e.g. reads/writes, match_buffer/alloc_buffer, serial/block etc.
+    """
+    C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1)
+    A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1)
+    # body
+    with T.block("root"):
+        T.reads([])
+        T.writes([])
+        B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1)
+        for i0 in T.serial(0, 128):
+            for ax1 in T.serial(0, 128):
+                with T.block("B"):
+                    vi, vj = T.axis.remap("SS", [i0, ax1])
+                    T.reads([A[vi, vj]])
+                    T.writes([B[vi, vj]])
+                    T.block_attr({"buffer_dim_align": [[0, 0, 128, 127]]})
+                    B[vi, vj] = A[vi, vj] * T.float32(2)
+            for i1 in T.serial(0, 128):
+                with T.block("C"):
+                    vi_1, vj_1 = T.axis.remap("SS", [i0, i1])
+                    T.reads([B[vi_1, vj_1]])
+                    T.writes([C[vi_1, vj_1]])
+                    C[vi_1, vj_1] = B[vi_1, vj_1] + T.float32(1)
+
+
+@T.prim_func
+def element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None:
+    """

Review comment:
       this is moved out of the function




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#issuecomment-961529551


   right now with `mypy --strict`, the only errors we get are:
   ``` python/tvm/script/tir/__init__.pyi: note: In class "Buffer":
   python/tvm/script/tir/__init__.pyi:44:14: error: Class cannot subclass "Var" (has type "Any")
   python/tvm/script/tir/__init__.pyi: note: In function "prim_func":
   python/tvm/script/tir/__init__.pyi:195:27: error: Missing type parameters for generic type "Callable"
   python/tvm/script/tir/__init__.pyi: note: At top level:
   python/tvm/script/tir/__init__.pyi:208:13: error: Missing type parameters for generic type "ContextManager"
   python/tvm/script/tir/__init__.pyi:212:12: error: Missing type parameters for generic type "ContextManager"
   python/tvm/script/tir/__init__.pyi:215:11: error: Missing type parameters for generic type "ContextManager"
   Found 5 errors in 1 file (checked 1 source file)```
   
   Should be good for another round of review now


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r745086428



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,353 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import Range
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class PrimExpr:
+    def __init__(self: PrimExpr) -> None: ...
+    @overload
+    def __add__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __add__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    @overload
+    def __sub__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __sub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    @overload
+    def __mul__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __mul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    @overload
+    def __div__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __div__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __radd__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __rsub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __rmul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __rdiv__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+
+class Var(PrimExpr): ...
+class IterVar(Var): ...
+
+class Buffer:
+    @overload
+    def __getitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int]]) -> PrimExpr: ...
+    @overload
+    def __getitem__(self: Buffer, pos: Union[PrimExpr, int]) -> PrimExpr: ...
+    @overload
+    def __setitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int]], value: PrimExpr) -> None: ...
+    @overload
+    def __setitem__(self: Buffer, pos: Union[PrimExpr, int], value: PrimExpr) -> None: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, builtins.int]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype: str) -> PrimExpr: ...
+def max_value(dtype: str) -> PrimExpr: ...
+def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def abs(x: PrimExpr) -> PrimExpr: ...
+def load(
+    dtype: str, var: Var, index: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = None
+) -> PrimExpr: ...
+def cast(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ...
+def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ...
+def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ...
+def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ...
+def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
+def evaluate(value: PrimExpr) -> None: ...
+def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def store(
+    var: Var, index: PrimExpr, value: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = True
+) -> None: ...
+def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @overload
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def spatial(
+        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
+    ) -> IterVar: ...
+    @overload
+    @staticmethod
+    def S(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def S(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def reduce(
+        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
+    ) -> IterVar: ...
+    @overload
+    @staticmethod
+    def R(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def R(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def scan(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def scan(
+        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
+    ) -> IterVar: ...
+    @overload
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def opaque(
+        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
+    ) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+def get_axis(begin: PrimExpr, end: PrimExpr, iter_type: int) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype: str, storage_scope: str) -> Var: ...
+def func_attr(attrs: Mapping[str, Object]) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(
+    extents: List[PrimExpr],
+    dtype: str,
+    scope: str,
+    condition: Union[PrimExpr, builtins.bool] = True,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Var: ...
+def launch_thread(env_var: Var, extent: Union[int, PrimExpr]) -> Var: ...
+def realize(
+    buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True
+) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr: ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: Union[PrimExpr, int]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty - redefine types
+"""
+
+class boolean: ...
+
+class handle:
+    def __getitem__(self: handle, pos: Tuple[Union[int, PrimExpr, slice]]) -> Buffer: ...
+    def __setitem__(
+        self: handle, pos: Tuple[Union[int, PrimExpr, slice]], value: Buffer

Review comment:
       please also update this similar to `Buffer`, besides `__setitem__` should return `None`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r745084911



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -33,21 +33,51 @@ from numbers import Number
 import builtins
 
 from tvm.tir.function import PrimFunc
-from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.tir import Range
 from tvm.runtime import Object
 from .node import BufferSlice
 
 """
 redefine types
 """
 
+class PrimExpr:
+    def __init__(self: PrimExpr) -> None: ...
+    @overload
+    def __add__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __add__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    @overload
+    def __sub__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __sub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    @overload
+    def __mul__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __mul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    @overload
+    def __div__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __div__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __radd__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __rsub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __rmul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __rdiv__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+
+class Var(PrimExpr): ...
+class IterVar(Var): ...
+
 class Buffer:
     @overload
-    def __getitem__(self: Buffer, pos: List[Union[PrimExpr, int]]) -> PrimExpr: ...
+    def __getitem__(
+        self: Buffer, pos: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]]

Review comment:
       done

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -33,21 +33,51 @@ from numbers import Number
 import builtins
 
 from tvm.tir.function import PrimFunc
-from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.tir import Range
 from tvm.runtime import Object
 from .node import BufferSlice
 
 """
 redefine types
 """
 
+class PrimExpr:
+    def __init__(self: PrimExpr) -> None: ...
+    @overload
+    def __add__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __add__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    @overload
+    def __sub__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __sub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    @overload
+    def __mul__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __mul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    @overload
+    def __div__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __div__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __radd__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __rsub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __rmul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __rdiv__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+
+class Var(PrimExpr): ...
+class IterVar(Var): ...
+
 class Buffer:
     @overload
-    def __getitem__(self: Buffer, pos: List[Union[PrimExpr, int]]) -> PrimExpr: ...
+    def __getitem__(
+        self: Buffer, pos: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]]
+    ) -> PrimExpr: ...
     @overload
     def __getitem__(self: Buffer, pos: Union[PrimExpr, int]) -> PrimExpr: ...
     @overload
-    def __setitem__(self: Buffer, pos: List[Union[PrimExpr, int]], value: PrimExpr) -> None: ...
+    def __setitem__(

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r743979772



##########
File path: tests/python/unittest/test_tvmscript_type.py
##########
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement
+from tvm.script import tir as T
+
+
+@T.prim_func
+def element_wise_storage_align(a: T.handle, c: T.handle) -> None:
+    """
+    This prim func include necessary buffer types that need to be checked
+    e.g. reads/writes, match_buffer/alloc_buffer, serial/block etc.
+    """
+    C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1)
+    A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1)
+    # body
+    with T.block("root"):
+        T.reads([])
+        T.writes([])
+        B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1)
+        for i0 in T.serial(0, 128):
+            for ax1 in T.serial(0, 128):
+                with T.block("B"):
+                    vi, vj = T.axis.remap("SS", [i0, ax1])
+                    T.reads([A[vi, vj]])
+                    T.writes([B[vi, vj]])
+                    T.block_attr({"buffer_dim_align": [[0, 0, 128, 127]]})
+                    B[vi, vj] = A[vi, vj] * T.float32(2)
+            for i1 in T.serial(0, 128):
+                with T.block("C"):
+                    vi_1, vj_1 = T.axis.remap("SS", [i0, i1])
+                    T.reads([B[vi_1, vj_1]])
+                    T.writes([C[vi_1, vj_1]])
+                    C[vi_1, vj_1] = B[vi_1, vj_1] + T.float32(1)
+
+
+"""

Review comment:
       removed




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r743975106



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,285 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, *pos: Union[int, PrimExpr, slice, Any]) -> PrimExpr: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype: str) -> PrimExpr: ...
+def max_value(dtype: str) -> PrimExpr: ...
+def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def abs(x: PrimExpr) -> PrimExpr: ...
+def load(dtype: str, var: Var, index: PrimExpr, predicate: PrimExpr = None) -> PrimExpr: ...
+def cast(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ...
+def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ...
+def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ...
+def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ...
+def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
+def evaluate(value: PrimExpr) -> PrimExpr: ...
+def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def store(var: Var, index: PrimExpr, value: PrimExpr, predicate: PrimExpr = True) -> None: ...
+def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+def get_axis(begin: PrimExpr, end: PrimExpr, iter_type: int) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype: str, storage_scope: str) -> Var: ...
+def func_attr(attrs: Mapping[str, Object]) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(
+    extents: List[PrimExpr],
+    dtype: str,
+    scope: str,
+    condition: Union[PrimExpr, builtins.bool] = True,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Var: ...
+def launch_thread(env_var: Var, extent: PrimExpr) -> Var: ...
+def realize(
+    buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True
+) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr: ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: Union[PrimExpr, int]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty - redefine types
+"""
+
+class boolean: ...
+
+class handle:
+    def __getitem__(self: handle, pos: Tuple[Union[Number, PrimExpr, slice]]) -> Buffer: ...
+    @property
+    def data(self: handle) -> Ptr: ...
+
+# class float32:
+# def __new__(self, imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+# def __init__(self, imm: Union[PrimExpr, Number]) -> None: ...

Review comment:
       removed

##########
File path: python/tvm/script/tir/ty.py
##########
@@ -30,8 +30,11 @@ def evaluate(self):
         """Return an actual ir.Type Object that this Generic class wraps"""
         raise TypeError("Cannot get tvm.Type from a generic type")
 
+    def __call__(self):
+        raise NotImplementedError

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r743974539



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,285 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, *pos: Union[int, PrimExpr, slice, Any]) -> PrimExpr: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype: str) -> PrimExpr: ...
+def max_value(dtype: str) -> PrimExpr: ...
+def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def abs(x: PrimExpr) -> PrimExpr: ...
+def load(dtype: str, var: Var, index: PrimExpr, predicate: PrimExpr = None) -> PrimExpr: ...
+def cast(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ...
+def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ...
+def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ...
+def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ...
+def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
+def evaluate(value: PrimExpr) -> PrimExpr: ...
+def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def store(var: Var, index: PrimExpr, value: PrimExpr, predicate: PrimExpr = True) -> None: ...
+def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+def get_axis(begin: PrimExpr, end: PrimExpr, iter_type: int) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype: str, storage_scope: str) -> Var: ...
+def func_attr(attrs: Mapping[str, Object]) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(
+    extents: List[PrimExpr],
+    dtype: str,
+    scope: str,
+    condition: Union[PrimExpr, builtins.bool] = True,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Var: ...
+def launch_thread(env_var: Var, extent: PrimExpr) -> Var: ...
+def realize(
+    buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True
+) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr: ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: Union[PrimExpr, int]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty - redefine types
+"""
+
+class boolean: ...
+
+class handle:
+    def __getitem__(self: handle, pos: Tuple[Union[Number, PrimExpr, slice]]) -> Buffer: ...
+    @property
+    def data(self: handle) -> Ptr: ...

Review comment:
       `T.handle` is used as such [here](https://github.com/apache/tvm/blob/048994bd934a39b45f88f4e929d8214e7918dd8e/tests/python/unittest/test_tvmscript_roundtrip.py#L2606), where `stack_tcode` is defined as `T.handle` in this [line](https://github.com/apache/tvm/blob/048994bd934a39b45f88f4e929d8214e7918dd8e/tests/python/unittest/test_tvmscript_roundtrip.py#L2421) above




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r743148098



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(
+        self: Buffer, pos: Tuple[Union[int, PrimExpr], Union[int, PrimExpr]]
+    ) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...

Review comment:
       ```suggestion
   def buffer_var(dtype: str, storage_scope: str) -> IterVar: ...
   ```

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(
+        self: Buffer, pos: Tuple[Union[int, PrimExpr], Union[int, PrimExpr]]
+    ) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...

Review comment:
       can we type these functions?

##########
File path: tests/python/unittest/test_tvmscript_type.py
##########
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys

Review comment:
       not needed

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(
+        self: Buffer, pos: Tuple[Union[int, PrimExpr], Union[int, PrimExpr]]
+    ) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...
+def func_attr(attrs: Dict) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(
+    extents, dtype, scope: str, condition: builtins.bool = True, annotations=None
+) -> None: ...
+def launch_thread(env_var, extent) -> None: ...
+def realize(buffer_slice: BufferSlice, scope: str, condition: builtins.bool = True) -> None: ...

Review comment:
       same as above

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(
+        self: Buffer, pos: Tuple[Union[int, PrimExpr], Union[int, PrimExpr]]
+    ) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...
+def func_attr(attrs: Dict) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(
+    extents, dtype, scope: str, condition: builtins.bool = True, annotations=None

Review comment:
       ```suggestion
       extents, dtype, scope: str, condition: Union[PrimExpr, builtins.bool] = True, annotations=None
   ```
   if possible, please also type other args

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(
+        self: Buffer, pos: Tuple[Union[int, PrimExpr], Union[int, PrimExpr]]
+    ) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...

Review comment:
       ```suggestion
   def buffer_var(dtype: str, storage_scope: str) -> Var: ...
   ```

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(
+        self: Buffer, pos: Tuple[Union[int, PrimExpr], Union[int, PrimExpr]]
+    ) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...

Review comment:
       can we type these functions? If there are some args that can't be typed, please explicitly add `Any` as its type

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, *pos: Union[int, PrimExpr, slice, Any]) -> Buffer: ...

Review comment:
       isn't the return value `PrimExpr`?

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):

Review comment:
       `Buffer` shouldn't subclass `Var`

##########
File path: tests/python/unittest/test_tvmscript_type.py
##########
@@ -0,0 +1,87 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+from tvm.script import tir as T
+
+"""
+This prim_func tests the type of

Review comment:
       I think rather than list these types here, it would be clearer to just generally describe the intention of these cases. 

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):

Review comment:
       is it because `__setitem__` is missing?

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,285 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, *pos: Union[int, PrimExpr, slice, Any]) -> PrimExpr: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype: str) -> PrimExpr: ...
+def max_value(dtype: str) -> PrimExpr: ...
+def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def abs(x: PrimExpr) -> PrimExpr: ...
+def load(dtype: str, var: Var, index: PrimExpr, predicate: PrimExpr = None) -> PrimExpr: ...
+def cast(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ...
+def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ...
+def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ...
+def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ...
+def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
+def evaluate(value: PrimExpr) -> PrimExpr: ...
+def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def store(var: Var, index: PrimExpr, value: PrimExpr, predicate: PrimExpr = True) -> None: ...
+def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+def get_axis(begin: PrimExpr, end: PrimExpr, iter_type: int) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype: str, storage_scope: str) -> Var: ...
+def func_attr(attrs: Mapping[str, Object]) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(
+    extents: List[PrimExpr],
+    dtype: str,
+    scope: str,
+    condition: Union[PrimExpr, builtins.bool] = True,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Var: ...
+def launch_thread(env_var: Var, extent: PrimExpr) -> Var: ...
+def realize(
+    buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True
+) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr: ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: Union[PrimExpr, int]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty - redefine types
+"""
+
+class boolean: ...
+
+class handle:
+    def __getitem__(self: handle, pos: Tuple[Union[Number, PrimExpr, slice]]) -> Buffer: ...
+    @property
+    def data(self: handle) -> Ptr: ...

Review comment:
       are they needed? do you have a use case?

##########
File path: python/tvm/script/tir/ty.py
##########
@@ -30,8 +30,11 @@ def evaluate(self):
         """Return an actual ir.Type Object that this Generic class wraps"""
         raise TypeError("Cannot get tvm.Type from a generic type")
 
+    def __call__(self):
+        raise NotImplementedError

Review comment:
       also add a comment explain why this is needed
   ```suggestion
           raise NotImplementedError()
   ```

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,285 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, *pos: Union[int, PrimExpr, slice, Any]) -> PrimExpr: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype: str) -> PrimExpr: ...
+def max_value(dtype: str) -> PrimExpr: ...
+def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def abs(x: PrimExpr) -> PrimExpr: ...
+def load(dtype: str, var: Var, index: PrimExpr, predicate: PrimExpr = None) -> PrimExpr: ...
+def cast(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ...
+def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ...
+def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ...
+def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ...
+def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
+def evaluate(value: PrimExpr) -> PrimExpr: ...
+def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def store(var: Var, index: PrimExpr, value: PrimExpr, predicate: PrimExpr = True) -> None: ...
+def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+def get_axis(begin: PrimExpr, end: PrimExpr, iter_type: int) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype: str, storage_scope: str) -> Var: ...
+def func_attr(attrs: Mapping[str, Object]) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(
+    extents: List[PrimExpr],
+    dtype: str,
+    scope: str,
+    condition: Union[PrimExpr, builtins.bool] = True,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Var: ...
+def launch_thread(env_var: Var, extent: PrimExpr) -> Var: ...
+def realize(
+    buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True
+) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr: ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: Union[PrimExpr, int]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty - redefine types
+"""
+
+class boolean: ...
+
+class handle:
+    def __getitem__(self: handle, pos: Tuple[Union[Number, PrimExpr, slice]]) -> Buffer: ...
+    @property
+    def data(self: handle) -> Ptr: ...
+
+# class float32:
+# def __new__(self, imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+# def __init__(self, imm: Union[PrimExpr, Number]) -> None: ...

Review comment:
       remove if not needed

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(
+        self: Buffer, pos: Tuple[Union[int, PrimExpr], Union[int, PrimExpr]]
+    ) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...

Review comment:
       ```suggestion
   def buffer_var(dtype: str, storage_scope: str) -> IterVar: ...
   ```

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(
+        self: Buffer, pos: Tuple[Union[int, PrimExpr], Union[int, PrimExpr]]
+    ) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...

Review comment:
       can we type these functions?

##########
File path: tests/python/unittest/test_tvmscript_type.py
##########
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys

Review comment:
       not needed

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(
+        self: Buffer, pos: Tuple[Union[int, PrimExpr], Union[int, PrimExpr]]
+    ) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...
+def func_attr(attrs: Dict) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(
+    extents, dtype, scope: str, condition: builtins.bool = True, annotations=None
+) -> None: ...
+def launch_thread(env_var, extent) -> None: ...
+def realize(buffer_slice: BufferSlice, scope: str, condition: builtins.bool = True) -> None: ...

Review comment:
       same as above

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(
+        self: Buffer, pos: Tuple[Union[int, PrimExpr], Union[int, PrimExpr]]
+    ) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...
+def func_attr(attrs: Dict) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(
+    extents, dtype, scope: str, condition: builtins.bool = True, annotations=None

Review comment:
       ```suggestion
       extents, dtype, scope: str, condition: Union[PrimExpr, builtins.bool] = True, annotations=None
   ```
   if possible, please also type other args

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(
+        self: Buffer, pos: Tuple[Union[int, PrimExpr], Union[int, PrimExpr]]
+    ) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...

Review comment:
       ```suggestion
   def buffer_var(dtype: str, storage_scope: str) -> Var: ...
   ```

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(
+        self: Buffer, pos: Tuple[Union[int, PrimExpr], Union[int, PrimExpr]]
+    ) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...

Review comment:
       can we type these functions? If there are some args that can't be typed, please explicitly add `Any` as its type

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, *pos: Union[int, PrimExpr, slice, Any]) -> Buffer: ...

Review comment:
       isn't the return value `PrimExpr`?

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):

Review comment:
       `Buffer` shouldn't subclass `Var`

##########
File path: tests/python/unittest/test_tvmscript_type.py
##########
@@ -0,0 +1,87 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+from tvm.script import tir as T
+
+"""
+This prim_func tests the type of

Review comment:
       I think rather than list these types here, it would be clearer to just generally describe the intention of these cases. 

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):

Review comment:
       is it because `__setitem__` is missing?

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,285 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, *pos: Union[int, PrimExpr, slice, Any]) -> PrimExpr: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype: str) -> PrimExpr: ...
+def max_value(dtype: str) -> PrimExpr: ...
+def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def abs(x: PrimExpr) -> PrimExpr: ...
+def load(dtype: str, var: Var, index: PrimExpr, predicate: PrimExpr = None) -> PrimExpr: ...
+def cast(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ...
+def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ...
+def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ...
+def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ...
+def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
+def evaluate(value: PrimExpr) -> PrimExpr: ...
+def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def store(var: Var, index: PrimExpr, value: PrimExpr, predicate: PrimExpr = True) -> None: ...
+def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+def get_axis(begin: PrimExpr, end: PrimExpr, iter_type: int) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype: str, storage_scope: str) -> Var: ...
+def func_attr(attrs: Mapping[str, Object]) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(
+    extents: List[PrimExpr],
+    dtype: str,
+    scope: str,
+    condition: Union[PrimExpr, builtins.bool] = True,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Var: ...
+def launch_thread(env_var: Var, extent: PrimExpr) -> Var: ...
+def realize(
+    buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True
+) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr: ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: Union[PrimExpr, int]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty - redefine types
+"""
+
+class boolean: ...
+
+class handle:
+    def __getitem__(self: handle, pos: Tuple[Union[Number, PrimExpr, slice]]) -> Buffer: ...
+    @property
+    def data(self: handle) -> Ptr: ...

Review comment:
       are they needed? do you have a use case?

##########
File path: python/tvm/script/tir/ty.py
##########
@@ -30,8 +30,11 @@ def evaluate(self):
         """Return an actual ir.Type Object that this Generic class wraps"""
         raise TypeError("Cannot get tvm.Type from a generic type")
 
+    def __call__(self):
+        raise NotImplementedError

Review comment:
       also add a comment explain why this is needed
   ```suggestion
           raise NotImplementedError()
   ```

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,285 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, *pos: Union[int, PrimExpr, slice, Any]) -> PrimExpr: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype: str) -> PrimExpr: ...
+def max_value(dtype: str) -> PrimExpr: ...
+def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def abs(x: PrimExpr) -> PrimExpr: ...
+def load(dtype: str, var: Var, index: PrimExpr, predicate: PrimExpr = None) -> PrimExpr: ...
+def cast(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ...
+def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ...
+def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ...
+def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ...
+def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
+def evaluate(value: PrimExpr) -> PrimExpr: ...
+def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def store(var: Var, index: PrimExpr, value: PrimExpr, predicate: PrimExpr = True) -> None: ...
+def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+def get_axis(begin: PrimExpr, end: PrimExpr, iter_type: int) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype: str, storage_scope: str) -> Var: ...
+def func_attr(attrs: Mapping[str, Object]) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(
+    extents: List[PrimExpr],
+    dtype: str,
+    scope: str,
+    condition: Union[PrimExpr, builtins.bool] = True,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Var: ...
+def launch_thread(env_var: Var, extent: PrimExpr) -> Var: ...
+def realize(
+    buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True
+) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr: ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: Union[PrimExpr, int]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty - redefine types
+"""
+
+class boolean: ...
+
+class handle:
+    def __getitem__(self: handle, pos: Tuple[Union[Number, PrimExpr, slice]]) -> Buffer: ...
+    @property
+    def data(self: handle) -> Ptr: ...
+
+# class float32:
+# def __new__(self, imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+# def __init__(self, imm: Union[PrimExpr, Number]) -> None: ...

Review comment:
       remove if not needed




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#issuecomment-963828259


   Quick question: do we plan to add docstring in this PR?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#issuecomment-958416998


   We will need a few test cases that cover various TIR elements (block related: T.block, T.match_buffer, T.alloc_buffer, T.block_attr, loop related: T.grid, T.serial, T.thread_binding, etc). You can find good examples of TIR scripts written for previous test cases. Please also add script to run the tests to `https://github.com/apache/tvm/blob/main/tests/scripts/task_mypy.sh` and make sure it can pass the tests


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r743310268



##########
File path: tests/python/unittest/test_tvmscript_type.py
##########
@@ -0,0 +1,87 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+from tvm.script import tir as T
+
+"""
+This prim_func tests the type of

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r744028799



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,285 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer:
+    def __getitem__(self: Buffer, pos: Union[PrimExpr, int, slice]) -> PrimExpr: ...

Review comment:
       ```
       @overload
       def __getitem__(self: Buffer, pos: List[Union[PrimExpr, int]]) -> PrimExpr: ...
       @overload
       def __getitem__(self: Buffer, pos: Union[PrimExpr, int]) -> PrimExpr: ...
       @overload
       def __setitem__(self: Buffer, pos: List[Union[PrimExpr, int]], value: PrimExpr) -> None: ...
       @overload
       def __setitem__(self: Buffer, pos: Union[PrimExpr, int], value: PrimExpr) -> None: ...
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r742311358



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+from . import axis
+from .ty import ConcreteType
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+class Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+Axis
+"""
+
+def reduce_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def range(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def scan_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def opaque_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+def axis_spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 merged pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
junrushao1994 merged pull request #9432:
URL: https://github.com/apache/tvm/pull/9432


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r743319112



##########
File path: python/tvm/script/tir/ty.py
##########
@@ -30,8 +30,11 @@ def evaluate(self):
         """Return an actual ir.Type Object that this Generic class wraps"""
         raise TypeError("Cannot get tvm.Type from a generic type")
 
+    def __call__(self):
+        raise NotImplementedError

Review comment:
       also add a comment explain why this is needed
   ```suggestion
           raise NotImplementedError()
   ```

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,285 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, *pos: Union[int, PrimExpr, slice, Any]) -> PrimExpr: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype: str) -> PrimExpr: ...
+def max_value(dtype: str) -> PrimExpr: ...
+def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def abs(x: PrimExpr) -> PrimExpr: ...
+def load(dtype: str, var: Var, index: PrimExpr, predicate: PrimExpr = None) -> PrimExpr: ...
+def cast(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ...
+def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ...
+def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ...
+def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ...
+def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
+def evaluate(value: PrimExpr) -> PrimExpr: ...
+def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def store(var: Var, index: PrimExpr, value: PrimExpr, predicate: PrimExpr = True) -> None: ...
+def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+def get_axis(begin: PrimExpr, end: PrimExpr, iter_type: int) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype: str, storage_scope: str) -> Var: ...
+def func_attr(attrs: Mapping[str, Object]) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(
+    extents: List[PrimExpr],
+    dtype: str,
+    scope: str,
+    condition: Union[PrimExpr, builtins.bool] = True,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Var: ...
+def launch_thread(env_var: Var, extent: PrimExpr) -> Var: ...
+def realize(
+    buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True
+) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr: ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: Union[PrimExpr, int]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty - redefine types
+"""
+
+class boolean: ...
+
+class handle:
+    def __getitem__(self: handle, pos: Tuple[Union[Number, PrimExpr, slice]]) -> Buffer: ...
+    @property
+    def data(self: handle) -> Ptr: ...
+
+# class float32:
+# def __new__(self, imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+# def __init__(self, imm: Union[PrimExpr, Number]) -> None: ...

Review comment:
       remove if not needed




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9432: [TIR][WIP] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r741388855



##########
File path: python/tvm/tir/__init__.pyi
##########
@@ -0,0 +1,612 @@
+from typing import (

Review comment:
       Hey just wanted to get a bit of context here. Why this file is under `tir` not `script/tir`?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR][WIP] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r741449460



##########
File path: python/tvm/tir/__init__.pyi
##########
@@ -0,0 +1,612 @@
+from typing import (

Review comment:
       Thanks for the prompt review! There is a discussion in which Wuwei and I decided to put `PrimExpr` and `PrimExprWithOp` (for future use) here in `tir/__init__.pyi` instead of `script/tir/__init__.pyi`. Therefore stubs from `script/tir` will need to import `PrimExpr` and `IterVar` etc from `tir`. I am curious about your take on this.

##########
File path: python/tvm/tir/__init__.pyi
##########
@@ -0,0 +1,612 @@
+from typing import (

Review comment:
       Thanks for the prompt review! There is a discussion in which Wuwei and I decided to put `PrimExpr` and `PrimExprWithOp` (for future use) here in `tir/__init__.pyi` instead of `script/tir/__init__.pyi`. Therefore stubs from `script/tir` will need to import `PrimExpr`, `IterVar` and etc. from `tir`. I am curious about your take on this.

##########
File path: python/tvm/tir/__init__.pyi
##########
@@ -0,0 +1,612 @@
+from typing import (

Review comment:
       this stub is removed




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r742273240



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+from . import axis
+from .ty import ConcreteType
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+class Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+Axis
+"""
+
+def reduce_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def range(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def scan_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def opaque_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+def axis_spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_remap(iter_types: str, loop_vars: List[Var]) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...
+def func_attr(attrs: Dict) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(extents, dtype, scope: str, condition=True, annotations=None) -> None: ...
+def launch_thread(env_var, extent): ...
+def realize(buffer_slice: BufferSlice, scope: str, condition=True) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition, message): ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: List[Union[PrimExpr, int]]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty
+"""
+boolean = ConcreteType("bool")
+handle = ConcreteType("handle")

Review comment:
       is it necessary to instantiate `ConcreteType`? alternatively you can just add `class boolean` to declare a type
   also if you want to add the contents ty.py here it should be exhaustive

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+from . import axis
+from .ty import ConcreteType
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+class Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+Axis
+"""
+
+def reduce_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def range(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def scan_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def opaque_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+def axis_spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...

Review comment:
       should be converted to static method of `class axis` (`axis_spatial` is not a keyword exposed to user)
   e.g. 
   ```
   class axis:
     @staticmethod
     def spatial(...)
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r742304506



##########
File path: tests/python/unittest/test_tvmscript_type.py
##########
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys
+import pytest
+import tvm
+from tvm import tir
+from tvm.script import tir as T
+
+"""
+This module tests the type of
+T.prim_func, T.handle, T.match_buffer, T.block
+T.reads, T.writes, T.alloc_buffer, T.serial
+T.block_attr, T.float32
+"""
+
+
+@pytest.mark.mypy_testing

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r743974539



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,285 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, *pos: Union[int, PrimExpr, slice, Any]) -> PrimExpr: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype: str) -> PrimExpr: ...
+def max_value(dtype: str) -> PrimExpr: ...
+def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def abs(x: PrimExpr) -> PrimExpr: ...
+def load(dtype: str, var: Var, index: PrimExpr, predicate: PrimExpr = None) -> PrimExpr: ...
+def cast(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ...
+def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ...
+def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ...
+def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ...
+def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
+def evaluate(value: PrimExpr) -> PrimExpr: ...
+def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def store(var: Var, index: PrimExpr, value: PrimExpr, predicate: PrimExpr = True) -> None: ...
+def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+def get_axis(begin: PrimExpr, end: PrimExpr, iter_type: int) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype: str, storage_scope: str) -> Var: ...
+def func_attr(attrs: Mapping[str, Object]) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(
+    extents: List[PrimExpr],
+    dtype: str,
+    scope: str,
+    condition: Union[PrimExpr, builtins.bool] = True,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Var: ...
+def launch_thread(env_var: Var, extent: PrimExpr) -> Var: ...
+def realize(
+    buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True
+) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr: ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: Union[PrimExpr, int]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty - redefine types
+"""
+
+class boolean: ...
+
+class handle:
+    def __getitem__(self: handle, pos: Tuple[Union[Number, PrimExpr, slice]]) -> Buffer: ...
+    @property
+    def data(self: handle) -> Ptr: ...

Review comment:
       `T.handle` is used as such [here](https://github.com/apache/tvm/blob/048994bd934a39b45f88f4e929d8214e7918dd8e/tests/python/unittest/test_tvmscript_roundtrip.py#L2606)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r742568716



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,270 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from numbers import Number
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm) -> PrimExpr: ...

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r742578238



##########
File path: mypy.ini
##########
@@ -23,6 +23,14 @@ follow_imports = skip
 ignore_errors = False
 strict_optional = False
 
+#
+# Note: not all tests under .tests/ are typed 
+# Therefore include test files that should be
+# checked by mypy here
+#
+files = 
+    tests/python/unittest/test_tvmscript_type.py

Review comment:
       The problem with adding it in `task_mypy.sh` is that as our test case is not in the `tvm` python module, therefore imports like `tvm` or `tvm.tir` can not be found by mypy. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r743308012



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, *pos: Union[int, PrimExpr, slice, Any]) -> Buffer: ...

Review comment:
       changed

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#issuecomment-961529551


   right now with `mypy --strict`, the only errors we get are:
   ``` python/tvm/script/tir/__init__.pyi: note: In class "Buffer":
   python/tvm/script/tir/__init__.pyi:44:14: error: Class cannot subclass "Var" (has type "Any")
   python/tvm/script/tir/__init__.pyi: note: In function "prim_func":
   python/tvm/script/tir/__init__.pyi:195:27: error: Missing type parameters for generic type "Callable"
   python/tvm/script/tir/__init__.pyi: note: At top level:
   python/tvm/script/tir/__init__.pyi:208:13: error: Missing type parameters for generic type "ContextManager"
   python/tvm/script/tir/__init__.pyi:212:12: error: Missing type parameters for generic type "ContextManager"
   python/tvm/script/tir/__init__.pyi:215:11: error: Missing type parameters for generic type "ContextManager"
   Found 5 errors in 1 file (checked 1 source file)```
   
   Should be good for another round of review now


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r742578238



##########
File path: mypy.ini
##########
@@ -23,6 +23,14 @@ follow_imports = skip
 ignore_errors = False
 strict_optional = False
 
+#
+# Note: not all tests under .tests/ are typed 
+# Therefore include test files that should be
+# checked by mypy here
+#
+files = 
+    tests/python/unittest/test_tvmscript_type.py

Review comment:
       The problem with adding it in `task_mypy.sh` is that as our test case is not in the `tvm` python module, therefore imports like `tvm` or `tvm.tir` can not be found by mypy. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9432: [WIP][TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r741453933



##########
File path: python/tvm/tir/__init__.pyi
##########
@@ -0,0 +1,612 @@
+from typing import (

Review comment:
       For `tvm.tir` package I think adding inline typing where necessary is sufficient. The main reason that we need a stub file for `tvm.script.tir` is that we would like to expose those keywords that do not have an actual implementation 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r744001201



##########
File path: tests/python/unittest/test_tvmscript_type.py
##########
@@ -0,0 +1,82 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement
+from tvm.script import tir as T
+
+
+@T.prim_func
+def element_wise_storage_align(a: T.handle, c: T.handle) -> None:
+    """
+    This prim func include necessary buffer types that need to be checked
+    e.g. reads/writes, match_buffer/alloc_buffer, serial/block etc.
+    """
+    C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1)
+    A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1)
+    # body
+    with T.block("root"):
+        T.reads([])
+        T.writes([])
+        B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1)
+        for i0 in T.serial(0, 128):
+            for ax1 in T.serial(0, 128):
+                with T.block("B"):
+                    vi, vj = T.axis.remap("SS", [i0, ax1])
+                    T.reads([A[vi, vj]])
+                    T.writes([B[vi, vj]])
+                    T.block_attr({"buffer_dim_align": [[0, 0, 128, 127]]})
+                    B[vi, vj] = A[vi, vj] * T.float32(2)
+            for i1 in T.serial(0, 128):
+                with T.block("C"):
+                    vi_1, vj_1 = T.axis.remap("SS", [i0, i1])
+                    T.reads([B[vi_1, vj_1]])
+                    T.writes([C[vi_1, vj_1]])
+                    C[vi_1, vj_1] = B[vi_1, vj_1] + T.float32(1)
+
+
+@T.prim_func
+def element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None:
+    """

Review comment:
       doc string is currently not supported by TIR, you can make it a comment or move it outside the function though




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r744013432



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,285 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer:
+    def __getitem__(self: Buffer, pos: Union[PrimExpr, int, slice]) -> PrimExpr: ...
+    def __setitem__(self: Buffer, pos: Union[PrimExpr, int, slice], value: PrimExpr) -> None: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype: str) -> PrimExpr: ...
+def max_value(dtype: str) -> PrimExpr: ...
+def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def abs(x: PrimExpr) -> PrimExpr: ...
+def load(dtype: str, var: Var, index: PrimExpr, predicate: PrimExpr = None) -> PrimExpr: ...
+def cast(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ...
+def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ...
+def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ...
+def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ...
+def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
+def evaluate(value: PrimExpr) -> PrimExpr: ...

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r745070880



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -33,21 +33,51 @@ from numbers import Number
 import builtins
 
 from tvm.tir.function import PrimFunc
-from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.tir import Range
 from tvm.runtime import Object
 from .node import BufferSlice
 
 """
 redefine types
 """
 
+class PrimExpr:
+    def __init__(self: PrimExpr) -> None: ...
+    @overload
+    def __add__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __add__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    @overload
+    def __sub__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __sub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    @overload
+    def __mul__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __mul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    @overload
+    def __div__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __div__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __radd__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __rsub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __rmul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __rdiv__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+
+class Var(PrimExpr): ...
+class IterVar(Var): ...
+
 class Buffer:
     @overload
-    def __getitem__(self: Buffer, pos: List[Union[PrimExpr, int]]) -> PrimExpr: ...
+    def __getitem__(
+        self: Buffer, pos: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]]

Review comment:
       It should be `Sequence` instead of `2D tuple` as buffer can be accessed via N-D indices

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -33,21 +33,51 @@ from numbers import Number
 import builtins
 
 from tvm.tir.function import PrimFunc
-from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.tir import Range
 from tvm.runtime import Object
 from .node import BufferSlice
 
 """
 redefine types
 """
 
+class PrimExpr:
+    def __init__(self: PrimExpr) -> None: ...
+    @overload
+    def __add__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __add__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    @overload
+    def __sub__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __sub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    @overload
+    def __mul__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __mul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    @overload
+    def __div__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __div__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __radd__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __rsub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __rmul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __rdiv__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+
+class Var(PrimExpr): ...
+class IterVar(Var): ...
+
 class Buffer:
     @overload
-    def __getitem__(self: Buffer, pos: List[Union[PrimExpr, int]]) -> PrimExpr: ...
+    def __getitem__(
+        self: Buffer, pos: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]]
+    ) -> PrimExpr: ...
     @overload
     def __getitem__(self: Buffer, pos: Union[PrimExpr, int]) -> PrimExpr: ...
     @overload
-    def __setitem__(self: Buffer, pos: List[Union[PrimExpr, int]], value: PrimExpr) -> None: ...
+    def __setitem__(

Review comment:
       same as above




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r745035324



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,285 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer:
+    def __getitem__(self: Buffer, pos: Union[PrimExpr, int, slice]) -> PrimExpr: ...

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r744012190



##########
File path: tests/python/unittest/test_tvmscript_type.py
##########
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement
+from tvm.script import tir as T
+
+"""
+This prim func include necessary buffer types that need to be checked
+e.g. reads/writes, match_buffer/alloc_buffer, serial/block etc.
+"""
+
+
+@T.prim_func
+def element_wise_storage_align(a: T.handle, c: T.handle) -> None:
+    C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1)
+    A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1)
+    # body
+    with T.block("root"):
+        T.reads([])
+        T.writes([])
+        B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1)
+        for i0 in T.serial(0, 128):
+            for ax1 in T.serial(0, 128):
+                with T.block("B"):
+                    vi, vj = T.axis.remap("SS", [i0, ax1])

Review comment:
       it would be great to also cover `axis.spatial` (it is syntax sugar for remap)
   ```suggestion
                       vi = T.axis.S(128, i0)
                       vj = T.axis.S(128, ax1)
   ```

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,285 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer:
+    def __getitem__(self: Buffer, pos: Union[PrimExpr, int, slice]) -> PrimExpr: ...
+    def __setitem__(self: Buffer, pos: Union[PrimExpr, int, slice], value: PrimExpr) -> None: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype: str) -> PrimExpr: ...
+def max_value(dtype: str) -> PrimExpr: ...
+def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def abs(x: PrimExpr) -> PrimExpr: ...
+def load(dtype: str, var: Var, index: PrimExpr, predicate: PrimExpr = None) -> PrimExpr: ...
+def cast(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ...
+def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ...
+def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ...
+def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ...
+def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
+def evaluate(value: PrimExpr) -> PrimExpr: ...

Review comment:
       ```suggestion
   def evaluate(value: PrimExpr) -> None: ...
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r745065409



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,285 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer:
+    def __getitem__(self: Buffer, pos: Union[PrimExpr, int, slice]) -> PrimExpr: ...
+    def __setitem__(self: Buffer, pos: Union[PrimExpr, int, slice], value: PrimExpr) -> None: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype: str) -> PrimExpr: ...
+def max_value(dtype: str) -> PrimExpr: ...
+def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def abs(x: PrimExpr) -> PrimExpr: ...
+def load(dtype: str, var: Var, index: PrimExpr, predicate: PrimExpr = None) -> PrimExpr: ...
+def cast(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ...
+def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ...
+def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ...
+def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ...
+def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
+def evaluate(value: PrimExpr) -> None: ...
+def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def store(var: Var, index: PrimExpr, value: PrimExpr, predicate: PrimExpr = True) -> None: ...
+def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r744037726



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,285 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer:
+    def __getitem__(self: Buffer, pos: Union[PrimExpr, int, slice]) -> PrimExpr: ...
+    def __setitem__(self: Buffer, pos: Union[PrimExpr, int, slice], value: PrimExpr) -> None: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype: str) -> PrimExpr: ...
+def max_value(dtype: str) -> PrimExpr: ...
+def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def abs(x: PrimExpr) -> PrimExpr: ...
+def load(dtype: str, var: Var, index: PrimExpr, predicate: PrimExpr = None) -> PrimExpr: ...
+def cast(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ...
+def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ...
+def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ...
+def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ...
+def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
+def evaluate(value: PrimExpr) -> None: ...
+def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def store(var: Var, index: PrimExpr, value: PrimExpr, predicate: PrimExpr = True) -> None: ...
+def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...

Review comment:
       Might be clearer to write the two version (one taking PrimExpr one taking Tuple[PrimExpr]) as overload




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR][WIP] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r741449460



##########
File path: python/tvm/tir/__init__.pyi
##########
@@ -0,0 +1,612 @@
+from typing import (

Review comment:
       Thanks for the prompt review! There is a discussion in which Wuwei and I decided to put `PrimExpr` and `PrimExprWithOp` (for future use) here in `tir/__init__.pyi` instead of `script/tir/__init__.pyi`. Therefore stubs from `script/tir` will need to import `PrimExpr` and `IterVar` etc from `tir`. I am curious about your take on this.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r743148098



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(
+        self: Buffer, pos: Tuple[Union[int, PrimExpr], Union[int, PrimExpr]]
+    ) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...

Review comment:
       ```suggestion
   def buffer_var(dtype: str, storage_scope: str) -> Var: ...
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r743308028



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):

Review comment:
       if `Buffer` is not a subclass of `Var`, mypy throws this weird error:
   ``` tests/python/unittest/test_tvmscript_type.py: note: In function "element_wise_storage_align":
   tests/python/unittest/test_tvmscript_type.py:44:21: error: Unsupported target for indexed assignment ("Buffer")
   tests/python/unittest/test_tvmscript_type.py:50:21: error: Unsupported target for indexed assignment ("Buffer")
   tests/python/unittest/test_tvmscript_type.py: note: In function "element_wise_env_thread_x":
   tests/python/unittest/test_tvmscript_type.py:75:21: error: Unsupported target for indexed assignment ("Buffer")
   tests/python/unittest/test_tvmscript_type.py:80:21: error: Unsupported target for indexed assignment ("Buffer")```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r742317963



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+from . import axis
+from .ty import ConcreteType
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+class Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+Axis
+"""
+
+def reduce_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def range(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def scan_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def opaque_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+def axis_spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_remap(iter_types: str, loop_vars: List[Var]) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...
+def func_attr(attrs: Dict) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(extents, dtype, scope: str, condition=True, annotations=None) -> None: ...
+def launch_thread(env_var, extent): ...
+def realize(buffer_slice: BufferSlice, scope: str, condition=True) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition, message): ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: List[Union[PrimExpr, int]]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty
+"""
+boolean = ConcreteType("bool")
+handle = ConcreteType("handle")

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on a change in pull request #9432: [TIR][WIP] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r741388855



##########
File path: python/tvm/tir/__init__.pyi
##########
@@ -0,0 +1,612 @@
+from typing import (

Review comment:
       Hey just wanted to get a bit of context here. Why this file is under `tir` not `script/tir`?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r742258491



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...

Review comment:
       imm should has type Union[PrimExpr, bool, Number]

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...

Review comment:
       (for this function and below) `imm: Union[PrimExpr, Number]`

##########
File path: mypy.ini
##########
@@ -23,6 +23,14 @@ follow_imports = skip
 ignore_errors = False
 strict_optional = False
 
+#
+# Note: not all tests under .tests/ are typed 
+# Therefore include test files that should be
+# checked by mypy here
+#
+files = 
+    tests/python/unittest/test_tvmscript_type.py

Review comment:
       I'd prefer adding the command to https://github.com/apache/tvm/blob/main/tests/scripts/task_mypy.sh
   My concern is that adding it in the config file is not as explicit as adding in a centralized file of test commands. Also  IIUC this implies check the specified file every time running Mypy (even if it is intended to type check other modules)

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def range(begin, end): ...

Review comment:
       duplicated with definitions below 

##########
File path: tests/python/unittest/test_tvmscript_type.py
##########
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys
+import pytest
+import tvm
+from tvm import tir
+from tvm.script import tir as T
+
+"""
+This module tests the type of
+T.prim_func, T.handle, T.match_buffer, T.block
+T.reads, T.writes, T.alloc_buffer, T.serial
+T.block_attr, T.float32
+"""
+
+
+@pytest.mark.mypy_testing

Review comment:
       remove this as not needed




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r742304506



##########
File path: tests/python/unittest/test_tvmscript_type.py
##########
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys
+import pytest
+import tvm
+from tvm import tir
+from tvm.script import tir as T
+
+"""
+This module tests the type of
+T.prim_func, T.handle, T.match_buffer, T.block
+T.reads, T.writes, T.alloc_buffer, T.serial
+T.block_attr, T.float32
+"""
+
+
+@pytest.mark.mypy_testing

Review comment:
       done

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...

Review comment:
       the definition of `bool` here actually shadows all use of native `bool` type annotation of Python in this file. My workaround is to not include type annotation for method/class parameter that may be typed `bool`. E.g. this line would just be `def bool(imm) -> PrimExpr: ...`
   
   

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...

Review comment:
       done

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Object,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Buffer, IterVar, Var
+from .node import BufferSlice
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def range(begin, end): ...

Review comment:
       removed

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+from . import axis
+from .ty import ConcreteType
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+class Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+Axis
+"""
+
+def reduce_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def range(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def scan_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def opaque_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+def axis_spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...

Review comment:
       done

##########
File path: mypy.ini
##########
@@ -23,6 +23,14 @@ follow_imports = skip
 ignore_errors = False
 strict_optional = False
 
+#
+# Note: not all tests under .tests/ are typed 
+# Therefore include test files that should be
+# checked by mypy here
+#
+files = 
+    tests/python/unittest/test_tvmscript_type.py

Review comment:
       Good point. The reason why I include that test in `mypy.ini` is that PyTorch seems to add individual test to mypy this [way](https://github.com/pytorch/pytorch/blob/master/mypy.ini). But adding that to `task_mypy.sh` should also work.

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+from . import axis
+from .ty import ConcreteType
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+class Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: int) -> PrimExpr: ...
+def int8(imm: int) -> PrimExpr: ...
+def int16(imm: int) -> PrimExpr: ...
+def int32(imm: int) -> PrimExpr: ...
+def int64(imm: int) -> PrimExpr: ...
+def uint8(imm: int) -> PrimExpr: ...
+def uint16(imm: int) -> PrimExpr: ...
+def uint32(imm: int) -> PrimExpr: ...
+def uint64(imm: int) -> PrimExpr: ...
+def float8(imm: int) -> PrimExpr: ...
+def float16(imm: int) -> PrimExpr: ...
+def float32(imm: int) -> PrimExpr: ...
+def float64(imm: int) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+Axis
+"""
+
+def reduce_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def range(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def scan_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+def opaque_axis(begin: Union[PrimExpr, int], end: Union[PrimExpr, int]) -> IterVar: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+def axis_spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+def axis_remap(iter_types: str, loop_vars: List[Var]) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...
+def func_attr(attrs: Dict) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(extents, dtype, scope: str, condition=True, annotations=None) -> None: ...
+def launch_thread(env_var, extent): ...
+def realize(buffer_slice: BufferSlice, scope: str, condition=True) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition, message): ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: List[Union[PrimExpr, int]]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty
+"""
+boolean = ConcreteType("bool")
+handle = ConcreteType("handle")

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r743241697



##########
File path: tests/python/unittest/test_tvmscript_type.py
##########
@@ -0,0 +1,87 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+from tvm.script import tir as T
+
+"""
+This prim_func tests the type of

Review comment:
       I think rather than list these types here, it would be clearer to just generally describe the intention of these cases. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r743318643



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,285 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, *pos: Union[int, PrimExpr, slice, Any]) -> PrimExpr: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype: str) -> PrimExpr: ...
+def max_value(dtype: str) -> PrimExpr: ...
+def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def abs(x: PrimExpr) -> PrimExpr: ...
+def load(dtype: str, var: Var, index: PrimExpr, predicate: PrimExpr = None) -> PrimExpr: ...
+def cast(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ...
+def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ...
+def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ...
+def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ...
+def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
+def evaluate(value: PrimExpr) -> PrimExpr: ...
+def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def store(var: Var, index: PrimExpr, value: PrimExpr, predicate: PrimExpr = True) -> None: ...
+def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+def get_axis(begin: PrimExpr, end: PrimExpr, iter_type: int) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype: str, storage_scope: str) -> Var: ...
+def func_attr(attrs: Mapping[str, Object]) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(
+    extents: List[PrimExpr],
+    dtype: str,
+    scope: str,
+    condition: Union[PrimExpr, builtins.bool] = True,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Var: ...
+def launch_thread(env_var: Var, extent: PrimExpr) -> Var: ...
+def realize(
+    buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True
+) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr: ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: Union[PrimExpr, int]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty - redefine types
+"""
+
+class boolean: ...
+
+class handle:
+    def __getitem__(self: handle, pos: Tuple[Union[Number, PrimExpr, slice]]) -> Buffer: ...
+    @property
+    def data(self: handle) -> Ptr: ...

Review comment:
       are they needed? do you have a use case?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r742415048



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,270 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from numbers import Number
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm) -> PrimExpr: ...

Review comment:
       you can use `builtins.bool` to refer to python's bool type

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,270 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+)
+from numbers import Number
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, pos: Tuple[Union[int, PrimExpr]]) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...
+def max_value(dtype): ...
+def floordiv(x: PrimExpr, y: PrimExpr): ...
+def floormod(x: PrimExpr, y: PrimExpr): ...
+def abs(x): ...
+def load(dtype, var, index, predicate=None): ...
+def cast(value, dtype): ...
+def ramp(base, stride, lanes): ...
+def broadcast(value, lanes): ...
+def iter_var(var, dom, iter_type, thread_tag): ...
+def max(a, b): ...
+def min(a, b): ...
+def get_axis(begin, end, iter_type): ...
+def Select(cond, if_body, else_body): ...
+def evaluate(value): ...
+def store(var, index, value, predicate=True): ...
+def comm_reducer(lambda_io, identities): ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data=None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def S(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def R(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def scan(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype, storage_scope) -> IterVar: ...
+def func_attr(attrs: Dict) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(extents, dtype, scope: str, condition=True, annotations=None) -> None: ...
+def launch_thread(env_var, extent): ...
+def realize(buffer_slice: BufferSlice, scope: str, condition=True) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition, message): ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: List[Union[PrimExpr, int]]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty - redefine types
+"""
+
+class boolean: ...
+class handle: ...
+class int8: ...
+class int16: ...
+class int32: ...
+class int64: ...
+class float16: ...

Review comment:
       does it shadow the function name like float16 above?
   If so, you can also try making those functions as `__init__` to these types 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r743308012



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, *pos: Union[int, PrimExpr, slice, Any]) -> Buffer: ...

Review comment:
       changed

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):

Review comment:
       done

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):

Review comment:
       if `Buffer` is not a subclass of `Var`, mypy throws this weird error:
   ``` tests/python/unittest/test_tvmscript_type.py: note: In function "element_wise_storage_align":
   tests/python/unittest/test_tvmscript_type.py:44:21: error: Unsupported target for indexed assignment ("Buffer")
   tests/python/unittest/test_tvmscript_type.py:50:21: error: Unsupported target for indexed assignment ("Buffer")
   tests/python/unittest/test_tvmscript_type.py: note: In function "element_wise_env_thread_x":
   tests/python/unittest/test_tvmscript_type.py:75:21: error: Unsupported target for indexed assignment ("Buffer")
   tests/python/unittest/test_tvmscript_type.py:80:21: error: Unsupported target for indexed assignment ("Buffer")```

##########
File path: tests/python/unittest/test_tvmscript_type.py
##########
@@ -0,0 +1,87 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+from tvm.script import tir as T
+
+"""
+This prim_func tests the type of

Review comment:
       done

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(self: Buffer, *pos: Union[int, PrimExpr, slice, Any]) -> Buffer: ...

Review comment:
       changed

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):

Review comment:
       done

##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):

Review comment:
       if `Buffer` is not a subclass of `Var`, mypy throws this weird error:
   ``` tests/python/unittest/test_tvmscript_type.py: note: In function "element_wise_storage_align":
   tests/python/unittest/test_tvmscript_type.py:44:21: error: Unsupported target for indexed assignment ("Buffer")
   tests/python/unittest/test_tvmscript_type.py:50:21: error: Unsupported target for indexed assignment ("Buffer")
   tests/python/unittest/test_tvmscript_type.py: note: In function "element_wise_env_thread_x":
   tests/python/unittest/test_tvmscript_type.py:75:21: error: Unsupported target for indexed assignment ("Buffer")
   tests/python/unittest/test_tvmscript_type.py:80:21: error: Unsupported target for indexed assignment ("Buffer")```

##########
File path: tests/python/unittest/test_tvmscript_type.py
##########
@@ -0,0 +1,87 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+from tvm.script import tir as T
+
+"""
+This prim_func tests the type of

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r743149621



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import PrimExpr, Range, IterVar, Var
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class Buffer(Var):
+    def __getitem__(
+        self: Buffer, pos: Tuple[Union[int, PrimExpr], Union[int, PrimExpr]]
+    ) -> Buffer: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, Number]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, Number]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype): ...

Review comment:
       can we type these functions? If there are some args that can't be typed, please explicitly add `Any` as its type




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r745086428



##########
File path: python/tvm/script/tir/__init__.pyi
##########
@@ -0,0 +1,353 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+from typing import (
+    Any,
+    Callable,
+    ContextManager,
+    Dict,
+    Iterable,
+    Optional,
+    Tuple,
+    Union,
+    Sequence,
+    List,
+    Mapping,
+    overload,
+)
+from numbers import Number
+import builtins
+
+from tvm.tir.function import PrimFunc
+from tvm.tir import Range
+from tvm.runtime import Object
+from .node import BufferSlice
+
+"""
+redefine types
+"""
+
+class PrimExpr:
+    def __init__(self: PrimExpr) -> None: ...
+    @overload
+    def __add__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __add__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    @overload
+    def __sub__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __sub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    @overload
+    def __mul__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __mul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    @overload
+    def __div__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
+    @overload
+    def __div__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __radd__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __rsub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __rmul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+    def __rdiv__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
+
+class Var(PrimExpr): ...
+class IterVar(Var): ...
+
+class Buffer:
+    @overload
+    def __getitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int]]) -> PrimExpr: ...
+    @overload
+    def __getitem__(self: Buffer, pos: Union[PrimExpr, int]) -> PrimExpr: ...
+    @overload
+    def __setitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int]], value: PrimExpr) -> None: ...
+    @overload
+    def __setitem__(self: Buffer, pos: Union[PrimExpr, int], value: PrimExpr) -> None: ...
+    @property
+    def data(self: Buffer) -> Ptr: ...
+
+"""
+Variables and constants
+"""
+
+def bool(imm: Union[PrimExpr, builtins.bool, builtins.int]) -> PrimExpr: ...
+def int8(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def int16(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def int32(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def int64(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def uint8(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def uint16(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def uint32(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def uint64(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def float8(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def float16(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def float32(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+def float64(imm: Union[PrimExpr, int]) -> PrimExpr: ...
+
+"""
+Intrinsic
+"""
+
+def min_value(dtype: str) -> PrimExpr: ...
+def max_value(dtype: str) -> PrimExpr: ...
+def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def abs(x: PrimExpr) -> PrimExpr: ...
+def load(
+    dtype: str, var: Var, index: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = None
+) -> PrimExpr: ...
+def cast(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ...
+def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ...
+def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ...
+def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
+def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ...
+def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
+def evaluate(value: PrimExpr) -> None: ...
+def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
+def store(
+    var: Var, index: PrimExpr, value: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = True
+) -> None: ...
+def comm_reducer(lambda_io: Tuple[List[PrimExpr]], identities: List[PrimExpr]) -> PrimExpr: ...
+
+"""
+Unary operator
+"""
+
+def exp2(x: PrimExpr) -> PrimExpr: ...
+def exp10(x: PrimExpr) -> PrimExpr: ...
+def erf(x: PrimExpr) -> PrimExpr: ...
+def tanh(x: PrimExpr) -> PrimExpr: ...
+def sigmoid(x: PrimExpr) -> PrimExpr: ...
+def log(x: PrimExpr) -> PrimExpr: ...
+def log2(x: PrimExpr) -> PrimExpr: ...
+def log10(x: PrimExpr) -> PrimExpr: ...
+def log1p(x: PrimExpr) -> PrimExpr: ...
+def tan(x: PrimExpr) -> PrimExpr: ...
+def cos(x: PrimExpr) -> PrimExpr: ...
+def cosh(x: PrimExpr) -> PrimExpr: ...
+def acos(x: PrimExpr) -> PrimExpr: ...
+def acosh(x: PrimExpr) -> PrimExpr: ...
+def sin(x: PrimExpr) -> PrimExpr: ...
+def sinh(x: PrimExpr) -> PrimExpr: ...
+def asin(x: PrimExpr) -> PrimExpr: ...
+def asinh(x: PrimExpr) -> PrimExpr: ...
+def atan(x: PrimExpr) -> PrimExpr: ...
+def atanh(x: PrimExpr) -> PrimExpr: ...
+def atan2(x: PrimExpr) -> PrimExpr: ...
+def sqrt(x: PrimExpr) -> PrimExpr: ...
+def rsqrt(x: PrimExpr) -> PrimExpr: ...
+
+"""
+special_stmt - Buffers
+"""
+
+def match_buffer(
+    param: Union[Var, BufferSlice],
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def buffer_decl(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+def alloc_buffer(
+    shape: Sequence[Union[PrimExpr, int]],
+    dtype: str = "float32",
+    data: Var = None,
+    strides: Optional[Sequence[int]] = None,
+    elem_offset: Optional[int] = None,
+    scope: str = "global",
+    align: int = -1,
+    offset_factor: int = 0,
+    buffer_type: str = "default",
+) -> Buffer: ...
+
+"""
+special_stmt - Reads/Writes
+"""
+
+def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ...
+def block_attr(attrs: Mapping[str, Object]) -> None: ...
+
+"""
+special_stmt - Axis
+"""
+
+class axis:
+    @overload
+    @staticmethod
+    def spatial(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def spatial(
+        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
+    ) -> IterVar: ...
+    @overload
+    @staticmethod
+    def S(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def S(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def reduce(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def reduce(
+        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
+    ) -> IterVar: ...
+    @overload
+    @staticmethod
+    def R(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def R(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def scan(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def scan(
+        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
+    ) -> IterVar: ...
+    @overload
+    @staticmethod
+    def opaque(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
+    @overload
+    @staticmethod
+    def opaque(
+        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
+    ) -> IterVar: ...
+    @staticmethod
+    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
+
+def get_axis(begin: PrimExpr, end: PrimExpr, iter_type: int) -> IterVar: ...
+
+"""
+special_stmt - Annotations
+"""
+
+def buffer_var(dtype: str, storage_scope: str) -> Var: ...
+def func_attr(attrs: Mapping[str, Object]) -> None: ...
+def prim_func(input_func: Callable) -> PrimFunc: ...
+
+"""
+special_stmt - Threads and Bindings
+"""
+
+def env_thread(env_name: str) -> IterVar: ...
+def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
+
+"""
+Scope handler
+"""
+
+class block(ContextManager):
+    def __init__(self, name_hint: str = "") -> None: ...
+    def __enter__(self) -> Sequence[IterVar]: ...
+
+class init(ContextManager):
+    def __init__(self) -> None: ...
+
+class let(ContextManager):
+    def __init__(self, var: Var, value: PrimExpr) -> None: ...
+
+def where(cond: PrimExpr) -> None: ...
+def allocate(
+    extents: List[PrimExpr],
+    dtype: str,
+    scope: str,
+    condition: Union[PrimExpr, builtins.bool] = True,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Var: ...
+def launch_thread(env_var: Var, extent: Union[int, PrimExpr]) -> Var: ...
+def realize(
+    buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True
+) -> None: ...
+def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
+def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr: ...
+
+"""
+Scope handler - Loops
+"""
+
+def serial(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def parallel(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def vectorized(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def unroll(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def thread_binding(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int],
+    thread: str,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def for_range(
+    begin: Union[PrimExpr, int],
+    end: Union[PrimExpr, int] = None,
+    annotations: Optional[Mapping[str, Object]] = None,
+) -> Iterable[IterVar]: ...
+def grid(*extents: Union[PrimExpr, int]) -> Iterable[Tuple[IterVar]]: ...
+
+"""
+ty - redefine types
+"""
+
+class boolean: ...
+
+class handle:
+    def __getitem__(self: handle, pos: Tuple[Union[int, PrimExpr, slice]]) -> Buffer: ...
+    def __setitem__(
+        self: handle, pos: Tuple[Union[int, PrimExpr, slice]], value: Buffer

Review comment:
       please also update this, besides `__setitem__` should return `None`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#issuecomment-961529551


   right now with `mypy --strict`, the only errors we get are:
   ``` python/tvm/script/tir/__init__.pyi: note: In class "Buffer":
   python/tvm/script/tir/__init__.pyi:44:14: error: Class cannot subclass "Var" (has type "Any")
   python/tvm/script/tir/__init__.pyi: note: In function "prim_func":
   python/tvm/script/tir/__init__.pyi:195:27: error: Missing type parameters for generic type "Callable"
   python/tvm/script/tir/__init__.pyi: note: At top level:
   python/tvm/script/tir/__init__.pyi:208:13: error: Missing type parameters for generic type "ContextManager"
   python/tvm/script/tir/__init__.pyi:212:12: error: Missing type parameters for generic type "ContextManager"
   python/tvm/script/tir/__init__.pyi:215:11: error: Missing type parameters for generic type "ContextManager"
   Found 5 errors in 1 file (checked 1 source file)```
   
   Should be good for another round of review now


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9432: [TIR] Add type hint for TIR

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9432:
URL: https://github.com/apache/tvm/pull/9432#discussion_r744013498



##########
File path: tests/python/unittest/test_tvmscript_type.py
##########
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement
+from tvm.script import tir as T
+
+"""
+This prim func include necessary buffer types that need to be checked
+e.g. reads/writes, match_buffer/alloc_buffer, serial/block etc.
+"""
+
+
+@T.prim_func
+def element_wise_storage_align(a: T.handle, c: T.handle) -> None:
+    C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1)
+    A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1)
+    # body
+    with T.block("root"):
+        T.reads([])
+        T.writes([])
+        B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1)
+        for i0 in T.serial(0, 128):
+            for ax1 in T.serial(0, 128):
+                with T.block("B"):
+                    vi, vj = T.axis.remap("SS", [i0, ax1])

Review comment:
       this is changed and pylint&mypy passed




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org