You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ru...@apache.org on 2022/11/12 06:25:29 UTC
[tvm] branch main updated: [TVMScript] Reorganize the folder structure (#12496)
This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b20b7c4ad4 [TVMScript] Reorganize the folder structure (#12496)
b20b7c4ad4 is described below
commit b20b7c4ad4ad3774a42f47614245f8eeabe875cb
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Fri Nov 11 22:25:23 2022 -0800
[TVMScript] Reorganize the folder structure (#12496)
This PR introduces some minor restructuring of the `python/tvm/script`
folder structure to make it more convenient for future upstreaming.
Co-authored-by: Yaxing Cai <ca...@gmail.com>
---
include/tvm/ir/expr.h | 4 +
python/tvm/script/__init__.py | 7 +-
python/tvm/script/ir_builder/tir/__init__.py | 1 +
python/tvm/script/ir_builder/tir/ir.py | 11 +--
python/tvm/script/{_parser => parser}/__init__.py | 1 +
python/tvm/script/{_parser => parser}/_core.py | 0
.../script/{_parser => parser}/core/__init__.py | 0
.../script/{_parser => parser}/core/diagnostics.py | 39 ++++++---
.../script/{_parser => parser}/core/dispatch.py | 0
python/tvm/script/{_parser => parser}/core/doc.py | 0
.../script/{_parser => parser}/core/doc_core.py | 0
.../tvm/script/{_parser => parser}/core/entry.py | 10 +++
.../script/{_parser => parser}/core/evaluator.py | 0
.../tvm/script/{_parser => parser}/core/parser.py | 0
.../tvm/script/{_parser => parser}/core/utils.py | 36 +++++++-
.../tvm/script/{_parser => parser}/ir/__init__.py | 4 +-
python/tvm/script/{_parser => parser}/ir/entry.py | 32 +-------
python/tvm/script/{_parser => parser}/ir/parser.py | 0
.../tvm/script/{_parser => parser}/tir/__init__.py | 0
python/tvm/script/{_parser => parser}/tir/entry.py | 8 +-
.../script/{_parser => parser}/tir/operation.py | 12 +--
.../tvm/script/{_parser => parser}/tir/parser.py | 2 +-
python/tvm/script/{ => parser_v1}/__init__.py | 0
python/tvm/script/{ => parser_v1}/_ffi_api.py | 0
.../script/{ => parser_v1}/context_maintainer.py | 0
python/tvm/script/{ => parser_v1}/diagnostics.py | 0
python/tvm/script/{ => parser_v1}/meta_unparser.py | 0
python/tvm/script/{ => parser_v1}/parser.py | 0
python/tvm/script/{ => parser_v1}/registry.py | 0
python/tvm/script/{ => parser_v1}/tir/__init__.py | 0
python/tvm/script/{ => parser_v1}/tir/__init__.pyi | 0
python/tvm/script/{ => parser_v1}/tir/intrin.py | 2 +-
python/tvm/script/{ => parser_v1}/tir/node.py | 0
python/tvm/script/{ => parser_v1}/tir/prim_func.py | 0
.../script/{ => parser_v1}/tir/scope_handler.py | 0
.../tvm/script/{ => parser_v1}/tir/special_stmt.py | 0
python/tvm/script/{ => parser_v1}/tir/ty.py | 0
python/tvm/script/{ => parser_v1}/utils.py | 0
python/tvm/tir/__init__.py | 2 +
python/tvm/tir/buffer.py | 28 +++----
python/tvm/tir/expr.py | 2 +
python/tvm/tir/schedule/schedule.py | 6 +-
python/tvm/tir/tensor_intrin/cuda.py | 12 +--
python/tvm/tir/tensor_intrin/hexagon.py | 8 +-
python/tvm/tir/tensor_intrin/rocm.py | 1 +
.../test_hexagon/test_async_dma_pipeline.py | 6 +-
tests/python/relay/aot/test_pass_aot_lower_main.py | 6 +-
.../python/unittest/test_tir_lower_match_buffer.py | 42 ++++------
tests/python/unittest/test_tir_schedule_reindex.py | 2 +-
.../unittest/test_tir_schedule_transform_layout.py | 11 +--
.../python/unittest/test_tir_schedule_utilities.py | 2 +-
.../test_tir_transform_inject_virtual_thread.py | 1 +
...t_tir_transform_lower_cross_thread_reduction.py | 96 ----------------------
.../unittest/test_tir_transform_remove_assume.py | 6 +-
.../python/unittest/test_tvmscript_error_report.py | 1 -
.../unittest/test_tvmscript_ir_builder_tir.py | 4 +-
.../unittest/test_tvmscript_parser_evaluator.py | 4 +-
tests/python/unittest/test_tvmscript_parser_ir.py | 2 +-
.../unittest/test_tvmscript_parser_source.py | 4 +-
tests/python/unittest/test_tvmscript_parser_tir.py | 2 +-
tests/python/unittest/test_tvmscript_spans.py | 2 +-
.../python/unittest/test_tvmscript_syntax_sugar.py | 6 +-
62 files changed, 178 insertions(+), 247 deletions(-)
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index b9afb4be2d..94927b4892 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -764,6 +764,10 @@ struct PackedFuncValueConverter<PrimExpr> {
return PrimExpr(ObjectPtr<Object>(nullptr));
}
if (val.type_code() == kDLInt) {
+ int64_t value = val.operator int64_t();
+ if (value > std::numeric_limits<int>::max() || value < std::numeric_limits<int>::min()) {
+ return IntImm(runtime::DataType::Int(64), value);
+ }
return IntImm(runtime::DataType::Int(32), val.operator int());
}
if (val.type_code() == kDLFloat) {
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py
index 555659d0c5..21bdfa6f16 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/__init__.py
@@ -14,8 +14,5 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
-from . import tir
-
-from .parser import ir_module, from_source
+"""TVM Script APIs of TVM Python Package"""
+from .parser import ir, ir_module, parse as from_source, tir
diff --git a/python/tvm/script/ir_builder/tir/__init__.py b/python/tvm/script/ir_builder/tir/__init__.py
index 1e43d1af34..0a71af4db7 100644
--- a/python/tvm/script/ir_builder/tir/__init__.py
+++ b/python/tvm/script/ir_builder/tir/__init__.py
@@ -16,3 +16,4 @@
# under the License.
"""Package tvm.script.ir_builder.tir"""
from .ir import * # pylint: disable=wildcard-import,redefined-builtin
+from .ir import boolean as bool # pylint: disable=redefined-builtin
diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py
index bd9e4e1db5..0678925e2f 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1665,13 +1665,14 @@ start_profile_intrinsic = _op_wrapper(_tir_op.start_profile_intrinsic)
end_profile_intrinsic = _op_wrapper(_tir_op.end_profile_intrinsic)
-class inline:
- """Inline function for meta-programming.
+class meta_var:
+ """A meta variable used in TVMScript metaprogramming. It means that the value of the variable
+ does not appear in the final TIR, but only stays in the parser.
Parameters
----------
value: Any
- The value to be inlined.
+ The meta variable.
"""
def __init__(self, value: Any) -> None:
@@ -1680,7 +1681,7 @@ class inline:
def __iter__(self):
def f():
for i in self.value:
- yield inline(i)
+ yield meta_var(i)
return f()
@@ -1844,7 +1845,7 @@ __all__ += [
"TVMBackendFreeWorkspace",
"start_profile_intrinsic",
"end_profile_intrinsic",
- "inline",
+ "meta_var",
"llvm_lookup_intrinsic_id",
"type_annotation",
"broadcast",
diff --git a/python/tvm/script/_parser/__init__.py b/python/tvm/script/parser/__init__.py
similarity index 97%
rename from python/tvm/script/_parser/__init__.py
rename to python/tvm/script/parser/__init__.py
index 38c8b88cc7..5161a2601c 100644
--- a/python/tvm/script/_parser/__init__.py
+++ b/python/tvm/script/parser/__init__.py
@@ -16,5 +16,6 @@
# under the Licens.
"""The parser"""
from . import _core, ir, tir
+from ._core import parse
from .ir import ir_module
from .tir import prim_func
diff --git a/python/tvm/script/_parser/_core.py b/python/tvm/script/parser/_core.py
similarity index 100%
rename from python/tvm/script/_parser/_core.py
rename to python/tvm/script/parser/_core.py
diff --git a/python/tvm/script/_parser/core/__init__.py b/python/tvm/script/parser/core/__init__.py
similarity index 100%
rename from python/tvm/script/_parser/core/__init__.py
rename to python/tvm/script/parser/core/__init__.py
diff --git a/python/tvm/script/_parser/core/diagnostics.py b/python/tvm/script/parser/core/diagnostics.py
similarity index 88%
rename from python/tvm/script/_parser/core/diagnostics.py
rename to python/tvm/script/parser/core/diagnostics.py
index b077d22142..d673e0eb13 100644
--- a/python/tvm/script/_parser/core/diagnostics.py
+++ b/python/tvm/script/parser/core/diagnostics.py
@@ -17,7 +17,6 @@
"""TVM Script Parser Source and diagnostics"""
import inspect
-import re
import sys
from typing import Union
@@ -144,18 +143,34 @@ def findsource(obj):
if not lines:
raise OSError("could not get source code")
qual_names = obj.__qualname__.replace(".<locals>", "<locals>").split(".")
- pattern_list = []
- for name in qual_names:
- if name.endswith("<locals>"):
- pattern_list.append(re.compile(r"^(\s*)def\s*" + name[:-8] + r"\b"))
- else:
- pattern_list.append(re.compile(r"^(\s*)class\s*" + name + r"\b"))
+ in_comment = 0
+ scope_stack = []
+ indent_info = {}
for i, line in enumerate(lines):
- match = pattern_list[0].match(line)
- if match:
- pattern_list.pop(0)
- if not pattern_list:
- return lines, i
+ n_comment = line.count('"""')
+ if n_comment:
+ # update multi-line comments status
+ in_comment = in_comment ^ (n_comment & 1)
+ continue
+ if in_comment:
+ # skip lines within multi-line comments
+ continue
+ indent = len(line) - len(line.lstrip())
+ tokens = line.split()
+ if len(tokens) > 1:
+ name = None
+ if tokens[0] == "def":
+ name = tokens[1].split(":")[0].split("(")[0] + "<locals>"
+ elif tokens[0] == "class":
+ name = tokens[1].split(":")[0].split("(")[0]
+ if name:
+ while scope_stack and indent_info[scope_stack[-1]] >= indent:
+ scope_stack.pop()
+ scope_stack.append(name)
+ indent_info[name] = indent
+ if scope_stack == qual_names:
+ return lines, i
+
raise OSError("could not find class definition")
diff --git a/python/tvm/script/_parser/core/dispatch.py b/python/tvm/script/parser/core/dispatch.py
similarity index 100%
rename from python/tvm/script/_parser/core/dispatch.py
rename to python/tvm/script/parser/core/dispatch.py
diff --git a/python/tvm/script/_parser/core/doc.py b/python/tvm/script/parser/core/doc.py
similarity index 100%
rename from python/tvm/script/_parser/core/doc.py
rename to python/tvm/script/parser/core/doc.py
diff --git a/python/tvm/script/_parser/core/doc_core.py b/python/tvm/script/parser/core/doc_core.py
similarity index 100%
rename from python/tvm/script/_parser/core/doc_core.py
rename to python/tvm/script/parser/core/doc_core.py
diff --git a/python/tvm/script/_parser/core/entry.py b/python/tvm/script/parser/core/entry.py
similarity index 83%
rename from python/tvm/script/_parser/core/entry.py
rename to python/tvm/script/parser/core/entry.py
index a0974c8fd4..bf6a118672 100644
--- a/python/tvm/script/_parser/core/entry.py
+++ b/python/tvm/script/parser/core/entry.py
@@ -40,6 +40,16 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None)
func : Any
The parsed TVMScript program.
"""
+ if extra_vars is None:
+ from tvm.script.parser import ir # pylint: disable=import-outside-toplevel
+ from tvm.script.parser import tir # pylint: disable=import-outside-toplevel
+
+ extra_vars = {
+ "I": ir,
+ "ir": ir,
+ "T": tir,
+ "tir": tir,
+ }
source = Source(program)
parser = Parser(source)
diff --git a/python/tvm/script/_parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py
similarity index 100%
rename from python/tvm/script/_parser/core/evaluator.py
rename to python/tvm/script/parser/core/evaluator.py
diff --git a/python/tvm/script/_parser/core/parser.py b/python/tvm/script/parser/core/parser.py
similarity index 100%
rename from python/tvm/script/_parser/core/parser.py
rename to python/tvm/script/parser/core/parser.py
diff --git a/python/tvm/script/_parser/core/utils.py b/python/tvm/script/parser/core/utils.py
similarity index 63%
rename from python/tvm/script/_parser/core/utils.py
rename to python/tvm/script/parser/core/utils.py
index 65e7166bfc..a304afddbe 100644
--- a/python/tvm/script/_parser/core/utils.py
+++ b/python/tvm/script/parser/core/utils.py
@@ -17,7 +17,10 @@
"""TVM Script Parser utils"""
import inspect
-from typing import Any, Callable, Dict
+from types import FrameType
+from typing import Any, Callable, Dict, List
+
+from .diagnostics import findsource
def inspect_function_capture(func: Callable) -> Dict[str, Any]:
@@ -59,3 +62,34 @@ def inspect_class_capture(cls: type) -> Dict[str, Any]:
func_vars = inspect_function_capture(v)
result.update(**func_vars)
return result
+
+
+def is_defined_in_class(frames: List[FrameType], obj: Any) -> bool:
+ """Check whether a object is defined in a class scope.
+
+ Parameters
+ ----------
+ frames : List[FrameType]
+ The frame stack of the object, obtained by `inspect.stack()`.
+
+ Returns
+ -------
+ res : bool
+ The result if the object is defined in a class scope.
+ """
+ if len(frames) > 2:
+ frame_info = frames[2]
+ code_context = frame_info.code_context
+ if code_context is None:
+ return False
+ line = code_context[0].strip()
+ if line.startswith("@") and "ir_module" in line:
+ return True
+ if line.startswith("class"):
+ lineno = frame_info.lineno
+ if lineno >= 2:
+ source, _ = findsource(obj)
+ line = source[lineno - 2].strip()
+ if line.startswith("@") and "ir_module" in line:
+ return True
+ return False
diff --git a/python/tvm/script/_parser/ir/__init__.py b/python/tvm/script/parser/ir/__init__.py
similarity index 89%
rename from python/tvm/script/_parser/ir/__init__.py
rename to python/tvm/script/parser/ir/__init__.py
index b15468d37a..fedd2f0a14 100644
--- a/python/tvm/script/_parser/ir/__init__.py
+++ b/python/tvm/script/parser/ir/__init__.py
@@ -17,6 +17,6 @@
"""The ir module parser"""
from . import parser as _parser
-from .entry import ir_module, is_defined_in_class
+from .entry import ir_module
-__all__ = ["ir_module", "is_defined_in_class"]
+__all__ = ["ir_module"]
diff --git a/python/tvm/script/_parser/ir/entry.py b/python/tvm/script/parser/ir/entry.py
similarity index 61%
rename from python/tvm/script/_parser/ir/entry.py
rename to python/tvm/script/parser/ir/entry.py
index e8bc8b702d..94fc3d2e2c 100644
--- a/python/tvm/script/_parser/ir/entry.py
+++ b/python/tvm/script/parser/ir/entry.py
@@ -17,41 +17,13 @@
"""The entry point of TVM parser for ir module."""
import inspect
-from typing import List, Type
-from types import FrameType
+from typing import Type
from tvm.ir import IRModule
from .._core import parse, utils
-def is_defined_in_class(frames: List[FrameType]) -> bool:
- """Check whether a object is defined in a class scope.
-
- Parameters
- ----------
- frames : List[FrameType]
- The frame stack of the object, obtained by `inspect.stack()`.
-
- Returns
- -------
- res : bool
- The result if the object is defined in a class scope.
- """
- if len(frames) > 2:
- maybe_class_frame = frames[2]
- statement_list = maybe_class_frame[4]
- if statement_list is None:
- return False
- first_statement = statement_list[0]
- line = first_statement.strip()
- if line.startswith("class "):
- return True
- if line.startswith("@") and "ir_module" in line:
- return True
- return False
-
-
def ir_module(mod: Type) -> IRModule:
"""The parsing method for ir module, by using `@ir_module` as decorator.
@@ -62,7 +34,7 @@ def ir_module(mod: Type) -> IRModule:
Returns
-------
- irmodule : IRModule
+ ir_module : IRModule
The parsed ir module.
"""
if not inspect.isclass(mod):
diff --git a/python/tvm/script/_parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py
similarity index 100%
rename from python/tvm/script/_parser/ir/parser.py
rename to python/tvm/script/parser/ir/parser.py
diff --git a/python/tvm/script/_parser/tir/__init__.py b/python/tvm/script/parser/tir/__init__.py
similarity index 100%
rename from python/tvm/script/_parser/tir/__init__.py
rename to python/tvm/script/parser/tir/__init__.py
diff --git a/python/tvm/script/_parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py
similarity index 94%
rename from python/tvm/script/_parser/tir/entry.py
rename to python/tvm/script/parser/tir/entry.py
index 632b87aa24..a5c134a859 100644
--- a/python/tvm/script/_parser/tir/entry.py
+++ b/python/tvm/script/parser/tir/entry.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""The entry point of TVM parser for tir."""
-
import inspect
from typing import Callable, Union
@@ -23,7 +22,6 @@ from tvm.tir import Buffer, PrimFunc
from ...ir_builder.tir import buffer_decl, ptr
from .._core import parse, utils
-from ..ir import is_defined_in_class
def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
@@ -41,7 +39,7 @@ def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
"""
if not inspect.isfunction(func):
raise TypeError(f"Expect a function, but got: {func}")
- if is_defined_in_class(inspect.stack()):
+ if utils.is_defined_in_class(inspect.stack(), func):
return func
return parse(func, utils.inspect_function_capture(func))
@@ -57,7 +55,7 @@ class BufferProxy:
def __call__(
self,
shape,
- dtype="float32",
+ dtype=None,
data=None,
strides=None,
elem_offset=None,
@@ -67,6 +65,8 @@ class BufferProxy:
buffer_type="",
axis_separators=None,
) -> Buffer:
+ if dtype is None:
+ raise ValueError("Data type must be specified when constructing buffer")
return buffer_decl(
shape,
dtype=dtype,
diff --git a/python/tvm/script/_parser/tir/operation.py b/python/tvm/script/parser/tir/operation.py
similarity index 90%
rename from python/tvm/script/_parser/tir/operation.py
rename to python/tvm/script/parser/tir/operation.py
index ed8f07a063..f0c04f47cd 100644
--- a/python/tvm/script/_parser/tir/operation.py
+++ b/python/tvm/script/parser/tir/operation.py
@@ -46,12 +46,12 @@ def _register_expr_op(ty: Type): # pylint: disable=invalid-name
for i in [0, 1]:
# Case 1. binop
- r(doc.Add, i, lambda a, b: a + b)
- r(doc.Sub, i, lambda a, b: a - b)
- r(doc.Mult, i, lambda a, b: a * b)
- r(doc.Div, i, lambda a, b: a / b)
- r(doc.FloorDiv, i, lambda a, b: a // b)
- r(doc.Mod, i, lambda a, b: a % b)
+ r(doc.Add, i, tir.Add)
+ r(doc.Sub, i, tir.Sub)
+ r(doc.Mult, i, tir.Mul)
+ r(doc.Div, i, tir.Div)
+ r(doc.FloorDiv, i, tir.FloorDiv)
+ r(doc.Mod, i, tir.FloorMod)
r(doc.LShift, i, lambda a, b: a << b)
r(doc.RShift, i, lambda a, b: a >> b)
r(doc.BitOr, i, lambda a, b: a | b)
diff --git a/python/tvm/script/_parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py
similarity index 99%
rename from python/tvm/script/_parser/tir/parser.py
rename to python/tvm/script/parser/tir/parser.py
index 909238563f..1370758f5a 100644
--- a/python/tvm/script/_parser/tir/parser.py
+++ b/python/tvm/script/parser/tir/parser.py
@@ -125,7 +125,7 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -
res : Any
The bound value.
"""
- if isinstance(value, T.inline):
+ if isinstance(value, T.meta_var):
return value.value
elif isinstance(value, (list, tuple)):
for i, v in enumerate(value):
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser_v1/__init__.py
similarity index 100%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser_v1/__init__.py
diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/script/parser_v1/_ffi_api.py
similarity index 100%
rename from python/tvm/script/_ffi_api.py
rename to python/tvm/script/parser_v1/_ffi_api.py
diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/parser_v1/context_maintainer.py
similarity index 100%
rename from python/tvm/script/context_maintainer.py
rename to python/tvm/script/parser_v1/context_maintainer.py
diff --git a/python/tvm/script/diagnostics.py b/python/tvm/script/parser_v1/diagnostics.py
similarity index 100%
rename from python/tvm/script/diagnostics.py
rename to python/tvm/script/parser_v1/diagnostics.py
diff --git a/python/tvm/script/meta_unparser.py b/python/tvm/script/parser_v1/meta_unparser.py
similarity index 100%
rename from python/tvm/script/meta_unparser.py
rename to python/tvm/script/parser_v1/meta_unparser.py
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser_v1/parser.py
similarity index 100%
rename from python/tvm/script/parser.py
rename to python/tvm/script/parser_v1/parser.py
diff --git a/python/tvm/script/registry.py b/python/tvm/script/parser_v1/registry.py
similarity index 100%
rename from python/tvm/script/registry.py
rename to python/tvm/script/parser_v1/registry.py
diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/parser_v1/tir/__init__.py
similarity index 100%
rename from python/tvm/script/tir/__init__.py
rename to python/tvm/script/parser_v1/tir/__init__.py
diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/parser_v1/tir/__init__.pyi
similarity index 100%
rename from python/tvm/script/tir/__init__.pyi
rename to python/tvm/script/parser_v1/tir/__init__.pyi
diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/parser_v1/tir/intrin.py
similarity index 99%
rename from python/tvm/script/tir/intrin.py
rename to python/tvm/script/parser_v1/tir/intrin.py
index 8e24f27325..9cde8e3f6d 100644
--- a/python/tvm/script/tir/intrin.py
+++ b/python/tvm/script/parser_v1/tir/intrin.py
@@ -22,7 +22,7 @@ from typing import Any, List
import tvm.tir
from tvm.tir import FloatImm
-from ...target import codegen
+from ....target import codegen
from ..registry import register
from ..utils import get_param_list, tvm_span_from_synr
diff --git a/python/tvm/script/tir/node.py b/python/tvm/script/parser_v1/tir/node.py
similarity index 100%
rename from python/tvm/script/tir/node.py
rename to python/tvm/script/parser_v1/tir/node.py
diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/parser_v1/tir/prim_func.py
similarity index 100%
rename from python/tvm/script/tir/prim_func.py
rename to python/tvm/script/parser_v1/tir/prim_func.py
diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/parser_v1/tir/scope_handler.py
similarity index 100%
rename from python/tvm/script/tir/scope_handler.py
rename to python/tvm/script/parser_v1/tir/scope_handler.py
diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/parser_v1/tir/special_stmt.py
similarity index 100%
rename from python/tvm/script/tir/special_stmt.py
rename to python/tvm/script/parser_v1/tir/special_stmt.py
diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/parser_v1/tir/ty.py
similarity index 100%
rename from python/tvm/script/tir/ty.py
rename to python/tvm/script/parser_v1/tir/ty.py
diff --git a/python/tvm/script/utils.py b/python/tvm/script/parser_v1/utils.py
similarity index 100%
rename from python/tvm/script/utils.py
rename to python/tvm/script/parser_v1/utils.py
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index d02f7fab7a..a2e341d823 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -48,6 +48,7 @@ from .function import PrimFunc, TensorIntrin, IndexMap
from .op import call_packed_lowered, call_cpacked_lowered
from .op import call_packed, call_cpacked, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
+from .op import tvm_check_return
from .op import tvm_stack_alloca, tvm_stack_make_shape, tvm_stack_make_array
from .op import tvm_tuple, tvm_struct_get, tvm_struct_set
from .op import address_of, lookup_param, assume, undef
@@ -74,6 +75,7 @@ from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod,
from .op import comm_reducer, min, max, sum
from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right
from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace
+from .op import start_profile_intrinsic, end_profile_intrinsic
from .generic import add, subtract, multiply
from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py
index 6d9d3ce1d1..726d5d1c98 100644
--- a/python/tvm/tir/buffer.py
+++ b/python/tvm/tir/buffer.py
@@ -184,31 +184,31 @@ class Buffer(Object):
if not isinstance(indices, (tuple, list)):
indices = [indices]
- if any(isinstance(index, slice) and index.step is None for index in indices):
+ has_slice = any(isinstance(i, slice) for i in indices)
+ has_step = any(isinstance(i, slice) and i.step is not None for i in indices)
+ analyzer = Analyzer()
+ if has_slice and not has_step:
region = []
- analyzer = Analyzer()
- for index in indices:
+ for i, index in enumerate(indices):
if isinstance(index, slice):
- region.append(
- Range.from_min_extent(
- index.start, analyzer.simplify(index.stop - index.start)
- )
- )
+ start = 0 if index.start is None else index.start
+ stop = self.shape[i] if index.stop is None else index.stop
+ region.append(Range.from_min_extent(start, analyzer.simplify(stop - start)))
else:
region.append(Range.from_min_extent(index, 1))
return BufferRegion(self, region)
else:
- analyzer = Analyzer()
expr_indices = []
for index in indices:
if isinstance(index, slice):
- lanes = analyzer.simplify(
- (index.stop - index.start + index.step - 1) // index.step
- )
+ start = 0 if index.start is None else index.start
+ stop = self.shape[i] if index.stop is None else index.stop
+ step = 1 if index.step is None else index.step
+ lanes = analyzer.simplify((stop - start + step - 1) // step)
if lanes == 1:
- expr_indices.append(index.start)
+ expr_indices.append(start)
else:
- expr_indices.append(Ramp(index.start, index.step, int(lanes)))
+ expr_indices.append(Ramp(start, step, int(lanes)))
else:
expr_indices.append(index)
return BufferLoad(self, expr_indices)
diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py
index beefcb0d28..d52fbb83c3 100644
--- a/python/tvm/tir/expr.py
+++ b/python/tvm/tir/expr.py
@@ -1005,6 +1005,8 @@ class Select(PrimExprWithOp):
"""
def __init__(self, condition, true_value, false_value, span=None):
+ if isinstance(condition, bool):
+ condition = IntImm("bool", condition)
self.__init_handle_by_constructor__(
_ffi_api.Select, condition, true_value, false_value, span # type: ignore
)
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index c5b7937c60..170179d0d4 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -1832,7 +1832,7 @@ class Schedule(Object):
.. code-block:: python
- @tvm.script.tir
+ @T.prim_func
def before_decompose(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128])
B = tir.match_buffer(b, [128, 128])
@@ -1851,13 +1851,13 @@ class Schedule(Object):
C = sch.get_block("C")
i, j, k = sch.get_loops(C)
sch.decompose_reduction(C, i)
- print(tvm.script.asscript(sch.mod["main"]))
+ print(sch.mod["main"].script())
After applying decompose-reduction, the IR becomes:
.. code-block:: python
- @tvm.script.tir
+ @T.prim_func
def after_decompose(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128])
B = tir.match_buffer(b, [128, 128])
diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py
index 86dd2eee5c..0cde7f2464 100644
--- a/python/tvm/tir/tensor_intrin/cuda.py
+++ b/python/tvm/tir/tensor_intrin/cuda.py
@@ -138,7 +138,7 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed, shared_scope="shared"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(shared[v0, v1])
- thread_id, local_id = index_map(v0, v1)
+ thread_id, local_id = T.meta_var(index_map(v0, v1))
T.writes(warp[thread_id, local_id])
warp[thread_id, local_id] = shared[v0, v1]
@@ -245,9 +245,9 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
i, j, k = T.axis.remap("SSR", [i, j, k])
b_row_ind, b_col_ind = maybe_swap(k, j)
- thread_id_C, local_id_C = index_map_C(i, j)
- thread_id_A, local_id_A = index_map_A(i, k)
- thread_id_B, local_id_B = index_map_B(b_row_ind, b_col_ind)
+ thread_id_C, local_id_C = T.meta_var(index_map_C(i, j))
+ thread_id_A, local_id_A = T.meta_var(index_map_A(i, k))
+ thread_id_B, local_id_B = T.meta_var(index_map_B(b_row_ind, b_col_ind))
T.reads(
C[thread_id_C, local_id_C],
@@ -339,7 +339,7 @@ def get_mma_fill_intrin(dtype, local_size):
for i0, i1 in T.grid(M_DIM, N_DIM):
with T.block("C_warp"):
i, j = T.axis.remap("SS", [i0, i1])
- thread_id, local_id = index_map(i, j)
+ thread_id, local_id = T.meta_var(index_map(i, j))
T.reads()
T.writes(C_warp[thread_id, local_id])
C_warp[thread_id, local_id] = zero
@@ -376,7 +376,7 @@ def get_mma_store_intrin(dtype, local_size, scope="global"):
for i0, i1 in T.grid(M_DIM, N_DIM):
with T.block("C_warp"):
v0, v1 = T.axis.remap("SS", [i0, i1])
- thread_id, local_id = index_map(v0, v1)
+ thread_id, local_id = T.meta_var(index_map(v0, v1))
T.reads(C_warp[thread_id, local_id])
T.writes(C[v0, v1])
C[v0, v1] = C_warp[thread_id, local_id]
diff --git a/python/tvm/tir/tensor_intrin/hexagon.py b/python/tvm/tir/tensor_intrin/hexagon.py
index 3cad94006d..6fa9dd8f00 100644
--- a/python/tvm/tir/tensor_intrin/hexagon.py
+++ b/python/tvm/tir/tensor_intrin/hexagon.py
@@ -30,10 +30,10 @@ def dot_product_32x4_u8u8i32_desc(
T.reads(C[0:32], A[0:4], B[0:32, 0:4])
T.writes(C[0:32])
for i in T.serial(0, 32):
- with T.init():
- C[i] = T.int32(0)
for k in T.serial(0, 4):
with T.block("update"):
+ with T.init():
+ C[i] = T.int32(0)
vi, vk = T.axis.remap("SR", [i, k])
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")
@@ -74,10 +74,10 @@ def dot_product_32x4_u8i8i32_desc(
T.reads(C[0:32], A[0:4], B[0:32, 0:4])
T.writes(C[0:32])
for i in T.serial(0, 32):
- with T.init():
- C[i] = T.int32(0)
for k in T.serial(0, 4):
with T.block("update"):
+ with T.init():
+ C[i] = T.int32(0)
vi, vk = T.axis.remap("SR", [i, k])
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")
diff --git a/python/tvm/tir/tensor_intrin/rocm.py b/python/tvm/tir/tensor_intrin/rocm.py
index 7a989d0bcc..3700f3e8da 100644
--- a/python/tvm/tir/tensor_intrin/rocm.py
+++ b/python/tvm/tir/tensor_intrin/rocm.py
@@ -17,6 +17,7 @@
# pylint: disable=invalid-name,missing-function-docstring
"""Intrinsics for AMDGPU tensorization."""
from tvm.script import tir as T
+
from .. import TensorIntrin
from .dot_product_common import dp4a_desc
diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
index 9f8e639b53..ef9b142d6f 100644
--- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
+++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
@@ -128,9 +128,9 @@ def get_single_dma_schedule(size_a, size_w):
a_buffer = T.match_buffer(a_input, a_shape, dtype="uint8", scope="global")
w_buffer = T.match_buffer(b_input, w_shape, dtype="uint8", scope="global")
c_buffer = T.match_buffer(c_output, out_shape, dtype="int32", scope="global")
- a_global_vtcm = T.alloc_buffer(a_shape, dtype="uint8", mem_scope="global.vtcm")
- w_global_vtcm = T.alloc_buffer(w_shape, dtype="uint8", mem_scope="global.vtcm")
- c_global_vtcm = T.alloc_buffer(out_shape, dtype="int32", mem_scope="global.vtcm")
+ a_global_vtcm = T.alloc_buffer(a_shape, dtype="uint8", scope="global")
+ w_global_vtcm = T.alloc_buffer(w_shape, dtype="uint8", scope="global")
+ c_global_vtcm = T.alloc_buffer(out_shape, dtype="int32", scope="global")
T.evaluate(
T.tvm_call_packed(
"device_api.hexagon.mem_copy_DLTensor",
diff --git a/tests/python/relay/aot/test_pass_aot_lower_main.py b/tests/python/relay/aot/test_pass_aot_lower_main.py
index 0a9d95247a..093305203a 100644
--- a/tests/python/relay/aot/test_pass_aot_lower_main.py
+++ b/tests/python/relay/aot/test_pass_aot_lower_main.py
@@ -17,11 +17,11 @@
# pylint: disable=line-too-long,missing-class-docstring,missing-module-docstring,missing-function-docstring,no-self-argument,unused-argument,invalid-name
import numpy as np
import pytest
-
import tvm
import tvm.testing
-from tvm.script import tir as T
+from tvm.ir import assert_structural_equal
from tvm.relay.backend.aot import AOTLowerMain, CallType
+from tvm.script import tir as T
def _make_const(dtype, shape):
@@ -48,7 +48,7 @@ def _assert_lowered_main(mod, main_func, call_type, print_script=False):
if print_script:
print(mod["__tvm_main__"].script())
- assert mod["__tvm_main__"].script() == main_func.script()
+ assert_structural_equal(mod["__tvm_main__"], main_func)
def test_single_call_cpacked():
diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py
index 6120cf2b67..535e0bb329 100644
--- a/tests/python/unittest/test_tir_lower_match_buffer.py
+++ b/tests/python/unittest/test_tir_lower_match_buffer.py
@@ -82,14 +82,13 @@ def opaque_access(a: T.handle, b: T.handle) -> None:
offset_factor=1,
)
T.evaluate(
- T.intrin_test(
+ intrin_test(
sub_A.data,
sub_A.elem_offset,
sub_A.strides[0],
sub_A.strides[1],
sub_A.shape[0],
sub_A.shape[1],
- dtype="handle",
)
)
for i, j, k in T.grid(64, 2, 8):
@@ -105,14 +104,13 @@ def opaque_access(a: T.handle, b: T.handle) -> None:
offset_factor=1,
)
T.evaluate(
- T.intrin_test(
+ intrin_test(
sub_B.data,
sub_B.elem_offset,
sub_B.strides[0],
sub_B.strides[1],
sub_B.shape[0],
sub_B.shape[1],
- dtype="handle",
)
)
@@ -126,14 +124,13 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None:
T.reads([])
T.writes(A[i * 16 : i * 16 + 16, j, k * 16 : k * 16 + 16])
T.evaluate(
- T.intrin_test(
+ intrin_test(
A.data,
i * 131072 + j * 128 + k * 16,
8192,
128,
16,
1,
- dtype="handle",
)
)
for i, j, k in T.grid(64, 2, 8):
@@ -141,14 +138,13 @@ def transformed_opaque_access(a: T.handle, b: T.handle) -> None:
T.reads([])
T.writes(B[i, j * 32 : j * 32 + 32, k * 8 : k * 8 + 8])
T.evaluate(
- T.intrin_test(
+ intrin_test(
B.data,
i * 4096 + j * 2048 + k * 8,
64,
1,
32,
8,
- dtype="handle",
)
)
@@ -169,14 +165,13 @@ def high_dim_opaque_access(a: T.handle) -> None:
offset_factor=1,
)
T.evaluate(
- T.intrin_test(
+ intrin_test(
sub_A.data,
sub_A.elem_offset,
sub_A.strides[0],
sub_A.strides[1],
sub_A.shape[0],
sub_A.shape[1],
- dtype="handle",
)
)
@@ -189,14 +184,13 @@ def transformed_high_dim_opaque_access(a: T.handle) -> None:
T.reads([])
T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16])
T.evaluate(
- T.intrin_test(
+ intrin_test(
A.data,
i * 2048 + j * 1024 + k * 16,
64,
1,
16,
16,
- dtype="handle",
)
)
@@ -217,14 +211,13 @@ def high_dim_opaque_access_with_source_strides(a: T.handle) -> None:
offset_factor=1,
)
T.evaluate(
- T.intrin_test(
+ intrin_test(
sub_A.data,
sub_A.elem_offset,
sub_A.strides[0],
sub_A.strides[1],
sub_A.shape[0],
sub_A.shape[1],
- dtype="handle",
)
)
@@ -237,14 +230,13 @@ def transformed_high_dim_opaque_access_with_source_strides(a: T.handle) -> None:
T.reads([])
T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16])
T.evaluate(
- T.intrin_test(
+ intrin_test(
A.data,
i * 2576 + j * 1280 + k * 16,
80,
1,
16,
16,
- dtype="handle",
)
)
@@ -298,14 +290,13 @@ def recursive_match(a: T.handle, b: T.handle) -> None:
offset_factor=1,
)
T.evaluate(
- T.intrin_test(
+ intrin_test(
sub_sub_A.data,
sub_sub_A.elem_offset,
sub_sub_A.strides[0],
sub_sub_A.strides[1],
sub_sub_A.shape[0],
sub_sub_A.shape[1],
- dtype="handle",
)
)
for jjj, kkk in T.grid(4, 4):
@@ -343,14 +334,13 @@ def transformed_recursive_match(a: T.handle, b: T.handle) -> None:
]
)
T.evaluate(
- T.intrin_test(
+ intrin_test(
A.data,
i * 4096 + j * 1024 + jj * 256 + k * 16 + kk * 4,
64,
1,
4,
4,
- dtype="handle",
)
)
for jjj, kkk in T.grid(4, 4):
@@ -375,14 +365,13 @@ def symbolic_match(a: T.handle, b: T.handle, n: T.int32, m: T.int32) -> None:
sub_A[ii, jj] = 1
for j in range(0, 4):
T.evaluate(
- T.intrin_test(
+ intrin_test(
sub_B.data,
sub_B.elem_offset,
sub_B.strides[0],
sub_B.strides[1],
sub_B.shape[0],
sub_B.shape[1],
- dtype="handle",
)
)
@@ -399,14 +388,13 @@ def transformed_symbolic_match(a: T.handle, b: T.handle, n: T.int32, m: T.int32)
A[i * m + ii, jj] = 1
for j in range(0, 4):
T.evaluate(
- T.intrin_test(
+ intrin_test(
B.data,
i * n * (m * 4),
m * 4,
1,
2,
m * 4,
- dtype="handle",
)
)
@@ -423,14 +411,13 @@ def rank0_buffer(a: T.handle, b: T.handle) -> None:
sub_B = T.match_buffer(B[i, j], (), offset_factor=1)
sub_A[()] = 1
T.evaluate(
- T.intrin_test(
+ intrin_test(
sub_B.data,
sub_B.elem_offset,
0,
0,
0,
0,
- dtype="handle",
)
)
@@ -445,14 +432,13 @@ def transformed_rank0_buffer(a: T.handle, b: T.handle) -> None:
T.writes([A[i, j], B[i, j]])
A[i, j] = 1
T.evaluate(
- T.intrin_test(
+ intrin_test(
B.data,
i * 8 + j,
0,
0,
0,
0,
- dtype="handle",
)
)
diff --git a/tests/python/unittest/test_tir_schedule_reindex.py b/tests/python/unittest/test_tir_schedule_reindex.py
index 53bc726cea..b5e6694301 100644
--- a/tests/python/unittest/test_tir_schedule_reindex.py
+++ b/tests/python/unittest/test_tir_schedule_reindex.py
@@ -233,7 +233,7 @@ def mixed_dtype_reindex_write(
for ax0, ax1 in T.grid(T.int64(2), 1280):
with T.block("T_matmul_NT_reindex"):
v0 = T.axis.spatial(T.int64(2), ax0)
- (v1,) = T.axis.remap("S", [ax1])
+ v1 = T.axis.remap("S", [ax1])
T.reads(T_matmul_NT_reindex[v0, v1])
T.writes(T_matmul_NT[v0, v1])
T_matmul_NT[v0, v1] = T_matmul_NT_reindex[v0, v1]
diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py
index 282f1dcf49..ca5ac12a97 100644
--- a/tests/python/unittest/test_tir_schedule_transform_layout.py
+++ b/tests/python/unittest/test_tir_schedule_transform_layout.py
@@ -18,7 +18,6 @@
import sys
import pytest
-
import tvm
import tvm.testing
from tvm import tir
@@ -707,7 +706,7 @@ class TestPaddedTransformOfInputCreatesAssumption(BasePaddingCompare):
for i, j in T.grid(4, 4):
with T.block("buffer_A_assumption"):
vi, vj = T.axis.remap("SS", [i, j])
- T.assume(not (vi == 3 and 2 <= vj) or A[vi, vj] == 42)
+ T.evaluate(T.assume(not (vi == 3 and 2 <= vj) or A[vi, vj] == 42))
for i in T.serial(14):
with T.block("block"):
@@ -790,9 +789,11 @@ class TestPaddedTransformRepeatedBufferElement(tvm.testing.CompareBeforeAfter):
for i, j in T.grid(4, 4):
with T.block("buffer_A_assumption"):
vi, vj = T.axis.remap("SS", [i, j])
- T.assume(
- not (vi == 3 and 2 <= vj)
- or A[vi, vj] == A[((4 * vi + j) % 14) // 4, ((4 * vi + j) % 14) % 4]
+ T.evaluate(
+ T.assume(
+ not (vi == 3 and 2 <= vj)
+ or A[vi, vj] == A[((4 * vi + j) % 14) // 4, ((4 * vi + j) % 14) % 4]
+ )
)
B = T.alloc_buffer(14, "int32")
diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py
index 33ef0e2215..2f6c2f6a51 100644
--- a/tests/python/unittest/test_tir_schedule_utilities.py
+++ b/tests/python/unittest/test_tir_schedule_utilities.py
@@ -150,7 +150,7 @@ def tuple_reduction(data: T.Buffer[(4, 32), "float32"], T_add: T.Buffer[(4,), "f
data_red_temp_v1[ax0] = v_data_red_temp_v1
for i0 in range(4):
with T.block("T_add"):
- (ax0,) = T.axis.remap("S", [i0])
+ ax0 = T.axis.remap("S", [i0])
T.reads(data_red_temp_v0[ax0], data_red_temp_v1[ax0])
T.writes(T_add[ax0])
T_add[ax0] = data_red_temp_v0[ax0] + data_red_temp_v1[ax0]
diff --git a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
index b4ea4e712d..eb5ed08bb5 100644
--- a/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
+++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
+import tvm.testing
from tvm import te
from tvm.script import tir as T
diff --git a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
index 42c7fbc0d4..3ab09f01dd 100644
--- a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
+++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py
@@ -548,88 +548,6 @@ def single_reduction_loop_with_tensorize(
)
-@T.prim_func
-def nested_reduction_loop_with_inner_match_buffers(
- in0: T.Buffer[(4, 16), "int8"],
- in1: T.Buffer[(4, 16), "int8"],
- out: T.Buffer[(4, 4), "int32"],
-) -> None:
- # body
- # with T.block("root")
- for y in T.serial(4):
- with T.block("C"):
- yi = T.axis.spatial(4, y)
- T.reads(in0[yi, 0:16], in1[yi, 0:16])
- T.writes(out[yi, 0:4])
- for x in T.serial(4):
- xr = T.axis.reduce(4, x)
- with T.init():
- for i in T.serial(4):
- with T.block("C_init"):
- ii = T.axis.spatial(4, i)
- T.reads()
- T.writes(out[yi, ii])
- out[yi, ii] = 0
- with T.block("C"):
- T.reads(
- out[yi, xr],
- in0[yi, yi * 4 + xr : yi * 4 + xr + 4],
- in1[yi, yi * 4 + xr : yi * 4 + xr + 4],
- )
- T.writes(out[yi, xr])
- A = T.match_buffer(
- in0[yi, yi * 4 + xr : yi * 4 + xr + 4], [4], dtype="int8", offset_factor=1
- )
- B = T.match_buffer(
- in1[yi, yi * 4 + xr : yi * 4 + xr + 4], [4], dtype="int8", offset_factor=1
- )
- C = T.match_buffer(out[yi, xr], [1], dtype="int32", offset_factor=1)
- A_i8x4: T.int8x4 = A[0:4]
- A_i32: T.int32 = T.reinterpret(A_i8x4, dtype="int32")
- B_i8x4: T.int8x4 = B[0:4]
- B_i32: T.int32 = T.reinterpret(B_i8x4, dtype="int32")
- C[0] = A_i32 + B_i32 + C[0]
-
-
-@T.prim_func
-def nested_reduction_loop_with_outer_match_buffers(
- in0: T.Buffer[(4, 16), "int8"],
- in1: T.Buffer[(4, 16), "int8"],
- out: T.Buffer[(4, 4), "int32"],
-) -> None:
- # body
- # with T.block("root")
- for y in T.serial(4):
- with T.block("C"):
- yi = T.axis.spatial(4, y)
- T.reads(in0[yi, 0:16], in1[yi, 0:16])
- T.writes(out[yi, 0:4])
- A = T.match_buffer(in0[yi, 0:16], [16], dtype="int8", offset_factor=1)
- B = T.match_buffer(in1[yi, 0:16], [16], dtype="int8", offset_factor=1)
- C = T.match_buffer(out[yi, 0:4], [4], dtype="int32", offset_factor=1)
- for x in T.serial(4):
- xr = T.axis.reduce(4, x)
- with T.init():
- for i in T.serial(4):
- with T.block("C_init"):
- ii = T.axis.spatial(4, i)
- T.reads()
- T.writes(out[yi, ii])
- out[yi, ii] = 0
- with T.block("C"):
- T.reads(
- out[yi, xr],
- in0[yi, yi * 4 + xr : yi * 4 + xr + 4],
- in1[yi, yi * 4 + xr : yi * 4 + xr + 4],
- )
- T.writes(out[yi, xr])
- A_i8x4: T.int8x4 = A[yi * 4 + xr : yi * 4 + xr + 4]
- A_i32: T.int32 = T.reinterpret(A_i8x4, dtype="int32")
- B_i8x4: T.int8x4 = B[yi * 4 + xr : yi * 4 + xr + 4]
- B_i32: T.int32 = T.reinterpret(B_i8x4, dtype="int32")
- C[xr] = A_i32 + B_i32 + C[xr]
-
-
@T.prim_func
def reducer_max(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [128, 128], dtype="float32")
@@ -1329,20 +1247,6 @@ def test_single_reduction_loop_with_tensorize():
)
-def test_nested_reduction_loop_with_inner_match_buffers():
- _check(
- nested_reduction_loop_with_inner_match_buffers,
- nested_reduction_loop_with_inner_match_buffers,
- )
-
-
-def test_nested_reduction_loop_with_outer_match_buffers():
- _check(
- nested_reduction_loop_with_outer_match_buffers,
- nested_reduction_loop_with_outer_match_buffers,
- )
-
-
def test_reducer_max():
_check(reducer_max, lowered_reducer_max)
diff --git a/tests/python/unittest/test_tir_transform_remove_assume.py b/tests/python/unittest/test_tir_transform_remove_assume.py
index 4223e40e3f..a2d68a0757 100644
--- a/tests/python/unittest/test_tir_transform_remove_assume.py
+++ b/tests/python/unittest/test_tir_transform_remove_assume.py
@@ -17,8 +17,8 @@
import tvm
import tvm.testing
-from tvm.script import tir as T
from tvm import TVMError
+from tvm.script import tir as T
class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
@@ -31,7 +31,7 @@ class TestRemoveAssume(BaseBeforeAfter):
"""Remove any instance of T.assume"""
def before(A: T.Buffer[1, "int32"]):
- T.assume(A[0] == 5)
+ T.evaluate(T.assume(A[0] == 5))
A[0] = 10
def expected(A: T.Buffer[1, "int32"]):
@@ -43,7 +43,7 @@ class TestRemoveAssumeLoop(BaseBeforeAfter):
def before(A: T.Buffer[16, "int32"]):
for i in T.serial(16):
- T.assume(A[i] == 0)
+ T.evaluate(T.assume(A[i] == 0))
for i in T.serial(16):
A[i] = 10
diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py
index 2ec52bfbfe..32293cccdc 100644
--- a/tests/python/unittest/test_tvmscript_error_report.py
+++ b/tests/python/unittest/test_tvmscript_error_report.py
@@ -116,7 +116,6 @@ def test_missing_type_annotation():
def test_invalid_for_function():
def invalid_for_function(a: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float32")
-
for i in T.evaluate(0.0): # error
for j in T.serial(0, 16):
A[i, j] = 0.0
diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
index a3df5a183b..29e03f8bb6 100644
--- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
@@ -489,8 +489,8 @@ def test_ir_builder_tir_decl_buffer():
def test_ir_builder_tir_inline():
with IRBuilder() as ib:
- m, n = T.inline(1), T.inline(2)
- a, b = T.inline([3, 4])
+ m, n = T.meta_var(1), T.meta_var(2)
+ a, b = T.meta_var([3, 4])
T.evaluate(m.value + n.value + a.value + b.value)
# the evaluate generated by IRBuilder
eval_actual = ib.get()
diff --git a/tests/python/unittest/test_tvmscript_parser_evaluator.py b/tests/python/unittest/test_tvmscript_parser_evaluator.py
index 4d65903060..0f03e47ff9 100644
--- a/tests/python/unittest/test_tvmscript_parser_evaluator.py
+++ b/tests/python/unittest/test_tvmscript_parser_evaluator.py
@@ -17,8 +17,8 @@
"""Unittests for tvm.script.parser.evaluator"""
import pytest
import tvm.testing
-from tvm.script._parser.core.diagnostics import Source
-from tvm.script._parser.core.evaluator import ExprEvaluator
+from tvm.script.parser.core.diagnostics import Source
+from tvm.script.parser.core.evaluator import ExprEvaluator
def _calc(expr, extra_vars=None):
diff --git a/tests/python/unittest/test_tvmscript_parser_ir.py b/tests/python/unittest/test_tvmscript_parser_ir.py
index b235d85bb4..d3e758fbe1 100644
--- a/tests/python/unittest/test_tvmscript_parser_ir.py
+++ b/tests/python/unittest/test_tvmscript_parser_ir.py
@@ -19,7 +19,7 @@
import pytest
import inspect
import tvm.testing
-from tvm.script._parser import ir_module
+from tvm.script.parser import ir_module
from tvm.ir import IRModule
diff --git a/tests/python/unittest/test_tvmscript_parser_source.py b/tests/python/unittest/test_tvmscript_parser_source.py
index cb93a2dcf6..f5dc17fdfe 100644
--- a/tests/python/unittest/test_tvmscript_parser_source.py
+++ b/tests/python/unittest/test_tvmscript_parser_source.py
@@ -18,8 +18,8 @@
import pytest
import inspect
import tvm.testing
-from tvm.script._parser.core.diagnostics import Source
-from tvm.script._parser.core import doc_core as doc
+from tvm.script.parser.core.diagnostics import Source
+from tvm.script.parser.core import doc_core as doc
from tvm.script import tir as T
diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py b/tests/python/unittest/test_tvmscript_parser_tir.py
index cfa1dc62b3..e3f87928ac 100644
--- a/tests/python/unittest/test_tvmscript_parser_tir.py
+++ b/tests/python/unittest/test_tvmscript_parser_tir.py
@@ -19,7 +19,7 @@
import pytest
import inspect
import tvm.testing
-from tvm.script._parser import tir as T
+from tvm.script.parser import tir as T
from tvm import ir, tir
diff --git a/tests/python/unittest/test_tvmscript_spans.py b/tests/python/unittest/test_tvmscript_spans.py
index f863a4dd98..2c0522e3e3 100644
--- a/tests/python/unittest/test_tvmscript_spans.py
+++ b/tests/python/unittest/test_tvmscript_spans.py
@@ -16,7 +16,7 @@
# under the License.
-from tvm.script import tir as T
+from tvm.script.parser_v1 import tir as T
@T.prim_func
diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py
index 32572d392c..16f1cb0494 100644
--- a/tests/python/unittest/test_tvmscript_syntax_sugar.py
+++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py
@@ -288,8 +288,8 @@ def test_letstmt_bind_with_constant():
@T.prim_func
def constant_binds_wrapped():
- x = T.int32(1)
- y = T.float32(42.0)
+ x = T.meta_var(T.int32(1))
+ y = T.meta_var(T.float32(42.0))
T.evaluate(T.cast(x, "float32") + y)
assert_structural_equal(constant_binds, constant_binds_wrapped)
@@ -298,7 +298,7 @@ def test_letstmt_bind_with_constant():
def test_func_call():
def shared_16x16_to_ldmatrix_32x8_layout(i, j):
thread_id = (i % 8) * 4 + (j % 8) // 2
- return thread_id, (j // 8) * 4 + (i // 8) * 2 + (j % 2)
+ return T.meta_var((thread_id, (j // 8) * 4 + (i // 8) * 2 + (j % 2)))
@T.prim_func
def mma_sync_m16n16k16_desc(a: T.handle, b: T.handle, c: T.handle) -> None: