You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/09/17 22:55:08 UTC
[tvm] 01/02: IRBuilder
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch ir-builder-v2
in repository https://gitbox.apache.org/repos/asf/tvm.git
commit e6e27a4d179b36a492f83822c742739ea10f5a3c
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Mon Aug 15 12:33:04 2022 -0700
IRBuilder
---
include/tvm/tir/op.h | 33 +-
python/tvm/script/__init__.py | 6 +-
python/tvm/script/context_maintainer.py | 251 ----
python/tvm/script/diagnostics.py | 55 -
python/tvm/script/meta_unparser.py | 45 -
python/tvm/script/parser.py | 1392 --------------------
python/tvm/script/{ => parser}/__init__.py | 12 +-
python/tvm/script/{__init__.py => parser/_core.py} | 13 +-
python/tvm/script/{ => parser/core}/__init__.py | 7 +-
python/tvm/script/parser/core/diagnostics.py | 175 +++
python/tvm/script/parser/core/dispatch.py | 63 +
python/tvm/script/parser/core/doc.py | 361 +++++
.../script/{printer => parser/core}/doc_core.py | 0
.../{tir/prim_func.py => parser/core/entry.py} | 46 +-
python/tvm/script/parser/core/evaluator.py | 284 ++++
python/tvm/script/parser/core/parser.py | 273 ++++
.../{tir/prim_func.py => parser/core/utils.py} | 39 +-
python/tvm/script/{ => parser/ir}/__init__.py | 8 +-
.../script/{tir/__init__.py => parser/ir/entry.py} | 24 +-
.../{tir/__init__.py => parser/ir/parser.py} | 28 +-
python/tvm/script/{ => parser/tir}/__init__.py | 11 +-
python/tvm/script/parser/tir/entry.py | 101 ++
python/tvm/script/parser/tir/operation.py | 84 ++
python/tvm/script/parser/tir/parser.py | 268 ++++
python/tvm/script/registry.py | 62 -
python/tvm/script/tir/__init__.pyi | 487 -------
python/tvm/script/tir/intrin.py | 231 ----
python/tvm/script/tir/node.py | 218 ---
python/tvm/script/tir/scope_handler.py | 793 -----------
python/tvm/script/tir/special_stmt.py | 964 --------------
python/tvm/script/tir/ty.py | 216 ---
python/tvm/script/utils.py | 105 --
python/tvm/tir/analysis/analysis.py | 6 +-
python/tvm/tir/expr.py | 15 +-
python/tvm/tir/schedule/block_scope.py | 2 +-
python/tvm/tir/schedule/schedule.py | 6 +-
python/tvm/tir/schedule/state.py | 3 +-
python/tvm/tir/stmt.py | 4 +
python/tvm/tir/usmp/transform/transform.py | 5 +-
src/tir/ir/stmt.cc | 10 +
src/tir/op/op.cc | 22 +
41 files changed, 1793 insertions(+), 4935 deletions(-)
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 0939e25efd..63b7baaea7 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -527,7 +527,13 @@ TVM_DLL PrimExpr isnan(PrimExpr x, Span span = Span());
* \return The result expression.
*/
TVM_DLL PrimExpr isfinite(PrimExpr x, Span span = Span());
-
+/*!
+ * \brief Check if x is nullptr.
+ * \param x The input data
+ * \param span The location of this operation in the source.
+ * \return The result expression.
+ */
+TVM_DLL PrimExpr isnullptr(PrimExpr x, Span span = Span());
/*!
* \brief Check if x is infinite.
* \param x The input data
@@ -601,6 +607,15 @@ TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr>
TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
Span span = Span());
+/*!
+ * \brief Calculate fmod(x, y)
+ * \param x Left operand.
+ * \param y Right operand.
+ * \param span The location of this operation in the source.
+ * \return The result expression.
+ */
+TVM_DLL PrimExpr fmod(PrimExpr x, PrimExpr y, Span span = Span());
+
/*!
* \brief Calculate floor(x)
* \param x The input expression.
@@ -675,6 +690,22 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span sp
TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s,
Span span = Span());
+/*!
+ * \brief Returns the address of an element in the buffer
+ * \param buffer_load The input BufferLoad.
+ * \param span The location of this operation in the source.
+ * \return The address of an element in the buffer.
+ */
+TVM_DLL PrimExpr address_of(tir::BufferLoad buffer_load, Span span = Span());
+
+/*!
+ * \brief Returns the param by name
+ * \param param_name The param name.
+ * \param span The location of this operation in the source.
+ * \return The handle of param.
+ */
+TVM_DLL PrimExpr lookup_param(String param_name, Span span = Span());
+
// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py
index 555659d0c5..3107ada88b 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/__init__.py
@@ -15,7 +15,5 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
-from . import tir
-
-from .parser import ir_module, from_source
+from . import parser
+from .parser import ir, ir_module, parse, tir
diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py
deleted file mode 100644
index f7f16855c7..0000000000
--- a/python/tvm/script/context_maintainer.py
+++ /dev/null
@@ -1,251 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Context Maintainer for TIR"""
-
-from typing import List, Mapping, Union, Optional, Dict, Callable
-import synr
-
-
-import tvm
-from tvm.ir import Span
-from tvm.ir.expr import Range
-from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
-from tvm.runtime import Object
-from tvm.tir.expr import IterVar
-from .tir.node import BufferSlice
-
-
-class BlockInfo:
- """Information for block and block_realize signature
-
- Examples
- ----------
- .. code-block:: python
-
- @T.prim_func
- def example_func(a: T.handle, b: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (16, 16), "float32")
- B = T.match_buffer(b, (16, 16), "float32")
- C = T.match_buffer(a, (16, 16), "float32")
-
- for i, j, k in T.grid(16, 16, 16):
- with T.block("matmul"):
- vi = T.axis.S(16, i)
- vj = T.axis.S(16, j)
- vk = T.axis.R(16, k) # iter_bindings = {vj: i, vj: j, vk: k}
-
- T.where(True) # predicate of the block_realize
-
- T.reads(A[0:16, 0:16], B[0: 16, 0: 16]) # reads region of the block
- T.writes(C[0: 16, 0: 16]) # writes region of the block
- T.block_attr({"attr_key": "attr_value"}) # block annotations
-
- # alloc_buffers inside the block
- CC = T.alloc_buffer((1, 1), dtype="float32")
-
- # match_buffers of the block,
- # which bind a sub-region of source buffer into a new buffer
- D = T.match_buffer(C[vi, vj], ())
-
- # init part of the block, executed when all reduce axes are the beginning value
- with T.init():
- C[vi, vj] = T.float32(0)
-
- # block body
- CC[0, 0] = A[vi, vk] * B[vj, vk]
- D[()] += CC[0, 0] # The same as C[vi, vj] += CC[0, 0]
- """
-
- alloc_buffers: List[Buffer] = []
- """List[Buffer]: list of T.alloc_buffer statements in the block signature"""
- match_buffers: List[MatchBufferRegion] = []
- """List[MatchBufferRegion]: list of T.match_buffer statements in the block signature"""
- iter_values: List[PrimExpr] = []
- """List[PrimExpr]: list of binding values for iter vars"""
- iter_vars: List[IterVar] = []
- """List[PrimExpr]: list of iter vars in the block"""
- reads: Optional[List[BufferSlice]] = None
- """Optional[List[BufferSlice]]:
- list of T.reads statements in the block signature, None for not-visited"""
- writes: Optional[List[BufferSlice]] = None
- """Optional[List[BufferSlice]]:
- list of T.writes statements in the block signature, None for not-visited"""
- annotations: Optional[Mapping[str, Object]] = None
- """Optional[Mapping[str, Object]]:
- list of T.block_attr statements in the block signature, None for not-visited"""
- predicate: Optional[PrimExpr] = None
- """Optional[PrimExpr]: block realize predicate, None for not-visited"""
- init: Optional[Stmt] = None
- """Optional[Stmt]: init part of the block, None for not-visited"""
-
- def __init__(self):
- self.alloc_buffers = []
- self.match_buffers = []
- self.iter_values = []
- self.iter_vars = []
- self.reads = None
- self.writes = None
- self.annotations = None
- self.predicate = None
- self.init = None
-
-
-class ContextMaintainer:
- """Maintain all the necessary context info
- Parameters
- ----------
- _report_error : Callable[[str, Union[Span, synr.ast.Span]], None]
- The report error function handle
- """
-
- # scope context
- node_stack: List[List[synr.ast.Node]] = []
- """List[List[synr.ast.Node]]: The ast nodes insides the current scope"""
- block_info_stack: List[BlockInfo] = []
- """List[BlockInfo]: The block info for the current block scope"""
- loop_stack: Dict[Var, Range] = {}
- """Dict[Var, Range]: The dict from loop var to its domain outside the block"""
- symbols: List[Dict[str, Union[Var, Buffer]]] = []
- """List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope"""
- closure_vars: Dict[str, Object] = {}
- """ClosureVars: The closure vars defined in Python interpreter"""
-
- # function context
- func_params: List[Var] = []
- """List[Var]: The function parameters"""
- func_buffer_map: Mapping[Var, Buffer] = {}
- """Mapping[Var, Buffer]: The function buffer map"""
- func_preflattened_buffer_map: Mapping[Var, Buffer] = {}
- """Mapping[Var, Buffer]: The function buffer map, prior to any flattening."""
- func_dict_attr: Mapping[str, Object] = {}
- """Mapping[str, Object]: The function attrs"""
- func_var_env_dict: Mapping[Var, str] = {}
- """Mapping[Var, str]: The map from var to env thread"""
-
- # parser and analyzer
- analyzer: tvm.arith.Analyzer = tvm.arith.Analyzer()
- """tvm.arith.Analyzer: The analyzer for simplifying"""
- _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
- """Callable[[str, Union[Span, synr.ast.Span]], None]: The report error function handle"""
-
- # root alloc_buffer
- root_alloc_buffers: List[Buffer] = []
- """List[Buffer]: The buffers allocated under root block"""
-
- def __init__(
- self,
- _report_error: Callable[[str, Union[Span, synr.ast.Span]], None],
- closure_vars: Dict[str, Object],
- ):
- # scope context
- self.node_stack = []
- self.block_info_stack = []
- self.loop_stack = {}
- self.symbols = []
- self.closure_vars = closure_vars
- # function context
- self.func_params = []
- self.func_buffer_map = {}
- self.func_preflattened_buffer_map = {}
- self.func_dict_attr = {}
- self.func_var_env_dict = {}
- # parser and analyzer
- self._report_error = _report_error
- self.analyzer = tvm.arith.Analyzer()
- # root alloc_buffer
- self.root_alloc_buffers = []
-
- def enter_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
- """Creates a new scope
-
- Note
- ----
- This function is used for normal scopes that do not involve
- a `with block` scope. Use `enter_block_scope`
- for block scope cases.
-
- Parameters
- ----------
- nodes : Optional[List[synr.ast.Node]]
- The synr AST nodes in new scope
- """
- if nodes is None:
- nodes = []
- self.node_stack.append(list(reversed(nodes)))
- self.symbols.append(dict())
-
- def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
- """Creates a new block scope, the function will call `enter_scope` implicitly
- Besides the behaviors of `enter_scope`, it will update loop_stack and block_info_stack
- to maintain block info.
-
- Note
- ----
- This function should be used to handle a block scope,
- aka the blocks that involve a `with block` scope.
-
- Parameters
- ----------
- nodes : Optional[List[synr.ast.Node]]
- The synr AST nodes in new scope
- """
- self.enter_scope(nodes)
- # Create a new BlockInfo for the new block
- self.block_info_stack.append(BlockInfo())
-
- def exit_scope(self):
- """Pop the inner most scope"""
- self.symbols.pop()
- self.node_stack.pop()
-
- def exit_block_scope(self):
- """Pop the inner most block scope, the function will call `exit_scope` implicitly"""
- self.exit_scope()
- # Pop block_info
- self.block_info_stack.pop()
-
- def update_symbol(self, name: str, symbol: Union[Buffer, Var], node: synr.ast.Node):
- """Append a symbol into current scope"""
- if isinstance(symbol, Buffer):
- if name in self.symbols[0]:
- self.report_error("Duplicate Buffer name: " + symbol.name, node.span)
- self.symbols[0][name] = symbol
- else:
- self.symbols[-1][name] = symbol
-
- def remove_symbol(self, name: str):
- """Remove a symbol"""
- for symbols in reversed(self.symbols):
- if name in symbols:
- symbols.pop(name)
- return
- raise RuntimeError("Internal error of tvm script parser: no symbol named " + name)
-
- def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]:
- """Look up symbol by name"""
- for symbols in reversed(self.symbols):
- if name in symbols:
- return symbols[name]
- return self.closure_vars.get(name)
-
- def report_error(self, message: str, span: Union[Span, synr.ast.Span]):
- self._report_error(message, span)
-
- def current_block_scope(self) -> BlockInfo:
- if self.block_info_stack:
- return self.block_info_stack[-1]
- return None
diff --git a/python/tvm/script/diagnostics.py b/python/tvm/script/diagnostics.py
deleted file mode 100644
index e676461ab3..0000000000
--- a/python/tvm/script/diagnostics.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Bridge from synr's (the library used for parsing the python AST)
- DiagnosticContext to TVM's diagnostics
-"""
-from synr import DiagnosticContext, ast
-
-import tvm
-from tvm.ir.diagnostics import DiagnosticContext as TVMCtx
-from tvm.ir.diagnostics import get_renderer, DiagnosticLevel, Diagnostic
-
-
-class TVMDiagnosticCtx(DiagnosticContext):
- """TVM diagnostics for synr"""
-
- diag_ctx: TVMCtx
-
- def __init__(self) -> None:
- self.diag_ctx = TVMCtx(tvm.IRModule(), get_renderer())
- self.source_name = None
-
- def to_tvm_span(self, src_name, ast_span: ast.Span) -> tvm.ir.Span:
- return tvm.ir.Span(
- src_name,
- ast_span.start_line,
- ast_span.end_line,
- ast_span.start_column,
- ast_span.end_column,
- )
-
- def add_source(self, name: str, source: str) -> None:
- src_name = self.diag_ctx.module.source_map.add(name, source)
- self.source_name = src_name
-
- def emit(self, _level, message, span):
- span = self.to_tvm_span(self.source_name, span)
- self.diag_ctx.emit(Diagnostic(DiagnosticLevel.ERROR, span, message))
- self.diag_ctx.render() # Raise exception on the first error we hit. TODO remove
-
- def render(self):
- self.diag_ctx.render()
diff --git a/python/tvm/script/meta_unparser.py b/python/tvm/script/meta_unparser.py
deleted file mode 100644
index b1472ccdc7..0000000000
--- a/python/tvm/script/meta_unparser.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Unparse meta AST node into a dict"""
-# pylint: disable=invalid-name
-
-from synr import Transformer
-
-
-class MetaUnparser(Transformer):
- """Python AST Visitor to unparse meta AST node into a dict"""
-
- def transform(self, node):
- method = "transform_" + node.__class__.__name__
- visitor = getattr(self, method, None)
- if visitor is None:
- self.error(f"Unexpected node type {type(node)} when parsing __tvm_meta__", node.span)
- return visitor(node)
-
- def transform_DictLiteral(self, node):
- keys = [self.visit(key) for key in node.keys]
- values = [self.visit(value) for value in node.values]
- return dict(zip(keys, values))
-
- def transform_Tuple(self, node):
- return tuple(self.visit(element) for element in node.elts)
-
- def transform_ArrayLiteral(self, node):
- return [self.visit(element) for element in node.elts]
-
- def transform_Constant(self, node):
- return node.value
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
deleted file mode 100644
index c34aae2345..0000000000
--- a/python/tvm/script/parser.py
+++ /dev/null
@@ -1,1392 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser For TIR
-
-We use [synr](https://synr.readthedocs.io) to get an AST that is stable over
-different python versions. Synr also provides an error handling context that we
-use for error reporting.
-"""
-# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except
-import types
-import json
-import operator
-import inspect
-from typing import Any, Callable, Dict, List, Optional, Union
-from synr import ast, Transformer, to_ast
-
-import tvm
-from tvm import IRModule
-from tvm._ffi.base import TVMError
-from tvm.ir import GlobalVar
-from tvm.ir.function import BaseFunc
-from tvm.tir import buffer
-from tvm.tir.function import PrimFunc
-from . import _ffi_api
-from . import tir
-
-from .context_maintainer import ContextMaintainer
-from .meta_unparser import MetaUnparser
-from .registry import Registry
-from .diagnostics import TVMDiagnosticCtx
-from .utils import tvm_span_from_synr, synr_span_from_tvm, call_with_error_reporting
-
-from .tir.intrin import Intrin
-from .tir.node import Slice, BufferSlice
-from .tir.scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler
-from .tir.special_stmt import SpecialStmt
-from .tir import ty
-
-
-class CallArgumentReader(object):
- """Helper class to read required arguments from passed arguments.
-
- When parsing a function call, we need to match the arguments provided in
- the AST to the required arguments of the function. This class makes sure
- all the positional arguments are filled and also fill keyword arguments
- with thier default value if a different value was not provided.
- """
-
- def __init__(self, func_name, args, kwargs, parser, node):
- self.func_name = func_name
- self.args = args
- self.kwargs = kwargs
- self.parser = parser
- self.node = node
-
- def get_pos_only_arg(self, pos, name):
- """Get corresponding position only function argument from argument list"""
- if len(self.args) >= pos:
- arg = self.args[pos - 1]
- elif name not in self.kwargs:
- # If no positional argument was found in the AST, we see if it was
- # defined by name instead.
- # TODO(tkonolige): this error message is not quite correct. The
- # number of required arguments is >= pos
- self.parser.report_error(
- f"{self.func_name} requires {pos} arguments, but only {len(self.args)} were given.",
- self.node.span,
- )
- else:
- arg = self.kwargs[name]
-
- return arg
-
- def get_kwarg(self, pos, name, default):
- """Get corresponding keyword function argument from argument list.
-
- If the user hasn't provided the argument, set it to the default value.
- """
- if len(self.args) >= pos:
- arg = self.args[pos - 1]
- elif name in self.kwargs:
- arg = self.kwargs[name]
- else:
- return default
-
- return arg
-
- def get_varargs(self, pos):
- """Get corresponding variable argument from argument list"""
- if len(self.args) >= pos and len(self.kwargs) == 0:
- return self.args[pos - 1 :]
- return []
-
-
-class TVMScriptParser(Transformer):
- """Synr AST visitor pass which finally lowers to TIR.
-
- Notes for Extension
- -------------------
- 1. To support a new type of AST node, add a function transform_xxx().
- 2. To support new functions, add the function to the appropriate registry:
- We divide allowed function calls in TVM script into 3 categories,
- intrin, scope_handler and special_stmt.
- 1. intrin functions are low level functions like mod, load, and
- constants. They correspond to a tir `IRNode`. They must have a
- return value. The user can register intrin functions for the parser to
- use.
- 2. scope_handler functions have no return value. They take two
- arguments: the parser and the AST node. scope_handler functions are
- used in with and for statements.
- 3. special_stmt functions handle cases that do not have a corresponding
- tir `IRNode`. These functions take the parser and the AST node as
- arguments and may return a value.
- When visiting a Call node, we check the special_stmt registry first. If
- no registered function is found, we then check the intrin registry.
- When visiting With node, we check the with_scope registry.
- When visiting For node, we check the for_scope registry.
- """
-
- _binop_maker = {
- ast.BuiltinOp.Add: tvm.tir.Add,
- ast.BuiltinOp.Sub: tvm.tir.Sub,
- ast.BuiltinOp.Mul: tvm.tir.Mul,
- ast.BuiltinOp.Div: tvm.tir.Div,
- ast.BuiltinOp.FloorDiv: tvm.tir.FloorDiv,
- ast.BuiltinOp.Mod: tvm.tir.FloorMod,
- ast.BuiltinOp.BitOr: lambda lhs, rhs, span: operator.or_(lhs, rhs),
- ast.BuiltinOp.BitAnd: lambda lhs, rhs, span: operator.and_(lhs, rhs),
- ast.BuiltinOp.BitXor: lambda lhs, rhs, span: operator.xor(lhs, rhs),
- ast.BuiltinOp.GT: tvm.tir.GT,
- ast.BuiltinOp.GE: tvm.tir.GE,
- ast.BuiltinOp.LT: tvm.tir.LT,
- ast.BuiltinOp.LE: tvm.tir.LE,
- ast.BuiltinOp.Eq: tvm.tir.EQ,
- ast.BuiltinOp.NotEq: tvm.tir.NE,
- ast.BuiltinOp.And: tvm.tir.And,
- ast.BuiltinOp.Or: tvm.tir.Or,
- }
-
- _unaryop_maker = {
- ast.BuiltinOp.USub: lambda rhs, span: operator.neg(rhs),
- ast.BuiltinOp.Invert: lambda rhs, span: operator.invert(rhs),
- ast.BuiltinOp.Not: tvm.tir.Not,
- }
-
- # pylint gets confused here with synr.Transformer which doesn't have a
- # custom init, so just disable it
- def __init__(
- self, base_lineno, tir_namespace, closure_vars
- ): # pylint: disable=super-init-not-called
- self.context = None
-
- self.base_lineno = base_lineno
- self.current_lineno = 0
- self.current_col_offset = 0
- self.tir_namespace = tir_namespace
- self.closure_vars = closure_vars
- self.meta = None
- self._inside_buffer_sugar = False
-
- def init_function_parsing_env(self):
- """Initialize function parsing environment"""
- self.context = ContextMaintainer(self.report_error, self.closure_vars) # scope emitter
-
- def init_meta(self, meta_dict):
- if meta_dict is not None:
- self.meta = tvm.ir.load_json(json.dumps(meta_dict))
-
- def transform(self, node):
- """Generic transformation for visiting the AST. Dispatches to
- `transform_ClassName` for the appropriate ClassName."""
- old_lineno, old_col_offset = self.current_lineno, self.current_col_offset
-
- if hasattr(node, "lineno"):
- self.current_lineno = self.base_lineno + node.lineno - 1
- if hasattr(node, "col_offset"):
- self.current_col_offset = node.col_offset
-
- method = "transform_" + node.__class__.__name__
- visitor = getattr(self, method, self.generic_visit)
- transform_res = visitor(node)
-
- self.current_lineno, self.current_col_offset = old_lineno, old_col_offset
-
- return transform_res
-
- def match_tir_namespace(self, identifier: str) -> bool:
- """Check if the namespace is equal to tvm.script.tir"""
- return identifier in self.tir_namespace
-
- def report_error(self, message: str, span: Union[ast.Span, tvm.ir.Span]):
- """Report an error occuring at a location.
-
- This just dispatches to synr's DiagnosticContext.
-
- Parameters
- ----------
- message : str
- Error message
- span : Union[synr.ast.Span, tvm.ir.Span]
- Location of the error
- """
- if isinstance(span, tvm.ir.Span):
- span = synr_span_from_tvm(span)
- self.error(message, span)
-
- def parse_body(self, parent):
- """Parse remaining statements in this scope.
-
- Parameters
- ----------
- parent : synr.ast.Node
- Parent node of this scope. Errors will be reported here.
- """
- body = []
- spans = []
- stmt = parent
- while len(self.context.node_stack[-1]) > 0:
- stmt = self.context.node_stack[-1].pop()
- spans.append(stmt.span)
- res = self.transform(stmt)
- if res is not None:
- body.append(res)
- if len(body) == 0:
- self.report_error(
- "Expected another statement at the end of this block. Perhaps you "
- "used a concise statement and forgot to include a body afterwards.",
- stmt.span,
- )
- else:
- return (
- tvm.tir.SeqStmt(body, tvm_span_from_synr(ast.Span.union(spans)))
- if len(body) > 1
- else body[0]
- )
-
- def parse_arg_list(self, func, node_call):
- """Match the arguments of a function call in the AST to the required
- arguments of the function. This handles positional arguments,
- positional arguments specified by name, keyword arguments, and varargs.
-
- Parameters
- ----------
- func : Function
- The function that provides the signature
-
- node_call: Union[ast.Call, ast.TypeApply, ast.TypeCall]
- The AST call node that calls into the function.
-
- Returns
- -------
- arg_list : list
- The parsed positional argument.
- """
- assert isinstance(node_call, (ast.Call, ast.TypeApply, ast.TypeCall))
- # collect arguments
- args = [self.transform(arg) for arg in node_call.params]
- if isinstance(node_call, ast.TypeApply):
- kw_args = {} # TypeApply (e.g. foo[bar]) doesn't have kwargs defined in synr
- else:
- kw_args = {
- self.transform(k): self.transform(v) for k, v in node_call.keyword_params.items()
- }
- # get the name and parameter list of func
- if isinstance(func, (Intrin, ScopeHandler, SpecialStmt)):
- func_name, param_list = func.signature()
- else:
- self.report_error(
- "Internal Error: function must be of type Intrin, ScopeHandler or SpecialStmt, "
- f"but it is {type(func).__name__}",
- node_call.span,
- )
- # check arguments and parameter list and get a list of arguments
- reader = CallArgumentReader(func_name, args, kw_args, self, node_call)
- pos_only, kwargs, varargs = param_list
- internal_args = list()
-
- for i, arg_name in enumerate(pos_only):
- internal_args.append(reader.get_pos_only_arg(i + 1, arg_name))
- for i, arg_info in enumerate(kwargs):
- arg_name, default = arg_info
- internal_args.append(reader.get_kwarg(i + 1 + len(pos_only), arg_name, default=default))
- if varargs is not None:
- internal_args.extend(reader.get_varargs(len(pos_only) + len(kwargs) + 1))
- elif len(args) + len(kw_args) > len(pos_only) + len(kwargs):
- self.report_error(
- "Arguments mismatched. "
- + f"Expected {len(pos_only) + len(kwargs)} args but got "
- + f"{len(args) + len(kw_args)}",
- node_call.span,
- )
- return internal_args
-
- def parse_type(self, type_node, parent):
- """Parse a type annotation.
-
- We require the parent object to the type so that we have a place to
- report the error message if the type does not exist.
- """
- if type_node is None:
- self.report_error("A type annotation is required", parent.span)
- res_type = self.transform(type_node)
- return tvm.ir.TupleType([]) if res_type is None else res_type.evaluate()
-
- def generic_visit(self, node):
- """Fallback visitor if node type is not handled. Reports an error."""
-
- self.report_error(type(node).__name__ + " AST node is not supported", node.span)
-
- def transform_Module(self, node):
- """Module visitor
-
- Right now, we only support two formats for TVM Script.
-
- Example
- -------
- 1. Generate a PrimFunc (If the code is printed, then it may also contain metadata)
- .. code-block:: python
-
- import tvm
-
- @tvm.script
- def A(...):
- ...
-
- # returns a PrimFunc
- func = A
-
- 2. Generate an IRModule
- .. code-block:: python
-
- import tvm
-
- @tvm.script.ir_module
- class MyMod():
- @T.prim_func
- def A(...):
- ...
- @T.prim_func
- def B(...):
- ...
-
- __tvm_meta__ = ...
-
- # returns an IRModule
- mod = MyMod
- """
- if len(node.funcs) == 1:
- return self.transform(next(iter(node.funcs.values())))
- elif len(node.funcs) == 0:
- self.report_error(
- "You must supply at least one class or function definition", node.span
- )
- else:
- self.report_error(
- "Only one-function, one-class or function-with-meta source code is allowed",
- ast.Span.union([x.span for x in list(node.funcs.values())[1:]]),
- )
-
- def transform_Class(self, node):
- """Class definition visitor.
-
- A class can have multiple function definitions and a single
- :code:`__tvm_meta__` statement. Each class corresponds to a single
- :code:`IRModule`.
-
- Example
- -------
- .. code-block:: python
-
- @tvm.script.ir_module
- class MyClass:
- __tvm_meta__ = {}
- def A():
- T.evaluate(0)
- """
- if len(node.assignments) == 1:
- if not (
- len(node.assignments[0].lhs) == 1
- and isinstance(node.assignments[0].lhs[0], ast.Var)
- and node.assignments[0].lhs[0].id.name == "__tvm_meta__"
- ):
- self.report_error(
- "The only top level assignments allowed are `__tvm_meta__ = ...`",
- node.assignments[0].span,
- )
- self.init_meta(
- MetaUnparser().do_transform(node.assignments[0].rhs, self._diagnostic_context)
- )
- elif len(node.assignments) > 1:
- self.report_error(
- "Only a single top level `__tvm_meta__` is allowed",
- ast.Span.union([x.span for x in node.assignments[1:]]),
- )
-
- return IRModule(
- {GlobalVar(name): self.transform(func) for name, func in node.funcs.items()}
- )
-
- def transform_Function(self, node):
- """Function definition visitor.
-
- Each function definition is translated to a single :code:`PrimFunc`.
-
- There are a couple restrictions on TVM Script functions:
- 1. Function arguments must have their types specified.
- 2. The body of the function can contain :code:`func_attr` to specify
- attributes of the function (like it's name).
- 3. The body of the function can also contain multiple :code:`buffer_bind`s,
- which give shape and dtype information to arguments.
- 4. Return statements are implicit.
-
- Example
- -------
- .. code-block:: python
-
- @T.prim_func
- def my_function(x: T.handle): # 1. Argument types
- T.func_attr({"global_symbol": "mmult"}) # 2. Function attributes
- X_1 = tir.buffer_bind(x, [1024, 1024]) # 3. Buffer binding
- T.evaluate(0) # 4. This function returns 0
- """
-
- def check_as_torch_decorator(decorator: Union[ast.Call, ast.Var]):
- if isinstance(decorator, ast.Call):
- if len(decorator.params) != 1:
- return False
- func_name = decorator.func_name
- else:
- func_name = decorator
- if isinstance(func_name, ast.Var):
- return func_name.id.name == "as_torch"
-
- def check_decorator(decorators: List[ast.Expr]) -> bool:
- """Check the decorator is `T.prim_func"""
- if len(decorators) > 2 or len(decorators) == 0:
- return False
- if len(decorators) == 2 and not check_as_torch_decorator(decorators[0]):
- return False
- d: ast.Expr = decorators[-1]
- return (
- isinstance(d, ast.Attr)
- and isinstance(d.object, ast.Var)
- and self.match_tir_namespace(d.object.id.name)
- and d.field.name == "prim_func"
- )
-
- self.init_function_parsing_env()
- self.context.enter_scope(nodes=node.body.stmts)
-
- # add parameters of function
- for arg in node.params:
- # Note that this case is for T.match_buffer syntax sugar
- if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)) and isinstance(
- self.transform(arg.ty.func_name), ty.GenericBufferType
- ):
- result = self.handle_match_buffer_type(arg.ty, arg.name)
- if not isinstance(result, buffer.Buffer):
- self.report_error(
- "The result type of evaluating TypeCall and TypeApply stmt"
- f" is wrong: {type(result)}. It should be a Buffer",
- node.span,
- )
- arg_name_with_handle = arg.name + "_handle"
- arg_var = tvm.te.var(arg_name_with_handle, tvm.ir.PrimType("handle"))
- self.context.func_buffer_map[arg_var] = result
- self.context.update_symbol(arg.name, result, node)
- else:
- arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
- self.context.update_symbol(arg.name, arg_var, node)
- self.context.func_params.append(arg_var)
-
- if not check_decorator(node.decorators):
- self.report_error(
- "All functions should be decorated by `T.prim_func`",
- node.span,
- )
-
- # fetch the body of root block
- body = self.parse_body(node.body)
-
- # return a tir.PrimFunc
- dict_attr = self.context.func_dict_attr
- ret_type = self.parse_type(node.ret_type, node) if node.ret_type is not None else None
- func = tvm.tir.PrimFunc(
- self.context.func_params,
- body,
- ret_type,
- buffer_map=self.context.func_buffer_map,
- preflattened_buffer_map=self.context.func_preflattened_buffer_map,
- attrs=tvm.ir.make_node("DictAttrs", **dict_attr) if dict_attr else None,
- span=tvm_span_from_synr(node.span),
- )
-
- # New Scope : Implicit root block
- # Each function contains an implicit root block in TensorIR,
- # so here we need a block scope for it.
- # If the PrimFunc is not a TensorIR func (e.g. TE scheduled func or low-level func),
- # the root block will not be added. The logic to add root block is in `_ffi_api.Complete`
-
- # Fix the PrimFunc
- # 1. generate root block if necessary
- # 2. generate surrounding loops for blocks if necessary
-
- func = call_with_error_reporting(
- self.report_error,
- node.span,
- _ffi_api.Complete,
- func,
- self.context.root_alloc_buffers,
- )
-
- self.context.exit_scope()
- return func
-
- def transform_Lambda(self, node):
- """Lambda visitor
-
- Return an array of input parameters and the transformed lambda body.
- """
-
- self.context.enter_scope(nodes=[node.body])
-
- # add parameters of the lambda
- arg_vars = []
- for arg in node.params:
- # Use "void" for dtype here. The actual type is not yet known and will be
- # determined later. Using void type will allow IRSubstitute to do the
- # replacement without flagging a type-mismatch error.
- arg_var = tvm.te.var(arg.name, dtype="")
- arg_vars.append(arg_var)
- self.context.update_symbol(arg.name, arg_var, node)
-
- # the body of a lambda must be an expr
- if not isinstance(node.body, ast.Expr):
- self.report_error("The body of a lambda must be an expression", node.span)
-
- # transform the body of the lambda
- body = self.transform(node.body)
-
- self.context.exit_scope()
- return arg_vars, body
-
- def transform_Assign(self, node):
- """Assign visitor
- AST abstract grammar:
- Assign(expr* targets, expr value, string? type_comment)
-
- By now 5 patterns of Assign is supported:
- 1. special stmts with return value
- 1.1 Buffer = T.match_buffer()/T.buffer_decl()
- 1.2 Var = T.var()
- 1.3 Var = T.env_thread()
- 2. (BufferStore) Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr
- 3. (Store) Var[PrimExpr] = PrimExpr
- 4. with scope handlers with concise scoping and var def
- 4.1 var = T.allocate()
- 5. A call to a pure python function, consuming and producing TVMScript values.
- The outputs are inlined into the following body (no variable is created).
- x, y = f(...)
- """
-
- if isinstance(node.rhs, ast.Call):
- # Pattern 1 & Pattern 4
- if isinstance(node.rhs.func_name, ast.Op):
- func = None
- else:
- func = self.transform(node.rhs.func_name)
-
- if isinstance(func, WithScopeHandler):
- if not func.concise_scope or not func.def_symbol:
- self.report_error(
- "with scope handler " + func.signature()[0] + " is not suitable here",
- node.rhs.span,
- )
- # Pattern 4
- arg_list = self.parse_arg_list(func, node.rhs)
- func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
- func.body = self.parse_body(node)
- return func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span)
- elif isinstance(func, SpecialStmt):
- # Pattern 1
- arg_list = self.parse_arg_list(func, node.rhs)
- func.handle(node, self.context, arg_list, node.rhs.func_name.span)
- return self.parse_body(node)
- elif isinstance(func, types.FunctionType):
- # Pattern 5
- args = [self.transform(arg) for arg in node.rhs.params]
- try:
- out = func(*args)
- except Exception as e:
- self.report_error(
- "Error occurred when invoking the function "
- + func.__name__
- + ": \n"
- + str(e),
- node.rhs.span,
- )
-
- if len(node.lhs) == 1 and not isinstance(out, list):
- out = [out]
-
- assert len(out) == len(node.lhs)
-
- for var, value in zip(node.lhs, out):
- self.context.update_symbol(var.id.name, value, node)
-
- body = self.parse_body(node)
-
- for var, value in zip(node.lhs, out):
- self.context.remove_symbol(var.id.name)
-
- return body
-
- if isinstance(node.rhs, (ast.Call, ast.Constant)):
- # Pattern 4 of let binding
- value = self.transform(node.rhs)
- if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var):
- # This is a little confusing because it only is true when
- # we have taken this branch. We might need to clarify what
- # exectly is allowed in Assignments in tvmscript.
- self.report_error(
- "Left hand side of assignment must be an unqualified variable",
- node.span,
- )
- ast_var = node.lhs[0]
-
- if node.ty is None and hasattr(value, "dtype"):
- var_ty = value.dtype
- else:
- var_ty = self.parse_type(node.ty, ast_var)
-
- var = tvm.te.var(
- ast_var.id.name,
- var_ty,
- span=tvm_span_from_synr(ast_var.span),
- )
- self.context.update_symbol(var.name, var, node)
- body = self.parse_body(node)
- self.context.remove_symbol(var.name)
- return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))
-
- self.report_error(
- """Assignments should be one of:
- 1. A "special statement" with return value
- 1.1 Buffer = T.match_buffer()/T.buffer_decl()
- 1.2 Var = T.var()
- 1.3 Var = T.env_thread()
- 2. A store into a buffer: Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr
- 3. A store into a variable: Var[PrimExpr] = PrimExpr
- 4. A with scope handler with concise scoping and var def
- 4.1 var = T.allocate()
- 5. The right-hand side being a call to a pure python function, consuming and
- producing TVMScript values.
- x, y = f(...)""",
- node.span,
- )
-
- def transform_SubscriptAssign(self, node):
- """Visitor for statements of the form :code:`x[1] = 2`."""
- symbol = self.transform(node.params[0])
- indexes = self.transform(node.params[1])
- rhs = self.transform(node.params[2])
- rhs_span = tvm_span_from_synr(node.params[2].span)
- if isinstance(symbol, tvm.tir.Buffer):
- if len(indexes) != len(symbol.shape):
- self.report_error(
- f"Buffer {symbol.name} is {len(symbol.shape)}-dimensional, "
- f"cannot be indexed by {len(indexes)}-dimensional indices.",
- node.params[1].span,
- )
-
- def __convert_index(x):
- if isinstance(x, Slice):
- return x.as_index_expr(self.report_error)
- return x
-
- # BufferStore
- indexes = [__convert_index(x) for x in indexes]
- return tvm.tir.BufferStore(
- symbol,
- tvm.runtime.convert(rhs, span=rhs_span),
- indexes,
- span=tvm_span_from_synr(node.span),
- )
- else:
- if symbol.dtype == "handle" and len(indexes) != 1:
- self.report_error(
- "Handles only support one-dimensional indexing. Use `T.match_buffer` to "
- "construct a multidimensional buffer from a handle.",
- node.params[0].span,
- )
- if len(indexes) != 1:
- self.report_error(
- f"Store is only allowed with one index, but {len(indexes)} were provided.",
- node.params[1].span,
- )
- self.report_error(
- "Use of tir.Store has been deprecated in favor of tir.BufferStore.", node.span
- )
-
- def transform_AttrAssign(self, node):
- """Visitor for statements of the form :code:`x.y = 2`."""
- obj = self.transform(node.params[0])
- field = node.params[1]
- value = self.transform(node.params[2])
-
- if not hasattr(obj, field.name):
- self.error(f"Field {field.name} does not exist", field.span)
-
- var = getattr(obj, field.name)
-
- if not isinstance(var, tvm.tir.Var):
- self.error(
- f"Can only assign to tir.Var attributes, not {type(var).__name__}", node.span
- )
-
- body = self.parse_body(node)
- return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))
-
- def transform_Assert(self, node):
- """Assert visitor
-
- Pattern corresponds to concise mode of :code:`with T.Assert()`.
- """
-
- condition = self.transform(node.condition)
- if node.msg is None:
- self.report_error("Assert statements must have an error message.", node.span)
- message = self.transform(node.msg)
- body = self.parse_body(node)
- return tvm.tir.AssertStmt(
- condition, tvm.runtime.convert(message), body, span=tvm_span_from_synr(node.span)
- )
-
- def transform_For(self, node):
- """For visitor
- AST abstract grammar:
- For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment)
- By now 1 pattern of For is supported:
- 1. for scope handler
- for name in T.serial()/T.parallel()/T.vectorized()/T.unroll()/range()/
- T.grid()/T.thread_binding()
- """
-
- if not isinstance(node.rhs, ast.Call):
- self.report_error("The loop iterator should be a function call.", node.rhs.span)
- func = self.transform(node.rhs.func_name)
- if not isinstance(func, ForScopeHandler):
- self.report_error(
- "Only For scope handlers can be used in a for statement.", node.rhs.func_name.span
- )
- # prepare for new for scope
- old_lineno, old_col_offset = self.current_lineno, self.current_col_offset
- self.current_lineno = node.span.start_line
- self.current_col_offset = node.span.start_column
- self.context.enter_scope(nodes=node.body.stmts)
- # for scope handler process the scope
- arg_list = [
- tvm.runtime.convert(arg, span=tvm_span_from_synr(node.rhs.span))
- for arg in self.parse_arg_list(func, node.rhs)
- ]
- func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
- func.body = self.parse_body(node)
- res = func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span)
- # exit the scope
- self.context.exit_scope()
- self.current_lineno, self.current_col_offset = old_lineno, old_col_offset
- return res
-
- def transform_While(self, node):
- """While visitor
- AST abstract grammar:
- While(expr condition, stmt* body)
- """
- condition = self.transform(node.condition)
- # body
- self.context.enter_scope(nodes=node.body.stmts)
- body = self.parse_body(node)
- self.context.exit_scope()
-
- return tvm.tir.While(condition, body, span=tvm_span_from_synr(node.span))
-
- def transform_With(self, node):
- """With visitor
- AST abstract grammar:
- With(withitem* items, stmt* body, string? type_comment)
- withitem = (expr context_expr, expr? optional_vars)
- By now 2 patterns of With is supported:
- 1. with scope handler with symbol def
- with T.allocate() as targets:
- 2. with scope handler without symbol def
- with T.block(*axes)/T.let()/T.Assert()/T.attr()/T.realize()
- """
-
- if not isinstance(node.rhs, ast.Call):
- self.report_error(
- "The context expression of a `with` statement should be a function call.",
- node.rhs.span,
- )
-
- func = self.transform(node.rhs.func_name)
-
- if not isinstance(func, WithScopeHandler):
- self.report_error(
- f"Function {func} cannot be used in a `with` statement.", node.rhs.func_name.span
- )
- # prepare for new block scope
- old_lineno, old_col_offset = self.current_lineno, self.current_col_offset
- self.current_lineno = node.body.span.start_line
- self.current_col_offset = node.body.span.start_column
- self.context.enter_block_scope(nodes=node.body.stmts)
- # with scope handler process the scope
- arg_list = self.parse_arg_list(func, node.rhs)
- func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
- func.body = self.parse_body(node)
- res = func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span)
- # exit the scope
- self.context.exit_block_scope()
- self.current_lineno, self.current_col_offset = old_lineno, old_col_offset
- return res
-
- def transform_If(self, node):
- """If visitor
- AST abstract grammar:
- If(expr test, stmt* body, stmt* orelse)
- """
-
- condition = self.transform(node.condition)
- # then body
- self.context.enter_scope(nodes=node.true.stmts)
- then_body = self.parse_body(node)
- self.context.exit_scope()
-
- # else body
- if len(node.false.stmts) > 0:
- self.context.enter_scope(nodes=node.false.stmts)
- else_body = self.parse_body(node)
- self.context.exit_scope()
- else:
- else_body = None
-
- return tvm.tir.IfThenElse(
- condition, then_body, else_body, span=tvm_span_from_synr(node.span)
- )
-
- def transform_Call(self, node):
- """Call visitor
-
- 3 different Call patterns are allowed:
- 1. Intrin representing a PrimExpr/IterVar
- 1.1 tir.int/uint/float8/16/32/64/floormod/floordiv/load/cast/ramp/broadcast/max
- 1.2 tir.range/reduce_axis/scan_axis/opaque_axis
- 2. tir.Op(dtype, ...)
- 3. other callable functions
- """
-
- if isinstance(node.func_name, ast.Op):
- if node.func_name.name == ast.BuiltinOp.Subscript:
- return self.transform_Subscript(node)
- if node.func_name.name in self._binop_maker:
- lhs = self.transform(node.params[0])
- # There is no supertype for everything that can appear in
- # an expression, so we manually add what we might get here.
- if not isinstance(lhs, (tvm.tir.PrimExpr, BufferSlice)):
- # We would really like to report a more specific
- # error here, but this parser contains no distinction
- # between parsing statements and parsing expressions. All
- # rules just call `transform`.
- self.report_error(
- f"Left hand side of binary op must be a PrimExpr, "
- "but it is a {type(lhs).__name__}",
- node.params[0].span,
- )
- rhs = self.transform(node.params[1])
- if not isinstance(rhs, (tvm.tir.PrimExpr, BufferSlice)):
- self.report_error(
- f"Right hand side of binary op must be a PrimExpr, "
- "but it is a {type(rhs).__name__}",
- node.params[1].span,
- )
- return call_with_error_reporting(
- self.report_error,
- node.span,
- lambda node, lhs, rhs, span: self._binop_maker[node.func_name.name](
- lhs, rhs, span=span
- ),
- node,
- lhs,
- rhs,
- tvm_span_from_synr(node.span),
- )
- if node.func_name.name in self._unaryop_maker:
- rhs = self.transform(node.params[0])
- if node.func_name.name == ast.BuiltinOp.USub and isinstance(
- node.params[0], ast.Constant
- ):
- # '-literal' should be parsed together for proper literal type inference
- if not isinstance(rhs, (tvm.tir.IntImm, tvm.tir.FloatImm)):
- self.report_error("The literal is illegal after -", node.params[0].span)
- return tvm.tir.const(-rhs.value)
- return self._unaryop_maker[node.func_name.name](
- rhs, span=tvm_span_from_synr(node.span)
- )
- self.report_error(f"Unsupported operator {node.func_name.name}.", node.func_name.span)
- else:
- func = self.transform(node.func_name)
- if isinstance(func, Intrin) and not func.stmt:
- # pattern 1
- arg_list = self.parse_arg_list(func, node)
- return call_with_error_reporting(
- self.report_error,
- node.func_name.span,
- func.handle,
- arg_list,
- node.func_name.span,
- )
- else:
- args = [self.transform(arg) for arg in node.params]
- kw_args = {
- self.transform(k): self.transform(v) for k, v in node.keyword_params.items()
- }
- if isinstance(func, tvm.tir.op.Op):
- if not "dtype" in kw_args.keys():
- self.report_error(f"{func} requires a dtype keyword argument.", node.span)
- # pattern 2
- return tvm.tir.Call(
- kw_args["dtype"], func, args, span=tvm_span_from_synr(node.span)
- )
- elif callable(func):
- # pattern 3
- return func(*args, **kw_args)
- else:
- self.report_error(
- f"Function is neither callable nor a tvm.tir.op.Op (it is a {type(func)}).",
- node.func_name.span,
- )
-
- def transform_UnassignedCall(self, node):
- """Visitor for statements that are function calls.
-
- This handles function calls that appear on thier own line like `tir.realize`.
-
- Examples
- --------
- .. code-block:: python
-
- @T.prim_func
- def f():
- A = T.buffer_decl([10, 10])
- T.realize(A[1:2, 1:2], "") # This is an UnassignedCall
- A[1, 1] = 2 # This is also an UnassignedCall
- """
- # Only allowed builtin operator that can be a statement is x[1] = 3 i.e. subscript assign.
- if isinstance(node.call.func_name, ast.Op):
- if node.call.func_name.name == ast.BuiltinOp.SubscriptAssign:
- return self.transform_SubscriptAssign(node.call)
-
- if node.call.func_name.name == ast.BuiltinOp.AttrAssign:
- return self.transform_AttrAssign(node.call)
-
- self.report_error(
- "Binary and unary operators are not allowed as a statement", node.span
- )
-
- # handle a regular function call
- func = self.transform(node.call.func_name)
- arg_list = self.parse_arg_list(func, node.call)
-
- if isinstance(func, tir.scope_handler.AssertHandler):
- self.report_error(
- "A standalone `T.Assert` is not allowed. Use `assert condition, message` "
- "instead.",
- node.call.func_name.span,
- )
-
- if isinstance(func, Intrin):
- if func.stmt:
- return call_with_error_reporting(
- self.report_error,
- node.call.func_name.span,
- func.handle,
- arg_list,
- node.call.func_name.span,
- )
- else:
- self.report_error(f"This intrinsic cannot be used as a statement.", node.call.span)
- elif isinstance(func, WithScopeHandler) and func.concise_scope and not func.def_symbol:
- func.enter_scope(node, self.context, arg_list, node.call.func_name.span)
- func.body = self.parse_body(node)
- return func.exit_scope(node, self.context, arg_list, node.call.func_name.span)
- elif isinstance(func, SpecialStmt) and not func.def_symbol:
- func.handle(node, self.context, arg_list, node.call.func_name.span)
- return
-
- self.report_error(
- "Unexpected statement. Expected an assert, an intrinsic, a with statement, or a "
- f"special statement, but got {type(func).__name__}.",
- node.call.func_name.span,
- )
-
- def transform_Slice(self, node):
- """Index slice visitor."""
- start = self.transform(node.start)
- end = self.transform(node.end)
- if not (
- isinstance(node.step, ast.Constant)
- and isinstance(node.step.value, int)
- and node.step.value > 0
- ):
- self.report_error(
- "Only positive integer step size is supported for slices.", node.step.span
- )
- return Slice(start, end, node.step.value, tvm_span_from_synr(node.span))
-
- def transform_Subscript(self, node):
- """Array access visitor.
-
- By now only 3 types of Subscript are supported:
- 1. Buffer[index, index, ...], Buffer element access(BufferLoad & BufferStore)
- Var[index] Buffer element access()
- 2. Buffer[start: stop, start: stop, ...], BufferRealize(realize(buffer[...]))
- 3. Array[index], Buffer element access
- """
-
- symbol = self.transform(node.params[0])
- if symbol is None:
- self.report_error(
- f"Variable {node.params[0].id.name} is not defined.", node.params[0].span
- )
-
- indexes = [self.transform(x) for x in node.params[1].values]
- if isinstance(symbol, tvm.tir.expr.Var):
- if symbol.dtype == "handle":
- self.report_error(
- "Cannot read directly from a handle, use `T.match_buffer` "
- "to create a buffer to read from.",
- node.params[0].span,
- )
- if len(indexes) > 1:
- self.report_error(
- "Only a single index can be provided when indexing into a `var`.",
- node.params[1].span,
- )
- index = indexes[0]
- if not isinstance(index, (tvm.tir.PrimExpr, int)):
- self.report_error(
- "Var load index should be an int or PrimExpr, but it is a" + type(index),
- node.span,
- )
-
- self.report_error(
- "Use of tir.Load has been deprecated in favor of tir.BufferLoad", node.span
- )
- elif isinstance(symbol, tvm.tir.Buffer):
- return BufferSlice(
- symbol, indexes, self.report_error, span=tvm_span_from_synr(node.span)
- )
- elif isinstance(symbol, tvm.container.Array):
- if len(indexes) > 1:
- self.report_error(
- "Array access should be one-dimension access, but the indices are "
- + str(indexes),
- node.span,
- )
- index = indexes[0]
- if not isinstance(index, (int, tvm.tir.expr.IntImm)):
- self.report_error(
- "Array access index expected int or IntImm, but got " + type(index),
- node.span,
- )
- if int(index) >= len(symbol):
- self.report_error(
- f"Array access out of bound, size: {len(symbol)}, got index {index}.",
- node.span,
- )
- return symbol[int(index)]
- else:
- self.report_error(
- f"Cannot subscript from a {type(symbol).__name__}. Only variables and "
- "buffers are supported.",
- node.params[0].span,
- )
-
- def transform_Attr(self, node):
- """Visitor for field access of the form `x.y`.
-
- This visitor is used to lookup function and symbol names. We have two
- cases to handle here:
- 1. If we have a statement of the form `tir.something`, then we lookup
- `tir.something` in the `Registry`. If the function is not in the
- registry, then we try to find a `tvm.ir.op.Op` with the same name.
- 2. All other names `tvm.something` are lookup up in this current python
- namespace.
- """
-
- def get_full_attr_name(node: ast.Attr) -> str:
- reverse_field_names = [node.field.name]
- while isinstance(node.object, ast.Attr):
- node = node.object
- reverse_field_names.append(node.field.name)
- if isinstance(node.object, ast.Var):
- reverse_field_names.append(node.object.id.name)
- return ".".join(reversed(reverse_field_names))
-
- if isinstance(node.object, (ast.Var, ast.Attr)):
- full_attr_name = get_full_attr_name(node)
- attr_object, fields = full_attr_name.split(".", maxsplit=1)
- if self.match_tir_namespace(attr_object):
- func_name = "tir." + fields
- res = Registry.lookup(func_name)
- if res is not None:
- return res
- try:
- return tvm.ir.op.Op.get(func_name)
- except TVMError as e:
- # Check if we got an attribute error
- if e.args[0].find("AttributeError"):
- self.report_error(f"Unregistered function `tir.{fields}`.", node.span)
- else:
- raise e
-
- symbol = self.transform(node.object)
- if symbol is None:
- self.report_error("Unsupported Attribute expression.", node.object.span)
- if not hasattr(symbol, node.field.name):
- self.report_error(
- f"Type {type(symbol)} does not have a field called `{node.field.name}`.", node.span
- )
- res = getattr(symbol, node.field.name)
- return res
-
- def transform_TypeAttr(self, node):
- """Visitor for field access of the form `x.y` for types.
-
- We have two cases here:
- 1. If the type is of the form `T.something`, we look up the type in
- the `tir` namespace in this module.
- 2. If the type is of the form `tvm.x.something` then we look up
- `tvm.x.something` in this modules namespace.
- """
- if isinstance(node.object, ast.TypeVar):
- if self.match_tir_namespace(node.object.id.name):
- if not hasattr(tir, node.field.name):
- self.report_error(
- f"Invalid type annotation `tir.{node.field.name}`.", node.span
- )
- return getattr(tir, node.field.name)
-
- symbol = self.transform(node.object)
- if symbol is None:
- self.report_error("Unsupported Attribute expression", node.object.span)
- if not hasattr(symbol, node.field):
- self.report_error(
- f"Type {type(symbol)} does not have a field called `{node.field}`.", node.span
- )
- res = getattr(symbol, node.field)
- return res
-
- def transform_DictLiteral(self, node):
- """Dictionary literal visitor.
-
- Handles dictionary literals of the form `{x:y, z:2}`.
- """
-
- keys = [self.transform(key) for key in node.keys]
- values = [self.transform(value) for value in node.values]
-
- return dict(zip(keys, values))
-
- def transform_Tuple(self, node):
- """Tuple visitor.
-
- Handles tuples of the form `(x, y, 2)`.
- """
-
- return tuple(self.transform(element) for element in node.values)
-
- def transform_ArrayLiteral(self, node):
- """List literal visitor.
-
- Handles lists of the form `[x, 2, 3]`.
- """
-
- return [self.transform(element) for element in node.values]
-
- def transform_Var(self, node):
- """Variable visitor
-
- Handles variables like `x` in `x = 2`.
- """
-
- name = node.id.name
- if name == "meta":
- return self.meta
- symbol = Registry.lookup(name)
- if symbol is not None:
- return symbol
- symbol = self.context.lookup_symbol(name)
- if symbol is not None:
- return symbol
- self.report_error(f"Unknown identifier {name}.", node.span)
-
- def transform_TypeVar(self, node):
- """Type variable visitor.
-
- Equivalent to `transform_Var` but for types.
- """
- name = node.id.name
- symbol = Registry.lookup(name) or self.context.lookup_symbol(name)
- if symbol is not None:
- return symbol
- self.report_error(f"Unknown identifier {name}.", node.span)
-
- def transform_Constant(self, node):
- """Constant value visitor.
-
- Constant values include `None`, `"strings"`, `2` (integers), `4.2`
- (floats), and `true` (booleans).
- """
- return tvm.runtime.convert(node.value, span=tvm_span_from_synr(node.span))
-
- def transform_TypeConstant(self, node):
- """Constant value visitor for types.
-
- See `transform_Constant`.
- """
- if self._inside_buffer_sugar:
- return self.transform_Constant(node)
-
- return node.value
-
- def transform_TypeTuple(self, node):
- """Tuple value visitor for types.
-
- Mostly used in `transform_TypeCall` and `transform_TypeApply`.
- """
- return [self.transform(value) for value in node.values]
-
- def transform_TypeCall(self, node):
- """TypeCall visitor
-
- This occurs when an expression is used inside a T.Buffer
- parameter annotation.
- """
-
- # ast.Call has the BuiltinOp as node.func_name.name, where
- # ast.TypeCall has the BuiltinOp as node.func_name. So we can
- # delegate to self.transform_Call, but the error messages for
- # unsupported operations will highlight the entire expression
- # and not just the function itself.
- op = ast.Op(node.span, node.func_name)
- call = ast.Call(node.span, op, node.params, node.keyword_params)
- return self.transform_Call(call)
-
- def transform_TypeApply(self, node):
- """Visitor for Type[Type] expressions.
-
- Mostly used for ``T.Ptr`` expressions.
- """
- func = self.transform(node.func_name)
-
- if not isinstance(func, ty.TypeGeneric) or not hasattr(func, "__getitem__"):
- self.report_error(
- f"Use of type arguments requires a type that accepts type arguments (e.g. T.Ptr), "
- f"but found {type(func).__name__} instead.",
- node.span,
- )
-
- param_types = []
- for idx, param in enumerate(node.params):
- param_type = self.transform(param)
- if not isinstance(param_type, ty.TypeGeneric) and func.require_type_generic_at(idx):
- self.report_error(
- f"Expected a type but found {type(param).__name__} "
- f"at {idx}th type argument",
- param.span,
- )
-
- param_types.append(param_type)
-
- if len(param_types) == 1:
- return func[param_types[0]]
- else:
- return func[param_types]
-
- def handle_match_buffer_type(self, node, buffer_name):
- """special function to handle syntax sugar for match buffer.
-
- This method is for buffer declarations in the function parameters.
- """
- func = self.transform(node.func_name)
- assert isinstance(func, SpecialStmt)
-
- # parse args and kwargs for TypeCall and TypeApply
- self._inside_buffer_sugar = True
- try:
- arg_list = self.parse_arg_list(func, node)
- finally:
- self._inside_buffer_sugar = False
-
- # Note that the third element in arg_list would always be the 'name'
- # TODO: This index is hardcoded as a workaround. Better to make it programmatic
- if arg_list[2] is None:
- arg_list[2] = buffer_name
- buf = func.handle(node, self.context, arg_list, node.func_name.span)
- return buf
-
- def transform_Return(self, node):
- self.report_error(
- "TVM script does not support return statements. Instead the last statement in any "
- "block is implicitly returned.",
- node.span,
- )
-
-
-def get_tir_namespace(script: Union[Callable, type]) -> List[str]:
- assert inspect.isfunction(script) or inspect.isclass(script)
- env: Dict[str, Any] = script.__globals__
- return [key for key in env.keys() if env[key] == tir]
-
-
-def from_source(
- input_func: Union[str, Callable], tir_prefix: Optional[List[str]] = None
-) -> Union[PrimFunc, IRModule]:
- """Parse function or string into PrimFunc or IRModule.
-
- If possible, pass the TVM script in as a function so that line numbers and
- filename will be accurate.
-
- Parameters
- ----------
- input_module : Union[str, Callable]
- The python function to be parsed.
-
- tir_prefix : Optional[List[str]]
- The tir prefix list. Only works for str input, default by "tir" and "T".
-
- Returns
- -------
- output : Union[Function, Module]
- The Function or Module in IR.
- """
- if isinstance(input_func, str):
- tir_prefix = ["T", "tir"] if tir_prefix is None else tir_prefix
- return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix, {}))
- elif inspect.isfunction(input_func):
- _, start_line = inspect.getsourcelines(input_func)
- env: Dict[str, Any] = input_func.__globals__
- namespace = [key for key in env.keys() if env[key] is tir]
- _closure_vars = inspect.getclosurevars(input_func)
- closure_vars = {**_closure_vars.nonlocals, **_closure_vars.globals}
- parser = TVMScriptParser(start_line, namespace, closure_vars)
- result = to_ast(input_func, TVMDiagnosticCtx(), parser)
- return result
- else:
- raise TypeError("Only function definitions are supported.")
-
-
-def ir_module(input_module: type) -> IRModule:
- """Decorate a python class as tvm IRModule.
-
- Parameters
- ----------
- input_module : type
- The python class to be parsed.
-
- Returns
- -------
- output : IRModule
- The result IRModule.
- """
- if inspect.isclass(input_module):
- func_dict = {
- name: f for name, f in input_module.__dict__.items() if isinstance(f, BaseFunc)
- }
- return IRModule(func_dict)
- raise TypeError("Only class definitions are supported.")
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/__init__.py
similarity index 83%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/__init__.py
index 555659d0c5..5161a2601c 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/__init__.py
@@ -13,9 +13,9 @@
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
-# under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
-from . import tir
-
-from .parser import ir_module, from_source
+# under the Licens.
+"""The parser"""
+from . import _core, ir, tir
+from ._core import parse
+from .ir import ir_module
+from .tir import prim_func
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/_core.py
similarity index 76%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/_core.py
index 555659d0c5..4f5411dc36 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/_core.py
@@ -13,9 +13,10 @@
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
-# under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
-from . import tir
-
-from .parser import ir_module, from_source
+# under the Licens.
+"""The core parser infra"""
+# pylint: disable=unused-import
+from .core import dispatch, doc, utils
+from .core.dispatch import OpMethod, register_op
+from .core.entry import parse
+from .core.parser import Parser
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/core/__init__.py
similarity index 85%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/core/__init__.py
index 555659d0c5..94d8dab032 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/core/__init__.py
@@ -14,8 +14,5 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
-from . import tir
-
-from .parser import ir_module, from_source
+"""The core parser infra"""
+from . import diagnostics, dispatch, doc, doc_core, entry, evaluator, parser, utils
diff --git a/python/tvm/script/parser/core/diagnostics.py b/python/tvm/script/parser/core/diagnostics.py
new file mode 100644
index 0000000000..51c26bbc24
--- /dev/null
+++ b/python/tvm/script/parser/core/diagnostics.py
@@ -0,0 +1,175 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import inspect
+import re
+import sys
+from typing import Union
+
+from tvm.ir import IRModule, SourceName, Span, diagnostics
+
+from . import doc
+
+
+class Source:
+ source_name: str
+ start_line: int
+ start_column: int
+ source: str
+ full_source: str
+
+ def __init__(self, program: Union[str, doc.AST]):
+ if isinstance(program, str):
+ self.source_name = "<str>"
+ self.start_line = 1
+ self.start_column = 0
+ self.source = program
+ self.full_source = program
+ return
+
+ self.source_name = inspect.getsourcefile(program) # type: ignore
+ lines, self.start_line = getsourcelines(program) # type: ignore
+ if lines:
+ self.start_column = len(lines[0]) - len(lines[0].lstrip())
+ else:
+ self.start_column = 0
+ if self.start_column and lines:
+ self.source = "\n".join([l[self.start_column :].rstrip() for l in lines])
+ else:
+ self.source = "".join(lines)
+ try:
+ # It will cause a problem when running in Jupyter Notebook.
+ # `mod` will be <module '__main__'>, which is a built-in module
+ # and `getsource` will throw a TypeError
+ mod = inspect.getmodule(program)
+ if mod:
+ self.full_source = inspect.getsource(mod)
+ else:
+ self.full_source = self.source
+ except TypeError:
+ # It's a work around for Jupyter problem.
+ # Since `findsource` is an internal API of inspect, we just use it
+ # as a fallback method.
+ src, _ = inspect.findsource(program) # type: ignore
+ self.full_source = "".join(src)
+
+ def as_ast(self) -> doc.AST:
+ return doc.parse(self.source)
+
+
+_getfile = inspect.getfile # pylint: disable=invalid-name
+_findsource = inspect.findsource # pylint: disable=invalid-name
+
+
+def _patched_inspect_getfile(obj):
+ if not inspect.isclass(obj):
+ return _getfile(obj)
+ mod = getattr(obj, "__module__", None)
+ if mod is not None:
+ file = getattr(sys.modules[mod], "__file__", None)
+ if file is not None:
+ return file
+ for _, member in inspect.getmembers(obj):
+ if inspect.isfunction(member):
+ if obj.__qualname__ + "." + member.__name__ == member.__qualname__:
+ return inspect.getfile(member)
+ raise TypeError(f"Source for {obj:!r} not found")
+
+
+def findsource(obj):
+ import linecache # pylint: disable=import-outside-toplevel
+
+ if not inspect.isclass(obj):
+ return _findsource(obj)
+
+ file = inspect.getsourcefile(obj)
+ if file:
+ linecache.checkcache(file)
+ else:
+ file = inspect.getfile(obj)
+ if not (file.startswith("<") and file.endswith(">")):
+ raise OSError("source code not available")
+
+ module = inspect.getmodule(obj, file)
+ if module:
+ lines = linecache.getlines(file, module.__dict__)
+ else:
+ lines = linecache.getlines(file)
+ if not lines:
+ raise OSError("could not get source code")
+ qual_names = obj.__qualname__.replace(".<locals>", "<locals>").split(".")
+ pattern_list = []
+ for name in qual_names:
+ if name.endswith("<locals>"):
+ pattern_list.append(re.compile(r"^(\s*)def\s*" + name[:-8] + r"\b"))
+ else:
+ pattern_list.append(re.compile(r"^(\s*)class\s*" + name + r"\b"))
+ for i, line in enumerate(lines):
+ match = pattern_list[0].match(line)
+ if match:
+ pattern_list.pop(0)
+ if not pattern_list:
+ return lines, i
+ raise OSError("could not find class definition")
+
+
+def getsourcelines(obj):
+ obj = inspect.unwrap(obj)
+ lines, l_num = findsource(obj)
+ return inspect.getblock(lines[l_num:]), l_num + 1
+
+
+inspect.getfile = _patched_inspect_getfile
+
+
+class Diagnostics:
+
+ source: Source
+ ctx: diagnostics.DiagnosticContext
+
+ def __init__(self, source: Source):
+ mod = IRModule()
+ mod.source_map.add(source.source_name, source.full_source)
+ self.source = source
+ self.ctx = diagnostics.DiagnosticContext(mod, diagnostics.get_renderer())
+
+ def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel) -> None:
+ lineno = node.lineno or self.source.start_line
+ col_offset = node.col_offset or self.source.start_column
+ end_lineno = node.end_lineno or lineno
+ end_col_offset = node.end_col_offset or col_offset
+ lineno += self.source.start_line - 1
+ end_lineno += self.source.start_line - 1
+ col_offset += self.source.start_column + 1
+ end_col_offset += self.source.start_column + 1
+ self.ctx.emit(
+ diagnostics.Diagnostic(
+ level=level,
+ span=Span(
+ source_name=SourceName(self.source.source_name),
+ line=lineno,
+ end_line=end_lineno,
+ column=col_offset,
+ end_column=end_col_offset,
+ ),
+ message=message,
+ )
+ )
+
+ def error(self, node: doc.AST, message: str) -> None:
+ self._emit(node, message, diagnostics.DiagnosticLevel.ERROR)
+ self.ctx.render()
diff --git a/python/tvm/script/parser/core/dispatch.py b/python/tvm/script/parser/core/dispatch.py
new file mode 100644
index 0000000000..f10b90961a
--- /dev/null
+++ b/python/tvm/script/parser/core/dispatch.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type
+
+from .doc import AST
+
+if TYPE_CHECKING:
+ from .parser import Parser
+
+
+ParseMethod = Callable[["Parser", AST], None]
+ParseVTable: Dict[Tuple[str, str], ParseMethod] = {}
+
+OpMethod = Callable[..., Any]
+OpVTable: Dict[Tuple[Type, AST, int], OpMethod] = {}
+
+
+def register(token: str, type_name: str):
+ """Register a method for a dispatch token and type name"""
+
+ def f(method: ParseMethod):
+ ParseVTable[(token, type_name)] = method
+
+ return f
+
+
+def get(
+ token: str,
+ type_name: str,
+ default: Optional[ParseMethod] = None,
+) -> Optional[ParseMethod]:
+ return ParseVTable.get((token, type_name), default)
+
+
+def register_op(ty: Type, op: AST, operand_index: int): # pylint: disable=invalid-name
+ def f(method: OpMethod):
+ OpVTable[(ty, op, operand_index)] = method
+
+ return f
+
+
+def get_op( # pylint: disable=invalid-name
+ ty: Type,
+ op: Type,
+ operand_index: int,
+ default: Optional[OpMethod] = None,
+) -> Optional[OpMethod]:
+ return OpVTable.get((ty, op, operand_index), default)
diff --git a/python/tvm/script/parser/core/doc.py b/python/tvm/script/parser/core/doc.py
new file mode 100644
index 0000000000..f6a641cb64
--- /dev/null
+++ b/python/tvm/script/parser/core/doc.py
@@ -0,0 +1,361 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import ast
+import inspect
+import sys
+import typing
+from collections import defaultdict
+
+from . import doc_core as doc
+from .doc_core import * # pylint: disable=unused-import,wildcard-import,redefined-builtin,W0614
+
+FnToDoc = typing.Callable[[ast.AST], doc.AST]
+FnFromDoc = typing.Callable[[doc.AST], ast.AST]
+
+
+class Entry:
+ to_doc: typing.Optional[FnToDoc]
+ from_doc: typing.Optional[FnFromDoc]
+
+ def __init__(self):
+ self.to_doc = None
+ self.from_doc = None
+
+
+class Registry:
+ _inst: typing.Optional["Registry"] = None
+ table: typing.Dict[str, Entry]
+
+ def __init__(self):
+ self.table = defaultdict(Entry)
+
+
+def register_to_doc(name: str):
+ def f(to_doc: FnToDoc): # pylint: disable=redefined-outer-name
+ reg = Registry._inst # pylint: disable=protected-access
+ reg.table[name].to_doc = to_doc
+
+ return f
+
+
+def register_from_doc(name: str):
+ def f(to_doc: FnFromDoc): # pylint: disable=redefined-outer-name
+ reg = Registry._inst # pylint: disable=protected-access
+ reg.table[name].from_doc = to_doc
+
+ return f
+
+
+def _is_atomic_type(node):
+ return (
+ node is None
+ or node in [..., True, False]
+ or isinstance(
+ node,
+ (
+ int,
+ float,
+ str,
+ bool,
+ bytes,
+ complex,
+ ),
+ )
+ )
+
+
+def _get_registry_entry(cls_name, attr):
+ cls_name = cls_name.split(".")[-1]
+ reg = Registry._inst # pylint: disable=protected-access
+ if cls_name in reg.table:
+ entry = reg.table[cls_name]
+ return getattr(entry, attr, None)
+ return None
+
+
+def from_doc(node):
+ if _is_atomic_type(node):
+ return node
+ if isinstance(node, tuple):
+ return tuple(from_doc(n) for n in node)
+ if isinstance(node, list):
+ return [from_doc(n) for n in node]
+ func = _get_registry_entry(node.__class__.__name__, "from_doc")
+ if not func:
+ raise NotImplementedError(f"from_doc is not implemented for: {node.__class__.__name__}")
+ return func(node)
+
+
+def to_doc(node):
+ if _is_atomic_type(node):
+ return node
+ if isinstance(node, tuple):
+ return tuple(to_doc(n) for n in node)
+ if isinstance(node, list):
+ return [to_doc(n) for n in node]
+ func = _get_registry_entry(node.__class__.__name__, "to_doc")
+ if not func:
+ raise NotImplementedError(f"to_doc is not implemented for: {node.__class__.__name__}")
+ return func(node)
+
+
+def parse(
+ source,
+ filename="<unknown>",
+ mode="exec",
+) -> doc.AST:
+ try:
+ program = ast.parse( # pylint: disable=unexpected-keyword-arg
+ source=source,
+ filename=filename,
+ mode=mode,
+ feature_version=(3, 8),
+ )
+ except: # pylint: disable=bare-except
+ program = ast.parse(
+ source=source,
+ filename=filename,
+ mode=mode,
+ )
+ return to_doc(program)
+
+
+class NodeVisitor:
+ def visit(self, node: doc.AST) -> None:
+ if isinstance(node, (list, tuple)):
+ for item in node:
+ self.visit(item)
+ return
+ if not isinstance(node, doc.AST):
+ return
+ getattr(
+ self,
+ "visit_" + node.__class__.__name__.split(".")[-1],
+ self.generic_visit,
+ )(node)
+
+ def generic_visit(self, node: doc.AST) -> None:
+ for field in node.__class__._FIELDS: # pylint: disable=protected-access
+ value = getattr(node, field, None)
+ if value is None:
+ pass
+ elif isinstance(value, (doc.AST, list, tuple)):
+ self.visit(value)
+
+
+class NodeTransformer:
+ def visit(self, node: doc.AST) -> doc.AST:
+ if isinstance(node, list):
+ return [self.visit(item) for item in node]
+ if isinstance(node, tuple):
+ return tuple(self.visit(item) for item in node)
+ if not isinstance(node, doc.AST):
+ return node
+ return getattr(
+ self,
+ "visit_" + node.__class__.__name__.split(".")[-1],
+ self.generic_visit,
+ )(node)
+
+ def generic_visit(self, node: doc.AST) -> doc.AST:
+ kv: typing.Dict[str, typing.Any] = {}
+ for field in node.__class__._FIELDS: # pylint: disable=protected-access
+ value = getattr(node, field, None)
+ if value is None:
+ pass
+ elif isinstance(value, (doc.AST, list, tuple)):
+ value = self.visit(value)
+ kv[field] = value
+ return node.__class__(**kv)
+
+
+def _register_default():
+ class DefaultTranslator:
+ def __init__(self, doc_cls, func, fields):
+ self.doc_cls = doc_cls # getattr(doc, name)
+ self.func = func
+ self.fields = fields
+
+ def __call__(self, node):
+ kv = {attr: self.func(getattr(node, attr, None)) for attr in self.fields}
+ return self.doc_cls(**kv)
+
+ Registry._inst = Registry() # pylint: disable=protected-access
+ for cls_name in dir(doc):
+ doc_cls = getattr(doc, cls_name)
+ if not hasattr(ast, cls_name):
+ continue
+ if inspect.isclass(doc_cls) and issubclass(doc_cls, doc.AST):
+ assert "." not in cls_name
+ register_to_doc(cls_name)(
+ DefaultTranslator(
+ getattr(doc, cls_name),
+ to_doc,
+ doc_cls._FIELDS, # pylint: disable=protected-access
+ )
+ )
+ register_from_doc(cls_name)(
+ DefaultTranslator(
+ getattr(ast, cls_name),
+ from_doc,
+ doc_cls._FIELDS, # pylint: disable=protected-access
+ )
+ )
+
+
+def _py_version() -> typing.Tuple[int, int]:
+ return (sys.version_info.major, sys.version_info.minor)
+
+
+def _register_constant_handling():
+ if _py_version() not in [(3, 6), (3, 7)]:
+ return
+
+ def as_constant(f) -> doc.Constant:
+ def to_doc_func(x: ast.AST) -> doc.Constant:
+ return doc.Constant(
+ value=getattr(x, f) if isinstance(f, str) else f(x),
+ kind=None,
+ s=None,
+ n=None,
+ lineno=x.lineno,
+ col_offset=x.col_offset,
+ end_lineno=x.lineno,
+ end_col_offset=x.col_offset,
+ )
+
+ return to_doc_func
+
+ register_to_doc("Str")(as_constant("s"))
+ register_to_doc("NameConstant")(as_constant("value"))
+ register_to_doc("Num")(as_constant("n"))
+ register_to_doc("Bytes")(as_constant("s"))
+ register_to_doc("Ellipsis")(as_constant(lambda _: ...))
+
+
+def _register_subscription_handling():
+ if _py_version() >= (3, 9):
+ return
+
+ def subscript_to_doc(x: ast.Subscript) -> doc.Subscript:
+ if isinstance(x.slice, ast.Slice):
+ return doc.Subscript(
+ value=to_doc(x.value),
+ slice=doc.Slice(
+ lower=to_doc(x.slice.lower),
+ upper=to_doc(x.slice.upper),
+ step=to_doc(x.slice.step),
+ lineno=getattr(x.slice, "lineno", None),
+ col_offset=getattr(x.slice, "col_offset", None),
+ end_lineno=getattr(x.slice, "end_lineno", None),
+ end_col_offset=getattr(x.slice, "end_col_offset", None),
+ ),
+ ctx=to_doc(x.ctx),
+ lineno=getattr(x, "lineno", None),
+ col_offset=getattr(x, "col_offset", None),
+ end_lineno=getattr(x, "end_lineno", None),
+ end_col_offset=getattr(x, "end_col_offset", None),
+ )
+ if isinstance(x.slice, ast.ExtSlice):
+ return doc.Subscript(
+ value=to_doc(x.value),
+ slice=doc.Tuple(
+ elts=[to_doc(i) for i in x.slice.dims],
+ ctx=doc.Load(
+ lineno=None,
+ col_offset=None,
+ end_lineno=None,
+ end_col_offset=None,
+ ),
+ lineno=getattr(x, "lineno", None),
+ col_offset=getattr(x, "col_offset", None),
+ end_lineno=getattr(x, "end_lineno", None),
+ end_col_offset=getattr(x, "end_col_offset", None),
+ ),
+ ctx=to_doc(x.ctx),
+ lineno=getattr(x, "lineno", None),
+ col_offset=getattr(x, "col_offset", None),
+ end_lineno=getattr(x, "end_lineno", None),
+ end_col_offset=getattr(x, "end_col_offset", None),
+ )
+ if isinstance(x.slice, ast.Index):
+ return doc.Subscript(
+ value=to_doc(x.value),
+ slice=to_doc(x.slice.value),
+ ctx=to_doc(x.ctx),
+ lineno=getattr(x, "lineno", None),
+ col_offset=getattr(x, "col_offset", None),
+ end_lineno=getattr(x, "end_lineno", None),
+ end_col_offset=getattr(x, "end_col_offset", None),
+ )
+ raise TypeError(f"Unknown subscript type: {type(x.slice)}")
+
+ def subscript_from_doc(x: doc.Subscript) -> ast.Subscript:
+ if isinstance(x.slice, doc.Slice):
+ result = ast.Subscript(
+ value=from_doc(x.value),
+ slice=from_doc(x.slice),
+ ctx=from_doc(x.ctx),
+ )
+ elif isinstance(x.slice, doc.Tuple):
+ result = ast.Subscript(
+ value=from_doc(x.value),
+ slice=ast.ExtSlice(
+ dims=[from_doc(i) for i in x.slice.elts],
+ ),
+ ctx=from_doc(x.ctx),
+ )
+ else:
+ result = ast.Subscript(
+ value=from_doc(x.value),
+ slice=ast.Index(value=from_doc(x.slice)),
+ ctx=from_doc(x.ctx),
+ )
+ result.lineno = x.lineno
+ result.col_offset = x.col_offset
+ result.end_lineno = x.end_lineno
+ result.end_col_offset = x.end_col_offset
+ return result
+
+ register_to_doc("Subscript")(subscript_to_doc)
+ register_from_doc("Subscript")(subscript_from_doc)
+
+
+def _register_index_handling():
+ if _py_version() >= (3, 9):
+ return
+
+ def index_to_doc(x: ast.Index) -> doc.Expr:
+ return to_doc(x.value)
+
+ def index_from_doc(x: doc.Expr) -> ast.Index:
+ result = ast.Index(value=from_doc(x), ctx=from_doc(x.ctx))
+ result.lineno = x.lineno
+ result.col_offset = x.col_offset
+ result.end_lineno = x.end_lineno
+ result.end_col_offset = x.end_col_offset
+ return result
+
+ register_to_doc("Index")(index_to_doc)
+ register_from_doc("Index")(index_from_doc)
+
+
+_register_default()
+_register_constant_handling()
+_register_subscription_handling()
+_register_index_handling()
diff --git a/python/tvm/script/printer/doc_core.py b/python/tvm/script/parser/core/doc_core.py
similarity index 100%
rename from python/tvm/script/printer/doc_core.py
rename to python/tvm/script/parser/core/doc_core.py
diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/parser/core/entry.py
similarity index 51%
copy from python/tvm/script/tir/prim_func.py
copy to python/tvm/script/parser/core/entry.py
index 923eb97d27..ccf42e8c15 100644
--- a/python/tvm/script/tir/prim_func.py
+++ b/python/tvm/script/parser/core/entry.py
@@ -14,32 +14,30 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""TVM Script Interface for PrimFunc"""
+# pylint: disable=missing-docstring
+"""The entry point of TVM parser."""
+from typing import Any, Union
-import inspect
-from typing import Callable
+from ...ir_builder import IRBuilder
+from . import doc
+from .diagnostics import Source
+from .parser import Parser
-from tvm.tir.function import PrimFunc
-from ..parser import from_source
+def parse(program: Union[doc.AST, Any, str], extra_vars=None):
+ if extra_vars is None:
+ from tvm.script.parser import ir # pylint: disable=import-outside-toplevel
+ from tvm.script.parser import tir # pylint: disable=import-outside-toplevel
-def prim_func(input_func: Callable) -> PrimFunc:
- """Decorate a python function as tvm script.
+ extra_vars = {
+ "I": ir,
+ "ir": ir,
+ "T": tir,
+ "tir": tir,
+ }
- Parameters
- ----------
- func : input_func
- The function to be parsed.
-
- Returns
- -------
- output : PrimFunc
- The result functions.
- """
- if inspect.isfunction(input_func):
- result = from_source(input_func)
- result.__name__ = input_func.__name__
- result.__qualname__ = input_func.__qualname__
- return result
-
- raise TypeError("Only function definitions are supported.")
+ source = Source(program)
+ parser = Parser(source)
+ with IRBuilder() as builder:
+ parser.parse(extra_vars=extra_vars)
+ return builder.get()
diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py
new file mode 100644
index 0000000000..8cbde66e39
--- /dev/null
+++ b/python/tvm/script/parser/core/evaluator.py
@@ -0,0 +1,284 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""AST Evaluation"""
+import ast
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
+
+from . import dispatch, doc
+
+if TYPE_CHECKING:
+ from .parser import Parser
+
+DEFAULT_OP: Dict[Type, Callable[..., Any]] = {
+ doc.Add: lambda a, b: a + b,
+ doc.Sub: lambda a, b: a - b,
+ doc.Mult: lambda a, b: a * b,
+ doc.Div: lambda a, b: a / b,
+ doc.FloorDiv: lambda a, b: a // b,
+ doc.Mod: lambda a, b: a % b,
+ doc.LShift: lambda a, b: a << b,
+ doc.RShift: lambda a, b: a >> b,
+ doc.BitOr: lambda a, b: a | b,
+ doc.BitXor: lambda a, b: a ^ b,
+ doc.BitAnd: lambda a, b: a & b,
+ doc.MatMult: lambda a, b: a @ b,
+ # fmt: off
+ doc.Pow: lambda a, b: a ** b,
+ # fmt: on
+ doc.Eq: lambda a, b: a == b,
+ doc.NotEq: lambda a, b: a != b,
+ doc.Lt: lambda a, b: a < b,
+ doc.LtE: lambda a, b: a <= b,
+ doc.Gt: lambda a, b: a > b,
+ doc.GtE: lambda a, b: a >= b,
+ doc.Is: lambda a, b: a is b,
+ doc.IsNot: lambda a, b: a is not b,
+ doc.In: lambda a, b: a in b,
+ doc.NotIn: lambda a, b: a not in b,
+ doc.And: lambda a, b: a and b,
+ doc.Or: lambda a, b: a or b,
+ doc.Invert: lambda a: ~a,
+ doc.Not: lambda a: not a,
+ doc.UAdd: lambda a: +a,
+ doc.USub: lambda a: -a,
+}
+
+
+class ExprEvaluator:
+
+ parser: "Parser"
+ value_table: Dict[str, Any]
+ new_value_count: int
+
+ def __init__(self, parser: "Parser", value_table: Dict[str, Any]) -> None:
+ super().__init__()
+ self.parser = parser
+ self.value_table = value_table
+ self.new_value_count = 0
+
+ @staticmethod
+ def eval(parser: "Parser", value_table: Dict[str, Any], node: doc.AST) -> Any:
+ self = ExprEvaluator(parser, value_table)
+ result = self._visit(node) # pylint: disable=protected-access
+ if isinstance(result, doc.Name):
+ if result.id not in self.value_table:
+ self.parser.report_error(result, f"Undefined variable: {result.id}")
+ return self.value_table[result.id]
+ if isinstance(result, doc.Constant):
+ return result.value
+ raise TypeError(f"Unexpected result type: {type(result)}")
+
+ def _add_intermediate_result(self, value: Any) -> doc.Name:
+ name = f"__tvm_tmp_value_{self.new_value_count}"
+ self.new_value_count += 1
+ self.value_table[name] = value
+ lineno = 0
+ col_offset = 0
+ return doc.Name(
+ id=name,
+ ctx=doc.Load(
+ lineno=lineno,
+ col_offset=col_offset,
+ end_lineno=None,
+ end_col_offset=None,
+ ),
+ lineno=lineno,
+ col_offset=col_offset,
+ end_lineno=None,
+ end_col_offset=None,
+ )
+
+ def _visit(self, node: doc.AST) -> Any:
+ if isinstance(node, list):
+ return [self._visit(n) for n in node]
+ if isinstance(node, tuple):
+ return tuple(self._visit(n) for n in node)
+ assert isinstance(node, doc.AST)
+ if isinstance(node, doc.Name):
+ if node.id not in self.value_table:
+ self.parser.report_error(node, f"Undefined variable: {node.id}")
+ return node
+ if isinstance(
+ node,
+ (
+ doc.Constant,
+ doc.expr_context,
+ doc.operator,
+ doc.boolop,
+ doc.unaryop,
+ doc.cmpop,
+ ),
+ ):
+ return node
+ if not isinstance(node, (doc.expr, doc.slice)):
+ return node
+ if isinstance(node, doc.Lambda):
+ return self._eval_lambda(node)
+ fields = {}
+ for field in node.__class__._FIELDS: # pylint: disable=protected-access
+ attr = getattr(node, field)
+ if isinstance(attr, (doc.AST, tuple, list)):
+ fields[field] = self._visit(attr)
+ else:
+ fields[field] = attr
+ try:
+ if isinstance(node, doc.BoolOp):
+ value = self._eval_bool_op(fields)
+ elif isinstance(node, doc.Compare):
+ value = self._eval_compare(fields)
+ elif isinstance(node, doc.UnaryOp):
+ value = self._eval_unary_op(fields)
+ elif isinstance(node, doc.BinOp):
+ value = self._eval_bin_op(fields)
+ elif isinstance(node, doc.Slice):
+ value = self._eval_slice(fields)
+ else:
+ value = self._eval_expr(node.__class__(**fields))
+ except Exception as e: # pylint: disable=broad-except,invalid-name
+ self.parser.report_error(node, str(e))
+ return self._add_intermediate_result(value)
+
+ def _eval_lambda(self, node: doc.Lambda) -> Any:
+ try:
+ value = self._eval_expr(node)
+ except Exception as e: # pylint: disable=broad-except,invalid-name
+ self.parser.report_error(node, str(e))
+ return self._add_intermediate_result(value)
+
+ def _eval_bool_op(self, fields: Dict[str, Any]) -> Any:
+ op = fields["op"]
+ if not isinstance(op, (doc.And, doc.Or)):
+ raise TypeError(f"Unexpected operator: {op}")
+ value = self._eval_expr(fields["values"][0])
+ for rhs in fields["values"][1:]:
+ value = _eval_op(op, values=[value, self._eval_expr(rhs)])
+ return value
+
+ def _eval_compare(self, fields: Dict[str, Any]) -> Any:
+ value = self._eval_expr(fields["left"])
+ for op, rhs in zip(fields["ops"], fields["comparators"]):
+ value = _eval_op(op, values=[value, self._eval_expr(rhs)])
+ return value
+
+ def _eval_unary_op(self, fields: Dict[str, Any]) -> Any:
+ value = self._eval_expr(fields["operand"])
+ value = _eval_op(fields["op"], values=[value])
+ return value
+
+ def _eval_bin_op(self, fields: Dict[str, Any]) -> Any:
+ return _eval_op(
+ fields["op"],
+ values=[
+ self._eval_expr(fields["left"]),
+ self._eval_expr(fields["right"]),
+ ],
+ )
+
+ def _eval_slice(self, fields: Dict[str, Any]) -> Any:
+ lower, upper, step = fields["lower"], fields["upper"], fields["step"]
+
+ lower = self._eval_expr(lower) if lower is not None else None
+ upper = self._eval_expr(upper) if upper is not None else None
+ step = self._eval_expr(step) if step is not None else None
+
+ return slice(lower, upper, step)
+
+ def _eval_expr(self, v: Any) -> Any:
+ return _eval_expr(v, self.value_table)
+
+
+def eval_expr(
+ parser: "Parser",
+ node: Union[doc.expr, doc.Expression],
+ dict_globals: Optional[Dict[str, Any]],
+) -> Any:
+ value_table = {}
+ if dict_globals is not None:
+ value_table.update(dict_globals)
+ return ExprEvaluator.eval(parser, value_table, node)
+
+
+def eval_assign(
+ parser: "Parser",
+ target: doc.expr,
+ source: Any,
+) -> Dict[str, Any]:
+ try:
+ return _eval_assign(target, source)
+ except Exception as e: # pylint: disable=broad-except,invalid-name
+ parser.report_error(target, f"Failed to evaluate assignment: {str(e)}")
+ raise
+
+
+def _eval_expr(
+ node: Union[doc.expr, doc.Expression],
+ dict_globals: Optional[Dict[str, Any]],
+) -> Any:
+ node = doc.from_doc(node)
+ if isinstance(node, ast.expr):
+ node = ast.Expression(body=node)
+ assert isinstance(node, ast.Expression), "Expects an ast.Expression, but gets: " + str(node)
+ if dict_globals is None:
+ dict_globals = {}
+ node = ast.fix_missing_locations(node)
+ exe = compile(node, filename="<ast>", mode="eval")
+ return eval(exe, dict_globals) # pylint: disable=eval-used
+
+
+def _eval_op(
+ op: doc.AST,
+ values: List[Any],
+):
+ op_type = type(op) # pylint: disable=protected-access
+ for i, v in enumerate(values):
+ v_type = getattr(type(v), "_dispatch_type", None)
+ if v_type is None:
+ continue
+ f = dispatch.get_op(ty=v_type, op=op_type, operand_index=i, default=None)
+ if f is not None:
+ return f(*values)
+ return DEFAULT_OP[op_type](*values)
+
+
+def _eval_assign(
+ target: doc.expr,
+ source: Any,
+) -> Dict[str, Any]:
+ target = doc.from_doc(target)
+ assert isinstance(target, ast.expr)
+ RHS_VAR_NAME = "__tvm_rhs_var__" # pylint: disable=invalid-name
+ rhs_var_name = RHS_VAR_NAME
+ dict_locals = {rhs_var_name: source}
+ mod = ast.fix_missing_locations(
+ ast.Module(
+ body=[
+ ast.Assign(
+ targets=[target],
+ value=ast.Name(
+ id=rhs_var_name,
+ ctx=ast.Load(),
+ ),
+ )
+ ],
+ type_ignores=[],
+ )
+ )
+ exe = compile(mod, filename="<ast>", mode="exec")
+ exec(exe, {}, dict_locals) # pylint: disable=exec-used
+ del dict_locals[rhs_var_name]
+ return dict_locals
diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py
new file mode 100644
index 0000000000..e26324262f
--- /dev/null
+++ b/python/tvm/script/parser/core/parser.py
@@ -0,0 +1,273 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""The core parser"""
+from collections import defaultdict
+from contextlib import contextmanager
+from typing import Any, Callable, Dict, List, Optional, Set, Union
+
+from tvm.error import DiagnosticError
+
+from . import dispatch, doc
+from .diagnostics import Diagnostics, Source
+from .evaluator import eval_assign, eval_expr
+
+DEFAULT_VISIT = {
+ "Interactive",
+ "Module",
+ "Expression",
+ "Pass",
+}
+
+
+def _deferred(f: Callable[[], None]):
+ @contextmanager
+ def context():
+ try:
+ yield
+ finally:
+ f()
+
+ return context()
+
+
+class VarTableFrame:
+ vars: Set[str]
+
+ def __init__(self):
+ self.vars = set()
+
+ def add(self, var: str):
+ if var in self.vars:
+ raise ValueError(f"Variable {var} already defined in current scope")
+ self.vars.add(var)
+
+ def pop_all(self, fn_pop: Callable[[str], None]):
+ for var in self.vars:
+ fn_pop(var)
+ self.vars.clear()
+
+
+class VarTable:
+
+ frames: List[VarTableFrame]
+ name2value: Dict[str, List[Any]]
+
+ def __init__(self):
+ self.frames = []
+ self.name2value = defaultdict(list)
+
+ def with_frame(self):
+ def pop_frame():
+ frame = self.frames.pop()
+ frame.pop_all(lambda name: self.name2value[name].pop())
+
+ self.frames.append(VarTableFrame())
+ return _deferred(pop_frame)
+
+ def add(self, var: str, value: Any):
+ self.frames[-1].add(var)
+ self.name2value[var].append(value)
+
+ def get(self) -> Dict[str, Any]:
+ return {key: values[-1] for key, values in self.name2value.items() if values}
+
+ def exist(self, value: Any):
+ for v in self.name2value.values():
+ if v is value:
+ return True
+ return False
+
+
+def _dispatch_wrapper(func: dispatch.ParseMethod) -> dispatch.ParseMethod:
+ def _wrapper(self: "Parser", node: doc.AST) -> None:
+ try:
+ return func(self, node)
+ except DiagnosticError:
+ raise
+ except Exception as e: # pylint: disable=broad-except,invalid-name
+ self.report_error(node, str(e))
+ raise
+
+ return _wrapper
+
+
+def _dispatch(self: "Parser", type_name: str) -> dispatch.ParseMethod:
+ for token in [self.dispatch_tokens[-1], "default"]:
+ func = dispatch.get(token=token, type_name=type_name, default=None)
+ if func is not None:
+ return _dispatch_wrapper(func)
+ return _dispatch_wrapper(lambda self, node: self.generic_visit(node))
+
+
+class Parser(doc.NodeVisitor):
+ """The TVMScript parser"""
+
+ diag: Diagnostics
+ dispatch_tokens: List[str]
+ var_table: VarTable
+
+ def __init__(self, source: Source) -> None:
+ self.diag = Diagnostics(source)
+ self.dispatch_tokens = ["default"]
+ self.var_table = VarTable()
+
+ def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any:
+ if extra_vars is None:
+ extra_vars = {}
+ with self.var_table.with_frame():
+ for k, v in extra_vars.items():
+ self.var_table.add(k, v)
+ node = self.diag.source.as_ast()
+ self.visit(node)
+
+ def with_dispatch_token(self, token: str):
+ def pop_token():
+ self.dispatch_tokens.pop()
+
+ self.dispatch_tokens.append(token)
+ return _deferred(pop_token)
+
+ def eval_expr(
+ self,
+ node: Union[doc.Expression, doc.expr],
+ extra_vars: Optional[Dict[str, Any]] = None,
+ ) -> Any:
+ var_values = self.var_table.get()
+ if extra_vars is not None:
+ for k, v in extra_vars.items():
+ var_values[k] = v
+ return eval_expr(self, node, var_values)
+
+ def _duplicate_lhs_check(self, target: doc.expr) -> Union[bool, Set[str]]:
+ if isinstance(target, (doc.Tuple, doc.List)):
+ vars: Set[str] = set() # pylint: disable=redefined-builtin
+ for i in target.elts:
+ res = self._duplicate_lhs_check(i)
+ if isinstance(res, bool) and res:
+ return True
+ assert isinstance(res, set)
+ if vars & res:
+ return True
+ vars = vars.union(res)
+ return vars
+ elif isinstance(target, doc.Name):
+ return {target.id}
+ else:
+ self.report_error(target, "Invalid type in assign statement")
+ raise NotImplementedError
+
+ def eval_assign(
+ self,
+ target: doc.expr,
+ source: Any,
+ bind_value: Callable[["Parser", doc.expr, str, Any], Any],
+ ) -> Dict[str, Any]:
+ if self._duplicate_lhs_check(target) is True:
+ self.report_error(target, "Duplicate vars assigned.")
+ var_values = eval_assign(self, target, source)
+ for k, v in var_values.items():
+ var = bind_value(self, target, k, v)
+ self.var_table.add(k, var)
+ return var_values
+
+ def report_error(self, node: doc.AST, msg: str) -> None: # pylint: disable=no-self-use
+ self.diag.error(node, msg)
+
+ def visit(self, node: doc.AST) -> None:
+ if isinstance(node, (list, tuple)):
+ for item in node:
+ self.visit(item)
+ return
+ if not isinstance(node, doc.AST):
+ return
+ name = node.__class__.__name__.split(".")[-1]
+ if name in DEFAULT_VISIT:
+ func = self.generic_visit
+ else:
+ func = getattr(self, "visit_" + name, None)
+ if func is None:
+ raise NotImplementedError(f"Visitor of AST node is not implemented: {name}")
+ try:
+ func(node)
+ except Exception as e: # pylint: disable=broad-except,invalid-name
+ self.report_error(node, str(e))
+
+ def visit_body(self, node: List[doc.stmt]) -> Any:
+ for stmt in node:
+ self.visit(stmt)
+
+ def visit_tvm_annotation(self, node: doc.expr) -> Any:
+ return _dispatch(self, "tvm_annotation")(self, node)
+
+ def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=invalid-name
+ if not node.decorator_list:
+ self.report_error(node, "Function must be decorated")
+ # TODO: only the last decorator is parsed
+ decorator = self.eval_expr(node.decorator_list[-1])
+ if not hasattr(decorator, "dispatch_token"):
+ self.report_error(node, "The parser does not understand the decorator")
+ token = decorator.dispatch_token
+ func = dispatch.get(token=token, type_name="FunctionDef", default=None)
+ if func is None:
+ self.report_error(node, "The parser does not understand the decorator")
+ try:
+ func(self, node)
+ except Exception as e: # pylint: disable=broad-except,invalid-name
+ self.report_error(node, str(e))
+
+ def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name
+ func = dispatch.get(token="ir", type_name="ClassDef", default=None)
+ if func is None:
+ self.report_error(node, "The parser does not understand the decorator")
+ try:
+ func(self, node)
+ except Exception as e: # pylint: disable=broad-except,invalid-name
+ self.report_error(node, str(e))
+
+ def visit_arguments(self, node: doc.arguments) -> Any:
+ return _dispatch(self, "arguments")(self, node)
+
+ def visit_For(self, node: doc.For) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "For")(self, node)
+
+ def visit_While(self, node: doc.While) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "While")(self, node)
+
+ def visit_With(self, node: doc.With) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "With")(self, node)
+
+ def visit_Assign(self, node: doc.Assign) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "Assign")(self, node)
+
+ def visit_Expr(self, node: doc.Expr) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "Expr")(self, node)
+
+ def visit_If(self, node: doc.If) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "If")(self, node)
+
+ def visit_AnnAssign(self, node: doc.AnnAssign) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "AnnAssign")(self, node)
+
+ def visit_AugAssign(self, node: doc.AugAssign) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "AugAssign")(self, node)
+
+ def visit_Assert(self, node: doc.Assert) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "Assert")(self, node)
+
+ def visit_Return(self, node: doc.Return) -> Any: # pylint: disable=invalid-name
+ return _dispatch(self, "Return")(self, node)
diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/parser/core/utils.py
similarity index 54%
rename from python/tvm/script/tir/prim_func.py
rename to python/tvm/script/parser/core/utils.py
index 923eb97d27..aae45fe6ff 100644
--- a/python/tvm/script/tir/prim_func.py
+++ b/python/tvm/script/parser/core/utils.py
@@ -14,32 +14,23 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""TVM Script Interface for PrimFunc"""
-
+# pylint: disable=missing-docstring
import inspect
-from typing import Callable
-
-from tvm.tir.function import PrimFunc
-from ..parser import from_source
-
+from typing import Any, Callable, Dict
-def prim_func(input_func: Callable) -> PrimFunc:
- """Decorate a python function as tvm script.
- Parameters
- ----------
- func : input_func
- The function to be parsed.
+def inspect_function_capture(func: Callable) -> Dict[str, Any]:
+ captured = {
+ **inspect.getclosurevars(func).nonlocals,
+ **func.__globals__, # type: ignore
+ }
+ return captured
- Returns
- -------
- output : PrimFunc
- The result functions.
- """
- if inspect.isfunction(input_func):
- result = from_source(input_func)
- result.__name__ = input_func.__name__
- result.__qualname__ = input_func.__qualname__
- return result
- raise TypeError("Only function definitions are supported.")
+def inspect_class_capture(cls: type) -> Dict[str, Any]:
+ result: Dict[str, Any] = {}
+ for _, v in cls.__dict__.items():
+ if inspect.isfunction(v):
+ func_vars = inspect_function_capture(v)
+ result.update(**func_vars)
+ return result
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/ir/__init__.py
similarity index 85%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/ir/__init__.py
index 555659d0c5..4cbd9910a2 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/ir/__init__.py
@@ -14,8 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+# pylint: disable=missing-docstring
+from . import parser as _parser
+from .entry import ir_module
-from . import tir
-
-from .parser import ir_module, from_source
+__all__ = ["ir_module"]
diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/parser/ir/entry.py
similarity index 63%
copy from python/tvm/script/tir/__init__.py
copy to python/tvm/script/parser/ir/entry.py
index 2f2b4bbc25..3c1e4de5a7 100644
--- a/python/tvm/script/tir/__init__.py
+++ b/python/tvm/script/parser/ir/entry.py
@@ -14,18 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""TVMScript for TIR"""
+# pylint: disable=missing-docstring
+import inspect
+from typing import Type
-# Type system
-from .ty import void, boolean, handle, Ptr, Tuple, Buffer
+from tvm.ir import IRModule
-from .prim_func import prim_func
+from .._core import parse, utils
-# add all floating point and integer datatypes to the module
-for _dtype in ["float", "uint", "int"]:
- for _size in ["8", "16", "32", "64"]:
- for _lanes in ["", "x4", "x8", "x16", "x32"]:
- from . import ty
- _name = _dtype + _size + _lanes
- globals()[_name] = getattr(ty, _name)
+def ir_module(f: Type) -> IRModule:
+ if not inspect.isclass(f):
+ raise TypeError(f"Expect a class, but got: {f}")
+
+ return parse(f, utils.inspect_class_capture(f))
+
+
+setattr(ir_module, "dispatch_token", "ir")
diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/parser/ir/parser.py
similarity index 55%
rename from python/tvm/script/tir/__init__.py
rename to python/tvm/script/parser/ir/parser.py
index 2f2b4bbc25..8871d3b415 100644
--- a/python/tvm/script/tir/__init__.py
+++ b/python/tvm/script/parser/ir/parser.py
@@ -14,18 +14,24 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""TVMScript for TIR"""
+# pylint: disable=missing-docstring
+from ...ir_builder import ir as I
+from .._core import Parser, dispatch, doc
-# Type system
-from .ty import void, boolean, handle, Ptr, Tuple, Buffer
-from .prim_func import prim_func
+@dispatch.register(token="ir", type_name="ClassDef")
+def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
+ with self.var_table.with_frame():
+ with I.ir_module():
+ with self.with_dispatch_token("ir"):
+ self.visit_body(node.body)
-# add all floating point and integer datatypes to the module
-for _dtype in ["float", "uint", "int"]:
- for _size in ["8", "16", "32", "64"]:
- for _lanes in ["", "x4", "x8", "x16", "x32"]:
- from . import ty
- _name = _dtype + _size + _lanes
- globals()[_name] = getattr(ty, _name)
+@dispatch.register(token="ir", type_name="Assign")
+def _visit_assign(_self: Parser, _node: doc.Assign) -> None:
+ pass
+
+
+@dispatch.register(token="ir", type_name="Expr")
+def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
+ pass
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/tir/__init__.py
similarity index 71%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/tir/__init__.py
index 555659d0c5..930764f73d 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/tir/__init__.py
@@ -14,8 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+# pylint: disable=missing-docstring
+from ...ir_builder.tir import * # pylint: disable=redefined-builtin
+from ...ir_builder.tir import ir as _tir
+from . import operation as _operation
+from . import parser as _parser
+from .entry import Buffer, Ptr, prim_func
-from . import tir
-
-from .parser import ir_module, from_source
+__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func"]
diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py
new file mode 100644
index 0000000000..db4e2dd9a3
--- /dev/null
+++ b/python/tvm/script/parser/tir/entry.py
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import inspect
+from typing import Callable, Union
+
+from tvm.tir import Buffer, PrimFunc
+
+from ...ir_builder.tir import buffer_decl, ptr
+from .._core import parse, utils
+
+
+def _is_defined_in_class(frames):
+ if len(frames) > 2:
+ maybe_class_frame = frames[2]
+ statement_list = maybe_class_frame[4]
+ if statement_list is None:
+ return False
+ first_statement = statement_list[0]
+ line = first_statement.strip()
+ if line.startswith("class "):
+ return True
+ if line.startswith("@") and "ir_module" in line:
+ return True
+ return False
+
+
+def prim_func(f: Callable) -> Union[PrimFunc, Callable]:
+ if not inspect.isfunction(f):
+ raise TypeError(f"Expect a function, but got: {f}")
+ if _is_defined_in_class(inspect.stack()):
+ return f
+ return parse(f, utils.inspect_function_capture(f))
+
+
+setattr(prim_func, "dispatch_token", "tir")
+
+
+class BufferProxy:
+ def __call__(
+ self,
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="global",
+ align=0,
+ offset_factor=0,
+ buffer_type="",
+ axis_separators=None,
+ ) -> Buffer:
+ return buffer_decl(
+ shape,
+ dtype=dtype,
+ data=data,
+ strides=strides,
+ elem_offset=elem_offset,
+ scope=scope,
+ align=align,
+ offset_factor=offset_factor,
+ buffer_type=buffer_type,
+ axis_separators=axis_separators,
+ )
+
+ def __getitem__(self, keys) -> Buffer:
+ if not isinstance(keys, tuple):
+ return self(keys)
+ if len(keys) >= 2 and not isinstance(keys[1], str):
+ return self(keys)
+ return self(*keys) # pylint: disable=no-member # type: ignore
+
+
+class PtrProxy:
+ def __call__(self, dtype, storage_scope="global"):
+ if callable(dtype):
+ dtype = dtype().dtype
+ return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore
+
+ def __getitem__(self, keys):
+ if not isinstance(keys, tuple):
+ return self(keys)
+ return self(*keys)
+
+
+Buffer = BufferProxy() # pylint: disable=invalid-name
+Ptr = PtrProxy() # pylint: disable=invalid-name
diff --git a/python/tvm/script/parser/tir/operation.py b/python/tvm/script/parser/tir/operation.py
new file mode 100644
index 0000000000..716525b984
--- /dev/null
+++ b/python/tvm/script/parser/tir/operation.py
@@ -0,0 +1,84 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from typing import Type
+
+from tvm import tir
+from tvm.tir import IntImm
+
+from .._core import OpMethod, doc, register_op
+
+
+def _register_expr_op(ty: Type): # pylint: disable=invalid-name
+ ty._dispatch_type = ty # pylint: disable=protected-access
+
+ def _and(a, b):
+ if isinstance(a, bool):
+ a = IntImm("bool", a)
+ if isinstance(b, bool):
+ b = IntImm("bool", b)
+ return tir.And(a, b)
+
+ def _or(a, b):
+ if isinstance(a, bool):
+ a = IntImm("bool", a)
+ if isinstance(b, bool):
+ b = IntImm("bool", b)
+ return tir.Or(a, b)
+
+ def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name
+ register_op(ty, op, i)(m)
+
+ for i in [0, 1]:
+ # Case 1. binop
+ r(doc.Add, i, tir.Add)
+ r(doc.Sub, i, tir.Sub)
+ r(doc.Mult, i, tir.Mul)
+ r(doc.Div, i, tir.Div)
+ r(doc.FloorDiv, i, tir.FloorDiv)
+ r(doc.Mod, i, tir.FloorMod)
+ r(doc.LShift, i, lambda a, b: a << b)
+ r(doc.RShift, i, lambda a, b: a >> b)
+ r(doc.BitOr, i, lambda a, b: a | b)
+ r(doc.BitXor, i, lambda a, b: a ^ b)
+ r(doc.BitAnd, i, lambda a, b: a & b)
+ # doc.MatMult <-- not implemented
+ # doc.Pow <-- not implemented
+ # Case 2. cmpop
+ r(doc.Eq, i, tir.EQ)
+ r(doc.NotEq, i, tir.NE)
+ r(doc.Lt, i, tir.LT)
+ r(doc.LtE, i, tir.LE)
+ r(doc.Gt, i, tir.GT)
+ r(doc.GtE, i, tir.GE)
+ # doc.Is <-- not implemented
+ # doc.IsNot <-- not implemented
+ # doc.In <-- not implemented
+ # doc.NotIn <-- not implemented
+ # Case 3. boolop
+ r(doc.And, i, _and)
+ r(doc.Or, i, _or)
+ for i in [0]:
+ # Case 4. unaryop
+ r(doc.Invert, i, lambda a: ~a)
+ r(doc.Not, i, tir.Not)
+ r(doc.UAdd, i, lambda a: +a)
+ r(doc.USub, i, lambda a: -a)
+
+
+_register_expr_op(tir.PrimExpr)
+_register_expr_op(tir.IterVar)
diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py
new file mode 100644
index 0000000000..351238c06f
--- /dev/null
+++ b/python/tvm/script/parser/tir/parser.py
@@ -0,0 +1,268 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import contextlib
+from functools import partial
+from typing import Any
+
+from tvm.ir import PrimType
+from tvm.tir import Buffer, IterVar, PrimExpr, Var
+
+from ...ir_builder import tir as T
+from ...ir_builder.base import IRBuilderFrame as Frame
+from ...ir_builder.base import name
+from .._core import Parser, dispatch, doc
+
+
+def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
+ if isinstance(value, (list, tuple)):
+ for i, v in enumerate(value):
+ bind_with_value(self, node, f"{var_name}_{i}", v)
+ return value
+ elif isinstance(value, (Buffer, Var)):
+ name(var_name, value)
+ return value
+ else:
+ self.report_error(node, f"Do not know how to bind type: {type(value)} in with statement")
+ raise NotImplementedError
+
+
+def bind_for_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
+ if isinstance(value, (list, tuple)):
+ for i, v in enumerate(value):
+ bind_with_value(self, node, f"{var_name}_{i}", v)
+ return value
+ elif isinstance(value, Var):
+ name(var_name, value)
+ return value
+ else:
+ self.report_error(node, f"Do not know how to bind type: {type(value)} in for statement")
+ raise NotImplementedError
+
+
+def bind_assign_value(self: Parser, _node: doc.expr, var_name: str, value: Any) -> Any:
+ if isinstance(value, T.inline):
+ return value.value
+ elif isinstance(value, (list, tuple)):
+ for i, v in enumerate(value):
+ bind_with_value(self, _node, f"{var_name}_{i}", v)
+ return value
+ elif isinstance(value, Frame):
+ value.add_callback(partial(value.__exit__, None, None, None))
+ res = value.__enter__()
+ name(var_name, res)
+ return res
+ elif isinstance(value, (Buffer, IterVar)) or (
+ isinstance(value, Var) and not self.var_table.exist(value)
+ ):
+ name(var_name, value)
+ return value
+ elif isinstance(value, PrimExpr):
+ var = T.var(value.dtype)
+ name(var_name, var)
+ frame = T.let(var, value)
+ frame.add_callback(partial(frame.__exit__, None, None, None))
+ frame.__enter__()
+ return var
+ return value
+
+
+@dispatch.register(token="tir", type_name="For")
+def visit_for(self: Parser, node: doc.For) -> None:
+ for_frame = self.eval_expr(node.iter)
+ if not isinstance(for_frame, T.frame.ForFrame):
+ self.report_error(
+ node.iter,
+ "Expect the for loop to be one of the following: "
+ "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding",
+ )
+ with self.var_table.with_frame():
+ with for_frame as iters:
+ self.eval_assign(target=node.target, source=iters, bind_value=bind_for_value)
+ self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="While")
+def visit_while(self: Parser, node: doc.While) -> None:
+ with self.var_table.with_frame():
+ cond = self.eval_expr(node.test)
+ with T.While(cond):
+ self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="Assign")
+def visit_assign(self: Parser, node: doc.Assign) -> None:
+ if len(node.targets) != 1:
+ self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.")
+ lhs = node.targets[0]
+ rhs = self.eval_expr(node.value)
+ if isinstance(lhs, doc.Subscript):
+ if isinstance(lhs.slice, doc.Tuple):
+ indices = []
+ for index in lhs.slice.elts:
+ indices.append(self.eval_expr(index))
+ else:
+ indices = [self.eval_expr(lhs.slice)]
+ T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
+ else:
+ self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
+
+
+@dispatch.register(token="tir", type_name="AugAssign")
+def visit_aug_assign(self: Parser, node: doc.AugAssign) -> None:
+ lhs_pos = (
+ node.target.lineno,
+ node.target.col_offset,
+ node.target.end_lineno,
+ node.target.end_col_offset,
+ )
+ rhs_pos = (
+ node.value.lineno,
+ node.value.col_offset,
+ node.value.end_lineno,
+ node.value.end_col_offset,
+ )
+ node.target.ctx = doc.Load(*lhs_pos)
+ with self.var_table.with_frame():
+ lhs_name = "__tvm_tmp_value_aug_assign_lhs"
+ rhs_name = "__tvm_tmp_value_aug_assign_rhs"
+ lhs_expr = self.eval_expr(node.target)
+ rhs_expr = self.eval_expr(node.value)
+ self.var_table.add(lhs_name, lhs_expr)
+ self.var_table.add(rhs_name, rhs_expr)
+ op = doc.BinOp(
+ doc.Name(lhs_name, doc.Load(*lhs_pos), *lhs_pos),
+ node.op,
+ doc.Name(rhs_name, doc.Load(*rhs_pos), *rhs_pos),
+ *lhs_pos,
+ )
+ rhs = self.eval_expr(op)
+ lhs = node.target
+ lhs.ctx = doc.Store(*lhs_pos)
+ if isinstance(lhs, doc.Subscript):
+ if isinstance(lhs.slice, doc.Tuple):
+ indices = []
+ for index in lhs.slice.elts:
+ indices.append(self.eval_expr(index))
+ else:
+ indices = [self.eval_expr(lhs.slice)]
+ T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
+ else:
+ self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
+
+
+@dispatch.register(token="tir", type_name="AnnAssign")
+def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None:
+ lhs = node.target
+ rhs = self.eval_expr(node.value)
+ ann_var = self.visit_tvm_annotation(node.annotation)
+ if not isinstance(ann_var, Var):
+ self.report_error(node.annotation, "Annotation should be Var")
+ self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value)
+ frame = T.let(ann_var, rhs)
+ frame.add_callback(partial(frame.__exit__, None, None, None))
+ frame.__enter__()
+
+
+@dispatch.register(token="tir", type_name="With")
+def visit_with(self: Parser, node: doc.With) -> None:
+ with contextlib.ExitStack() as stack:
+ stack.enter_context(self.var_table.with_frame())
+ for item in node.items:
+ frame = self.eval_expr(item.context_expr)
+ if not isinstance(frame, Frame):
+ self.report_error(
+ item.context_expr, "Invalid context expression in the with-statement."
+ )
+ rhs = stack.enter_context(frame)
+ if item.optional_vars is not None:
+ self.eval_assign(target=item.optional_vars, source=rhs, bind_value=bind_with_value)
+ self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="FunctionDef")
+def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
+ with self.var_table.with_frame():
+ self.var_table.add("range", T.serial)
+ with T.prim_func():
+ T.func_name(node.name)
+ if node.returns is not None:
+ ret_type = self.eval_expr(node.returns)
+ if callable(ret_type):
+ ret_type = PrimType(ret_type().dtype)
+ T.func_ret(ret_type)
+ with self.with_dispatch_token("tir"):
+ self.visit(node.args)
+ self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="arguments")
+def visit_arguments(self: Parser, node: doc.arguments) -> None:
+ # TODO: handle different types of arguments:
+ # - vararg: arg | None
+ # - kwonlyargs: list[arg]
+ # - kw_defaults: list[expr | None]
+ # - kwarg: arg | None
+ # - defaults: list[expr]
+ # - posonlyargs: list[arg]
+ arg: doc.arg
+ for arg in node.args:
+ if arg.annotation is None:
+ self.report_error(arg, "Type annotation is required for function parameters.")
+ param = T.arg(arg.arg, self.visit_tvm_annotation(arg.annotation))
+ self.var_table.add(arg.arg, param)
+
+
+@dispatch.register(token="tir", type_name="tvm_annotation")
+def visit_tvm_annotation(self: Parser, node: doc.expr):
+ annotation = self.eval_expr(node)
+ if callable(annotation):
+ annotation = annotation()
+ return annotation
+
+
+@dispatch.register(token="tir", type_name="Expr")
+def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
+ res = self.eval_expr(node.value)
+ if isinstance(res, Frame):
+ res.add_callback(partial(res.__exit__, None, None, None))
+ res.__enter__()
+
+
+@dispatch.register(token="tir", type_name="If")
+def visit_if(self: Parser, node: doc.If) -> None:
+ with self.var_table.with_frame():
+ with T.If(self.eval_expr(node.test)):
+ with T.Then():
+ self.visit_body(node.body)
+ if node.orelse:
+ with T.Else():
+ self.visit_body(node.orelse)
+
+
+@dispatch.register(token="tir", type_name="Assert")
+def visit_assert(self: Parser, node: doc.Assert) -> None:
+ cond = self.eval_expr(node.test)
+ msg = self.eval_expr(node.msg)
+ frame = T.Assert(cond, msg)
+ frame.add_callback(partial(frame.__exit__, None, None, None))
+ frame.__enter__()
+
+
+@dispatch.register(token="tir", type_name="Return")
+def visit_return(self: Parser, node: doc.Return) -> None:
+ self.report_error(node, "Return is not allowed.")
diff --git a/python/tvm/script/registry.py b/python/tvm/script/registry.py
deleted file mode 100644
index e7d90dd515..0000000000
--- a/python/tvm/script/registry.py
+++ /dev/null
@@ -1,62 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser Function Registry """
-# pylint: disable=inconsistent-return-statements, relative-beyond-top-level, import-outside-toplevel
-import types
-from typing import Union, Callable, Dict, Optional, Any
-
-
-class Registry(object):
- """Registration map
- All these maps are static
- """
-
- registrations: Dict[str, type] = dict()
-
- @staticmethod
- def lookup(name: str) -> Optional[Any]:
- if name in Registry.registrations:
- # every time we create a new handler
- # since we may want to keep some local info inside it
- return Registry.registrations[name]()
- return None
-
-
-def register(inputs: Union[Callable, type]) -> type:
- """Register Intrin/ScopeHandler/SpecialStmt"""
- registration: type
- if isinstance(inputs, types.FunctionType):
- # is function
- from .tir.intrin import Intrin
-
- def create_new_intrin(func) -> type:
- class NewIntrin(Intrin):
- def __init__(self):
- super().__init__(func)
-
- return NewIntrin
-
- registration = create_new_intrin(inputs)
- elif isinstance(inputs, type):
- # is class
- registration = inputs
- else:
- raise ValueError()
-
- key: str = registration().signature()[0]
- Registry.registrations[key] = registration
- return registration
diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi
deleted file mode 100644
index a64eed055a..0000000000
--- a/python/tvm/script/tir/__init__.pyi
+++ /dev/null
@@ -1,487 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=redefined-builtin
-from typing import (
- Any,
- Callable,
- ContextManager,
- Dict,
- Iterable,
- Optional,
- Tuple,
- Union,
- Sequence,
- List,
- Mapping,
- overload,
-)
-from numbers import Number
-import builtins
-
-from tvm.tir.function import PrimFunc
-from tvm.tir import Range
-from tvm.runtime import Object
-from tvm.target import Target
-from .node import BufferSlice
-
-"""
-redefine types
-"""
-
-class PrimExpr:
- def __init__(self: PrimExpr) -> None: ...
- @overload
- def __add__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
- @overload
- def __add__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
- @overload
- def __sub__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
- @overload
- def __sub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
- @overload
- def __mul__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
- @overload
- def __mul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
- @overload
- def __div__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
- @overload
- def __div__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
- def __mod__(self: PrimExpr, other: Union[int, float, PrimExpr]) -> PrimExpr: ...
- def __radd__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
- def __rsub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
- def __rmul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
- def __rdiv__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
- def __floordiv__(self: PrimExpr, other: Union[int, float, PrimExpr]) -> PrimExpr: ...
- def __index__(self: PrimExpr) -> int: ... # so range doesn't complain
-
-class Var(PrimExpr): ...
-class IterVar(Var): ...
-
-class Buffer:
- @overload
- def __getitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int, slice]]) -> PrimExpr: ...
- @overload
- def __getitem__(self: Buffer, pos: Union[PrimExpr, int, slice]) -> PrimExpr: ...
- @overload
- def __setitem__(
- self: Buffer, pos: Sequence[Union[PrimExpr, int, slice]], value: PrimExpr
- ) -> None: ...
- @overload
- def __setitem__(self: Buffer, pos: Union[PrimExpr, int, slice], value: PrimExpr) -> None: ...
- @property
- def data(self: Buffer) -> Ptr: ...
-
-"""
-Intrinsic
-"""
-
-def min_value(dtype: str) -> PrimExpr: ...
-def max_value(dtype: str) -> PrimExpr: ...
-def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
-def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
-def ceildiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
-def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
-def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
-def abs(x: PrimExpr) -> PrimExpr: ...
-def load(
- dtype: str, var: Var, index: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = None
-) -> PrimExpr: ...
-def cast(value: PrimExpr, dtype: str) -> PrimExpr: ...
-def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ...
-def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ...
-def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ...
-def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
-def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
-def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ...
-def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
-def evaluate(value: PrimExpr) -> None: ...
-def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
-def vectorlow(value: PrimExpr, dtype: str) -> PrimExpr: ...
-def vectorhigh(value: PrimExpr, dtype: str) -> PrimExpr: ...
-def store(
- var: Var, index: PrimExpr, value: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = True
-) -> None: ...
-def comm_reducer(lambda_io: Callable[[Any, Any], Any], identities: List[PrimExpr]) -> PrimExpr: ...
-def llvm_lookup_intrinsic_id(name: str) -> PrimExpr: ...
-def preflattened_buffer(
- buf: Buffer,
- shape: Sequence[PrimExpr],
- dtype: str = "float32",
- data: Optional[Ptr] = None,
- strides: Optional[Sequence[int]] = None,
- elem_offset: Optional[int] = None,
- scope: str = "global",
- align: int = -1,
- offset_factor: int = 0,
- buffer_type: str = "default",
-) -> Buffer: ...
-
-"""
-Intrinsics - tvm builtin
-"""
-
-def tvm_thread_allreduce(
- *freduceargs: Union[PrimExpr, builtins.bool, Ptr], dtype: str
-) -> PrimExpr: ...
-
-"""
-Unary operator
-Note that any intrinsics not registered in script.tir.intrin
-should add "dtype" as an argument. This is different from their
-definition but intentional.
-"""
-
-def exp(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def exp2(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def exp10(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def erf(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def tanh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def sigmoid(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def log(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def log2(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def log10(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def log1p(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def tan(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def cos(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def cosh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def acos(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def acosh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def sin(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def sinh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def asin(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def asinh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def atan(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def atanh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def atan2(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def sqrt(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def rsqrt(x: PrimExpr, dtype: str) -> PrimExpr: ...
-
-"""
-special_stmt - Buffers
-"""
-
-def match_buffer(
- param: Union[Var, BufferSlice],
- shape: Sequence[Union[PrimExpr, int]],
- dtype: str = "float32",
- data: Var = None,
- strides: Optional[Sequence[int]] = None,
- elem_offset: Optional[int] = None,
- scope: str = "global",
- align: int = -1,
- offset_factor: int = 0,
- buffer_type: str = "default",
- axis_separators: Optional[List[int]] = None,
-) -> Buffer: ...
-def decl_buffer(
- shape: Sequence[Union[PrimExpr, int]],
- dtype: str = "float32",
- data: Var = None,
- strides: Optional[Sequence[int]] = None,
- elem_offset: Optional[int] = None,
- scope: str = "global",
- align: int = -1,
- offset_factor: int = 0,
- buffer_type: str = "default",
- axis_separators: Optional[List[int]] = None,
-) -> Buffer: ...
-def buffer_decl(
- shape: Sequence[Union[PrimExpr, int]],
- dtype: str = "float32",
- data: Var = None,
- strides: Optional[Sequence[int]] = None,
- elem_offset: Optional[int] = None,
- scope: str = "global",
- align: int = -1,
- offset_factor: int = 0,
- buffer_type: str = "default",
- axis_separators: Optional[List[int]] = None,
-) -> Buffer: ...
-def alloc_buffer(
- shape: Sequence[Union[PrimExpr, int]],
- dtype: str = "float32",
- data: Var = None,
- strides: Optional[Sequence[int]] = None,
- elem_offset: Optional[int] = None,
- scope: str = "global",
- align: int = -1,
- offset_factor: int = 0,
- buffer_type: str = "default",
- axis_separators: Optional[List[int]] = None,
-) -> Buffer: ...
-
-"""
-special_stmt - Reads/Writes
-"""
-
-@overload
-def reads(read_regions: List[BufferSlice]) -> None: ...
-@overload
-def reads(*read_regions: BufferSlice) -> None: ...
-@overload
-def writes(write_region: List[BufferSlice]) -> None: ...
-@overload
-def writes(*write_region: BufferSlice) -> None: ...
-def block_attr(attrs: Mapping[str, Object]) -> None: ...
-
-"""
-special_stmt - Axis
-"""
-
-class axis:
- @overload
- @staticmethod
- def spatial(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
- @overload
- @staticmethod
- def spatial(
- dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
- ) -> IterVar: ...
- @overload
- @staticmethod
- def S(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
- @overload
- @staticmethod
- def S(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ...
- @overload
- @staticmethod
- def reduce(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
- @overload
- @staticmethod
- def reduce(
- dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
- ) -> IterVar: ...
- @overload
- @staticmethod
- def R(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
- @overload
- @staticmethod
- def R(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ...
- @overload
- @staticmethod
- def scan(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
- @overload
- @staticmethod
- def scan(
- dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
- ) -> IterVar: ...
- @overload
- @staticmethod
- def opaque(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
- @overload
- @staticmethod
- def opaque(
- dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
- ) -> IterVar: ...
- @staticmethod
- def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
-
-def get_axis(begin: PrimExpr, end: PrimExpr, iter_type: int) -> IterVar: ...
-
-"""
-special_stmt - Annotations
-"""
-
-def buffer_var(dtype: str, storage_scope: str) -> Var: ...
-def func_attr(attrs: Mapping[str, Union[Object, str, bool, int, float]]) -> None: ...
-def prim_func(input_func: Callable) -> PrimFunc: ...
-
-"""
-special_stmt - Threads and Bindings
-"""
-
-def env_thread(env_name: str) -> IterVar: ...
-def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
-
-"""
-Scope handler
-"""
-
-class block(ContextManager):
- def __init__(self, name_hint: str = "") -> None: ...
- def __enter__(self) -> Sequence[IterVar]: ...
-
-class init(ContextManager):
- def __init__(self) -> None: ...
-
-class let(ContextManager):
- def __init__(self, var: Var, value: PrimExpr) -> None: ...
-
-def where(cond: PrimExpr) -> None: ...
-def allocate(
- extents: List[PrimExpr],
- dtype: str,
- scope: str,
- condition: Union[PrimExpr, builtins.bool] = True,
- annotations: Optional[Mapping[str, Object]] = None,
-) -> Buffer: ...
-def launch_thread(env_var: Var, extent: Union[int, PrimExpr]) -> Var: ...
-def realize(
- buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True
-) -> None: ...
-def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
-def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr: ...
-
-"""
-Scope handler - Loops
-"""
-
-@overload
-def serial(
- begin: Union[PrimExpr, int],
- end: Union[PrimExpr, int],
- annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def serial(
- end: Union[PrimExpr, int],
- annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def parallel(
- begin: Union[PrimExpr, int],
- end: Union[PrimExpr, int],
- annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def parallel(
- end: Union[PrimExpr, int],
- annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def vectorized(
- begin: Union[PrimExpr, int],
- end: Union[PrimExpr, int],
- annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def vectorized(
- end: Union[PrimExpr, int],
- annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def unroll(
- begin: Union[PrimExpr, int],
- end: Union[PrimExpr, int],
- annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def unroll(
- end: Union[PrimExpr, int],
- annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def thread_binding(
- begin: Union[PrimExpr, int],
- end: Union[PrimExpr, int],
- thread: str,
- annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def thread_binding(
- end: Union[PrimExpr, int],
- thread: str,
- annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def for_range(
- begin: Union[PrimExpr, int],
- end: Union[PrimExpr, int],
- annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def for_range(
- end: Union[PrimExpr, int],
- annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-def grid(*extents: Union[PrimExpr, int]) -> Iterable[Sequence[IterVar]]: ...
-
-"""
-ty - redefine types
-"""
-
-class boolean: ...
-
-class handle(Var):
- @overload
- def __getitem__(self: handle, pos: Sequence[Union[int, PrimExpr, slice]]) -> Buffer: ...
- @overload
- def __getitem__(self: handle, pos: Union[int, PrimExpr, slice]) -> Buffer: ...
- @overload
- def __setitem__(
- self: handle, pos: Sequence[Union[int, PrimExpr, slice]], value: Buffer
- ) -> None: ...
- @overload
- def __setitem__(self: handle, pos: Union[int, PrimExpr, slice], value: Buffer) -> None: ...
- @property
- def data(self: handle) -> Ptr: ...
-
-class Ptr: ...
-
-def target(target_str: Union[str, Mapping[str, Object]]) -> Target: ...
-
-class var(Var):
- def __init__(self: Var, dtype: str): ...
-
-class bool(PrimExpr):
- def __init__(self: bool, imm: Union[PrimExpr, builtins.bool, builtins.int]): ...
-
-class int8(PrimExpr):
- def __init__(self: int8, imm: Union[PrimExpr, int]): ...
-
-class int16(PrimExpr):
- def __init__(self: int16, imm: Union[PrimExpr, int]): ...
-
-class int32(PrimExpr):
- def __init__(self: int32, imm: Union[PrimExpr, int]): ...
-
-class int64(PrimExpr):
- def __init__(self: int64, imm: Union[PrimExpr, int]): ...
-
-class uint8(PrimExpr):
- def __init__(self: uint8, imm: Union[PrimExpr, int]): ...
-
-class uint16(PrimExpr):
- def __init__(self: uint16, imm: Union[PrimExpr, int]): ...
-
-class uint32(PrimExpr):
- def __init__(self: uint32, imm: Union[PrimExpr, int]): ...
-
-class uint64(PrimExpr):
- def __init__(self: uint64, imm: Union[PrimExpr, int]): ...
-
-# use typing.Literal instead for python 3.8 or higher
-import sys
-
-if sys.version_info >= (3, 8):
- from typing import Literal
-
- SpecialFloatLiteral = Literal["inf", "-inf", "nan"]
-else:
- SpecialFloatLiteral = str
-
-class float8(PrimExpr):
- def __init__(self: float8, imm: Union[PrimExpr, int, float, SpecialFloatLiteral]): ...
-
-class float16(PrimExpr):
- def __init__(self: float16, imm: Union[PrimExpr, int, float, SpecialFloatLiteral]): ...
-
-class float32(PrimExpr):
- def __init__(self: float32, imm: Union[PrimExpr, int, float, SpecialFloatLiteral]): ...
-
-class float64(PrimExpr):
- def __init__(self: float64, imm: Union[PrimExpr, int, float, SpecialFloatLiteral]): ...
diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py
deleted file mode 100644
index bd9aa1fdad..0000000000
--- a/python/tvm/script/tir/intrin.py
+++ /dev/null
@@ -1,231 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser Intrinsic Classes"""
-# pylint: disable=redefined-builtin, relative-beyond-top-level
-import builtins
-from typing import List, Any
-
-import tvm.tir
-from tvm.tir import FloatImm
-from ..registry import register
-from ...target import codegen
-from ..utils import get_param_list, tvm_span_from_synr
-
-
-class Intrin:
- def __init__(self, intrin, stmt=False):
- self.intrin = intrin
- self.stmt = stmt
-
- def signature(self):
- return "tir." + self.intrin.__name__, get_param_list(self.intrin)
-
- def handle(self, arg_list: List[Any], span: tvm.ir.Span):
- return self.intrin(*arg_list, span=tvm_span_from_synr(span))
-
-
-@register
-def bool(imm, span):
- return imm.astype("bool", span)
-
-
-# register all datatypes
-for _dtype in ["float", "uint", "int"]:
- for _size in ["8", "16", "32", "64"]:
- for _lanes in ["", "x4", "x8", "x16", "x32"]:
- _name = _dtype + _size + _lanes
-
- # nest closures so we copy the name string
- def wrap(name):
- def f(imm, span):
- if name.startswith("float"):
- if imm in {"inf", "-inf", "nan"}:
- return FloatImm(dtype=name, value=float(imm), span=span)
- return imm.astype(name, span)
-
- f.__name__ = name
- return f
-
- _intrin = wrap(_name)
- register(_intrin)
-
-
-@register
-def min_value(dtype, span):
- return tvm.tir.min_value(dtype, span)
-
-
-@register
-def max_value(dtype, span):
- return tvm.tir.max_value(dtype, span)
-
-
-@register
-def floordiv(x, y, span):
- return tvm.tir.floordiv(x, y, span)
-
-
-@register
-def floormod(x, y, span):
- return tvm.tir.floormod(x, y, span)
-
-
-@register
-def truncmod(x, y, span):
- return tvm.tir.truncmod(x, y, span)
-
-
-@register
-def truncdiv(x, y, span):
- return tvm.tir.truncdiv(x, y, span)
-
-
-@register
-def ceildiv(x, y, span):
- return tvm.tir.ceildiv(x, y, span)
-
-
-@register
-def abs(x, span):
- return tvm.tir.abs(x, span)
-
-
-@register
-def load(dtype, var, index, predicate=None, span=None):
- return tvm.tir.Load(dtype, var, index, predicate, span)
-
-
-@register
-def cast(value, dtype, span):
- return tvm.tir.Cast(dtype, value, span)
-
-
-@register
-def ramp(base, stride, lanes, span):
- return tvm.tir.Ramp(base, stride, lanes.value, span)
-
-
-@register
-def broadcast(value, lanes, span):
- return tvm.tir.Broadcast(value, lanes.value, span)
-
-
-@register
-def iter_var(var, dom, iter_type, thread_tag, span):
- iter_type = getattr(tvm.tir.IterVar, iter_type)
- return tvm.tir.IterVar(dom, var, iter_type, thread_tag, span)
-
-
-@register
-def max(a, b, span): # pylint: disable=redefined-builtin
- return tvm.tir.Max(a, b, span)
-
-
-@register
-def min(a, b, span): # pylint: disable=redefined-builtin
- return tvm.tir.Min(a, b, span)
-
-
-def get_axis(begin, end, iter_type, span):
- ana = tvm.arith.Analyzer()
- extent = ana.simplify(end - begin)
- block_var_dom = tvm.ir.Range.from_min_extent(begin, extent)
-
- iter_type_dict = {"data_par": 0, "reduce": 2, "scan": 3, "opaque": 4}
- return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type], span=span)
-
-
-@register
-def range(begin, end, span):
- return get_axis(begin, end, "data_par", span)
-
-
-@register
-def reduce_axis(begin, end, span):
- return get_axis(begin, end, "reduce", span)
-
-
-@register
-def scan_axis(begin, end, span):
- return get_axis(begin, end, "scan", span)
-
-
-@register
-def opaque_axis(begin, end, span):
- return get_axis(begin, end, "opaque", span)
-
-
-@register
-def Select(cond, if_body, else_body, span): # pylint: disable=invalid-name
- return tvm.tir.Select(cond, if_body, else_body, span)
-
-
-@register
-def Let(var, value, body, span): # pylint: disable=invalid-name
- return tvm.tir.Let(var, value, body, span)
-
-
-@register
-class EvaluateIntrin(Intrin):
- def __init__(self):
- def evaluate(value, span):
- return tvm.tir.Evaluate(value, span)
-
- super().__init__(evaluate, stmt=True)
-
-
-@register
-class StoreIntrin(Intrin):
- def __init__(self):
- def store(var, index, value, predicate=True, span=None):
- return tvm.tir.Store(var, value, index, predicate, span)
-
- super().__init__(store, stmt=True)
-
-
-@register
-class AssumeIntrin(Intrin):
- def __init__(self):
- def assume(constraint, span):
- return tvm.tir.Evaluate(
- tvm.tir.call_intrin("bool", "tir.assume", constraint, span=span)
- )
-
- super().__init__(assume, stmt=True)
-
-
-@register
-def comm_reducer(lambda_io, identities, span):
- """Create a CommReducer from lambda inputs/outputs and the identities"""
- lambda_input = lambda_io[0]
- lambda_output = lambda_io[1]
-
- num_args = len(lambda_input)
- num_arg_per_group = num_args // 2
- x = [lambda_input[i] for i in builtins.range(0, num_arg_per_group)]
- y = [lambda_input[i] for i in builtins.range(num_arg_per_group, num_args)]
-
- if not isinstance(lambda_output, tuple):
- lambda_output = (lambda_output,)
-
- return tvm.tir.CommReducer(x, y, lambda_output, identities, span)
-
-
-@register
-def llvm_lookup_intrinsic_id(name, span):
- # pylint: disable=unused-argument
- return codegen.llvm_lookup_intrinsic_id(name)
diff --git a/python/tvm/script/tir/node.py b/python/tvm/script/tir/node.py
deleted file mode 100644
index 29e79607fb..0000000000
--- a/python/tvm/script/tir/node.py
+++ /dev/null
@@ -1,218 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=redefined-builtin
-"""TVM Script nodes."""
-
-from typing import Optional, Union, List, Callable
-import synr
-from tvm.arith import Analyzer
-from tvm.runtime import ObjectGeneric, convert
-from tvm.tir import PrimExpr, Buffer, BufferLoad, IntImm, Ramp, BufferRegion
-from tvm.ir import Span, Range
-
-
-class Slice:
- """A helper class to present slice information for BufferSlice
-
- Parameters
- ----------
- start : Union[PrimExpr, int]
- The start index.
-
- stop : Optional[Union[PrimExpr, int]]
- The stop index, None means the Slice is an element-wise index
-
- step : int
- The slice step
-
- span : Optional[Span]
- The location of the slice in the source.
- """
-
- start: Union[PrimExpr, int]
- stop: Optional[Union[PrimExpr, int]]
- step: int
- span: Optional[Span]
-
- def __init__(
- self,
- start: Union[PrimExpr, int],
- stop: Optional[Union[PrimExpr, int]] = None,
- step: int = 1,
- span: Optional[Span] = None,
- ):
- self.start = start
- self.stop = stop
- self.step = step
- self.span = span
-
- def as_index_expr(self, report_error: Callable[[str, Union[Span, synr.ast.Span]], None]):
- """Helper to create index PrimExpr from slice object
- Parameters
- ----------
- report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
- The error report func
- """
- if self.stop is None:
- # scalar index
- return self.start
- if self.step < 1:
- report_error("Slice's step should be positive integer", self.span)
- lanes = Analyzer().simplify((self.stop - self.start + self.step - 1) // self.step)
- if not isinstance(lanes, (int, IntImm)):
- report_error("Slice's lanes should be constant for buffer indices", self.span)
- if lanes == 1:
- return self.start
- return Ramp(self.start, self.step, int(lanes), self.span)
-
-
-class BufferSlice(ObjectGeneric):
- """A generic object for representing general buffer access. Following cases are supported:
- - element wise access buffer[i, j], which can be converted to BufferLoad if necessary
- - slice access buffer[i: i + 1, j : j + 2]
- - union of element and slice buffer[i, j: j + 2]
-
- This node is used in TVMScript to parse BufferLoad, BufferRegion and Realize
-
- Parameters
- ----------
- buffer : Buffer
- The buffer.
-
- indices : List[Union[Slice, PrimExpr, int]]
- The access indexes can be slice, PrimExpr or int.
-
- report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
- The error report func
-
- span : Optional[Span]
- The location of the buffer access in the source.
- """
-
- buffer: Buffer
- slices: List[Slice]
- report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
- span: Optional[Span]
-
- def __init__(
- self,
- buffer: Buffer,
- indices: List[Union[Slice, PrimExpr, int]],
- report_error: Callable[[str, Union[Span, synr.ast.Span]], None],
- span: Optional[Span] = None,
- ):
- def check_index(index: Union[int, PrimExpr]):
- """Check input index is non-negative integer or PrimExpr"""
- if isinstance(index, int):
- if index < 0:
- report_error("Negative index is not allowed during buffer access", span)
- elif isinstance(index, PrimExpr):
- element_dtype = index.dtype.split("x", maxsplit=1)[0]
- if element_dtype[:3] != "int":
- report_error(
- "index expected an integer type PrimExpr but got " + str(index.dtype),
- index.span,
- )
- else:
- report_error(
- "Unsupported index type, expected int or tvm.tir.PrimExpr, but got "
- + str(type(index)),
- span,
- )
-
- slices: List[Union[Slice, BufferSlice]] = []
- for index in indices:
- if isinstance(index, Slice):
- index.start, index.stop = [convert(_) for _ in [index.start, index.stop]]
- check_index(index.start)
- check_index(index.stop)
- slices.append(index)
- elif isinstance(index, (PrimExpr, int)):
- check_index(index)
- slices.append(Slice(index))
- elif isinstance(index, BufferSlice):
- buffer_load = index.asobject()
- check_index(buffer_load)
- slices.append(Slice(buffer_load))
- else:
- report_error(
- "Unsupported index type for BufferSlice, "
- + "expected int, tvm.tir.PrimExpr, tvm.tir.Slice, but got "
- + str(type(index)),
- span,
- )
-
- self.buffer = buffer
- self.slices = slices
- self.report_error = report_error
- self.span = span
-
- def __str__(self):
- regions: List[str] = []
- for s in self.slices:
- if s.stop is None:
- regions.append(str(s.start))
- else:
- regions.append(str(s.start) + ": " + str(s.stop))
-
- return self.buffer.name + "[" + ", ".join(regions) + "]"
-
- def asobject(self) -> BufferLoad:
- """Convert object."""
- indices = [s.as_index_expr(self.report_error) for s in self.slices]
- return BufferLoad(self.buffer, indices, span=self.span)
-
- def as_buffer_region(self, analyzer: Optional[Analyzer] = None) -> BufferRegion:
- """Construct BufferRegion from BufferSlice
-
- Parameters
- ----------
- analyzer : Optional[tvm.arith.Analyzer]
- The analyzer for simplifying. If not provided, the method will construct a new one
-
- Returns
- -------
- buffer_region : BufferRegion
- The constructed BufferRegion.
- """
- region: List[Range] = []
- for s in self.slices:
- start = s.start if isinstance(s.start, PrimExpr) else IntImm("int32", s.start)
- extent = IntImm(start.dtype, 1) if s.stop is None else s.stop - s.start
- if not analyzer:
- analyzer = Analyzer()
- if isinstance(extent, PrimExpr):
- extent = analyzer.simplify(extent)
- if s.step != 1:
- self.report_error("BufferRegion do not support non-trivial stride", s.span)
- region.append(Range.from_min_extent(start, extent, span=s.span))
- return BufferRegion(self.buffer, region)
-
- def astype(self, dtype: str, span: Optional[Span] = None) -> PrimExpr:
- return self.asobject().astype(dtype, span)
-
- @property
- def dtype(self) -> str:
- """Return the dtype referenced by the slice.
-
- Implemented as a property so that ``slice.dtype`` has the same
- calling convention as ``primexpr.dtype``. This allows a
- BufferSlice object can be assigned to a variable without
- requiring a type annotation on the variable, similar to other
- expressions.
- """
- return self.asobject().dtype
diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py
deleted file mode 100644
index 1d2550eecd..0000000000
--- a/python/tvm/script/tir/scope_handler.py
+++ /dev/null
@@ -1,793 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser Scope Handler Classes"""
-# pylint: disable=redefined-builtin, unused-argument, invalid-name, relative-beyond-top-level
-from typing import Tuple, Any, Callable, Optional, List, Union, Mapping
-
-import synr
-import numpy as np
-import tvm.tir
-from tvm.runtime import Object, String, convert
-from tvm.ir import Span, Range
-from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind, IntImm
-
-from .node import BufferSlice
-
-from ..context_maintainer import ContextMaintainer
-from ..registry import register
-from ..utils import (
- get_param_list,
- tvm_span_from_synr,
- call_with_error_reporting,
-)
-
-
-class ScopeHandler:
- """Base class for all scope handlers"""
-
- def __init__(self, func: Callable):
- self.func: Callable = func
- self.body: Optional[Stmt] = None
- self.node: Optional[synr.ast.Node] = None
- self.context: Optional[ContextMaintainer] = None
-
- def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
- return "tir." + self.func.__name__, get_param_list(self.func)
-
- def enter_scope(
- self,
- node: synr.ast.Node,
- context: ContextMaintainer,
- arg_list: List[Any],
- span: synr.ast.Span,
- ):
- pass
-
- def exit_scope(
- self,
- node: synr.ast.Node,
- context: ContextMaintainer,
- arg_list: List[Any],
- span: synr.ast.Span,
- ):
- self.node = node
- self.context = context
- return call_with_error_reporting(
- context.report_error, span, self.func, *arg_list, span=tvm_span_from_synr(span)
- )
-
-
-class WithScopeHandler(ScopeHandler):
- """Base class for all with scope handlers"""
-
- def __init__(self, func, concise_scope, def_symbol):
- super().__init__(func)
- self.concise_scope = concise_scope
- self.def_symbol = def_symbol
-
- @staticmethod
- def get_optional_vars(node, context):
- """Get a list synr.ast.With's optional_vars"""
- assert isinstance(
- node, synr.ast.With
- ), f"WithScopeHandler expected synr.ast.With but got {type(node)}"
-
- if isinstance(node.lhs, list):
- for var in node.lhs:
- if not isinstance(var, synr.ast.Var):
- context.report_error(
- f"Invalid optional var definition, expected Var but got {type(var)}",
- node.span,
- )
- vars = node.lhs
- else:
- context.report_error(
- f"Invalid optional var definition, expected list of Var but got {type(node.lhs)}",
- node.span,
- )
- return vars
-
-
-@register
-class Allocate(WithScopeHandler):
- """With scope handler T.allocate(extents, dtype, scope, condition, annotations)"""
-
- def __init__(self):
- def allocate(extents, dtype, scope, condition=True, annotations=None, span=None):
- condition = tvm.runtime.convert(condition)
- scope = tvm.runtime.convert(scope)
-
- return tvm.tir.Allocate(
- self.buffer_var,
- dtype,
- extents,
- condition,
- self.body,
- annotations=annotations,
- span=span,
- )
-
- super().__init__(allocate, concise_scope=True, def_symbol=True)
- self.buffer_var = None
-
- def enter_scope(
- self,
- node: synr.ast.Node,
- context: ContextMaintainer,
- arg_list: List[Any],
- span: synr.ast.Span,
- ):
- # define buffer vars in symbol table
- if isinstance(node, synr.ast.With):
- vars = WithScopeHandler.get_optional_vars(node, context)
- if len(vars) != 1:
- context.report_error(f"Unexpected number of vars: 1 vs. {len(vars)}", node.span)
- name = vars[0].id.name
- var_span = vars[0].id.span
- elif isinstance(node, synr.ast.Assign):
- if len(node.lhs) != 1:
- context.report_error(f"Unexpected number of vars: 1 vs. {len(node.lhs)}", node.span)
- name = node.lhs[0].id.name
- var_span = node.lhs[0].id.span
- else:
- raise Exception("Internal Bug")
-
- def setup_buffer_var(
- extents, dtype, scope, condition=True, annotations=None, span: Span = None
- ):
- """Setup buffer var for a given type."""
- buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), scope)
- self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)
-
- setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span))
- context.update_symbol(name, self.buffer_var, node)
-
-
-@register
-class AllocateConst(WithScopeHandler):
- """With scope handler T.allocate_const(data, extents, dtype, condition)
-
- TIR constant node to represent non-scalar constant
- """
-
- def __init__(self):
- def allocate_const(raw_data, dtype, shape, annotations=None, span=None):
- list_data = []
- for i in raw_data:
- list_data.append(i.value)
- nd_data = tvm.nd.array(np.asarray(list_data, dtype=dtype))
- n = tvm.tir.AllocateConst(
- self.buffer_var,
- dtype,
- shape,
- nd_data,
- self.body,
- annotations=annotations,
- span=span,
- )
- return n
-
- super().__init__(allocate_const, concise_scope=True, def_symbol=True)
- self.buffer_var = None
-
- def enter_scope(
- self,
- node: synr.ast.Node,
- context: ContextMaintainer,
- arg_list: List[Any],
- span: synr.ast.Span,
- ):
- # define buffer vars in symbol table
- if isinstance(node, synr.ast.With):
- vars = WithScopeHandler.get_optional_vars(node, context)
- if len(vars) != 1:
- context.report_error(f"Unexpected number of vars: 1 vs. {len(vars)}", node.span)
- name = vars[0].id.name
- var_span = vars[0].id.span
- elif isinstance(node, synr.ast.Assign):
- if len(node.lhs) != 1:
- context.report_error(f"Unexpected number of vars: 1 vs. {len(node.lhs)}", node.span)
- name = node.lhs[0].id.name
- var_span = node.lhs[0].id.span
- else:
- raise Exception("Internal Bug")
-
- def setup_buffer_var(data, dtype, shape, annotations: dict = None, span: Span = None):
- """Setup buffer var for a given type."""
- buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype))
- self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)
-
- setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span))
- context.update_symbol(name, self.buffer_var, node)
-
-
-@register
-class DeclBuffer(WithScopeHandler):
- """Special Stmt decl_buffer(shape, dtype, data, strides, elem_offset, scope, align,
- offset_factor, buffer_type, axis_separators)
- Example
- -------
- .. code-block:: python
- A = T.decl_buffer((128, 128), dtype="float32")
- """
-
- def __init__(self):
- def decl_buffer(
- shape,
- dtype="float32",
- data=None,
- strides=None,
- elem_offset=None,
- scope="global",
- align=-1,
- offset_factor=0,
- buffer_type="default",
- axis_separators=None,
- span=None,
- ):
- decl_buffer = tvm.tir.DeclBuffer(self.buffer, self.body, span=span)
- if data is None:
- # when data is not specified, the buffer is implicitly allocated
- return tvm.tir.Allocate(
- self.buffer.data,
- dtype,
- shape,
- tvm.runtime.convert(True),
- decl_buffer,
- span=span,
- )
- return decl_buffer
-
- super().__init__(decl_buffer, concise_scope=True, def_symbol=True)
-
- def enter_scope(
- self,
- node: synr.ast.Node,
- context: ContextMaintainer,
- arg_list: List[Any],
- span: synr.ast.Span,
- ):
- # define buffer vars in symbol table
- if isinstance(node, synr.ast.With):
- vars = WithScopeHandler.get_optional_vars(node, context)
- if len(vars) != 1:
- context.report_error(f"Unexpected number of vars: 1 vs. {len(vars)}", node.span)
- name = vars[0].id.name
- var_span = vars[0].id.span
- elif isinstance(node, synr.ast.Assign):
- if len(node.lhs) != 1:
- context.report_error(f"Unexpected number of vars: 1 vs. {len(node.lhs)}", node.span)
- name = node.lhs[0].id.name
- var_span = node.lhs[0].id.span
- else:
- raise Exception("Internal Bug")
-
- def setup_buffer(
- shape,
- dtype,
- data,
- strides,
- elem_offset,
- scope,
- align,
- offset_factor,
- buffer_type,
- axis_separators,
- span: Span = None,
- ):
- self.buffer = tvm.tir.decl_buffer(
- shape=shape,
- dtype=dtype,
- data=data,
- strides=strides,
- elem_offset=elem_offset,
- scope=scope,
- data_alignment=align,
- offset_factor=offset_factor,
- buffer_type=buffer_type,
- axis_separators=axis_separators,
- name=name,
- span=span,
- )
-
- setup_buffer(*arg_list, span=tvm_span_from_synr(var_span))
- context.update_symbol(name, self.buffer, node)
-
-
-@register
-class LaunchThread(WithScopeHandler):
- """With scope handler T.launch_thread(env_var, extent)"""
-
- def __init__(self):
- def launch_thread(env_var, extent, span):
- extent = tvm.runtime.convert(extent, span=span)
- thread_id = self.context.func_var_env_dict[env_var]
- attr_key = "virtual_thread" if thread_id == "vthread" else "thread_extent"
- return tvm.tir.AttrStmt(
- IterVar(
- (0, extent),
- env_var,
- getattr(IterVar, "ThreadIndex"),
- thread_id,
- span=span,
- ),
- attr_key,
- extent,
- self.body,
- span=span,
- )
-
- super().__init__(launch_thread, concise_scope=True, def_symbol=False)
-
-
-@register
-class Realize(WithScopeHandler):
- """With scope handler T.realize(buffer_bounds, scope, condition)"""
-
- def __init__(self):
- def realize(
- buffer_slice: BufferSlice, scope: str, condition: bool = True, span: bool = None
- ):
- assert self.context, "call 'exit_scope' before 'enter_scope'"
- buffer: Buffer = buffer_slice.buffer
- bounds: List[Range] = []
- for s in buffer_slice.slices:
- min: Union[PrimExpr, int] = s.start
- extent: Union[PrimExpr, int] = 1 if s.stop is None else s.stop - s.start
- if isinstance(extent, PrimExpr):
- extent = self.context.analyzer.simplify(extent)
- bounds.append(Range.from_min_extent(min, extent, span=s.span))
-
- scope = tvm.runtime.convert(scope, span=span)
- return tvm.tir.AttrStmt(
- buffer,
- "realize_scope",
- scope,
- tvm.tir.BufferRealize(buffer, bounds, condition, self.body, span=span),
- span=span,
- )
-
- super().__init__(realize, concise_scope=True, def_symbol=False)
-
-
-@register
-class Attr(WithScopeHandler):
- """With scope handler T.attr(attr_node, attr_key, value)"""
-
- def __init__(self):
- def attr(attr_node, attr_key, value, span):
- attr_node = tvm.runtime.convert(attr_node, span=span)
- value = tvm.runtime.convert(value, span=span)
- return tvm.tir.AttrStmt(attr_node, attr_key, value, self.body, span=span)
-
- super().__init__(attr, concise_scope=True, def_symbol=False)
-
-
-@register
-class AssertHandler(WithScopeHandler):
- """With scope handler T.Assert(condition, message)"""
-
- def __init__(self):
- def Assert(condition, message, span):
- return tvm.tir.AssertStmt(condition, tvm.runtime.convert(message), self.body, span=span)
-
- super().__init__(Assert, concise_scope=True, def_symbol=False)
-
-
-@register
-class Let(WithScopeHandler):
- """With scope handler T.let(var, value)"""
-
- def __init__(self):
- def let(var, value, span):
- return tvm.tir.LetStmt(var, value, self.body, span=span)
-
- super().__init__(let, concise_scope=False, def_symbol=False)
-
- def __call__(self, var: tvm.tir.Var, value: tvm.tir.PrimExpr, body: tvm.tir.PrimExpr):
- return tvm.tir.Let(var, value, body)
-
-
-@register
-class Block(WithScopeHandler):
- """With scope handler T.block(name)"""
-
- def __init__(self):
- def block(name_hint: str = "", span: Optional[Span] = None):
- assert (
- self.node and self.context and self.body
- ), "call 'exit_scope' before 'enter_scope'"
- block_info = self.context.block_info_stack[-1]
-
- # create block read/write regions
- reads: List[BufferRegion] = (
- [read.as_buffer_region() for read in block_info.reads] if block_info.reads else []
- )
- writes: List[BufferRegion] = (
- [write.as_buffer_region() for write in block_info.writes]
- if block_info.writes
- else []
- )
-
- region_detect_mask: int = (block_info.reads is None) | (
- (block_info.writes is None) << 1
- )
- annotations = {} if block_info.annotations is None else block_info.annotations
- if region_detect_mask != 0:
- annotations["tir.script_parsing_detect_access"] = region_detect_mask
- inner = tvm.tir.Block(
- block_info.iter_vars,
- reads,
- writes,
- name_hint,
- self.body,
- block_info.init,
- block_info.alloc_buffers,
- block_info.match_buffers,
- annotations,
- span,
- )
- assert len(block_info.iter_vars) == len(block_info.iter_values)
- predicate = (
- tvm.tir.const(True, "bool")
- if block_info.predicate is None
- else block_info.predicate
- )
- body = tvm.tir.BlockRealize(block_info.iter_values, predicate, inner, span)
- return body
-
- super().__init__(func=block, concise_scope=False, def_symbol=True)
- self.block_vars = None
-
- def enter_scope(
- self,
- node: synr.ast.Node,
- context: ContextMaintainer,
- arg_list: List[Any],
- span: synr.ast.Span,
- ):
- # define block vars
- assert isinstance(
- node, synr.ast.With
- ), f"BlockScopeHandler expected to work on synr.ast.With but got {type(node)}"
-
- optional_vars = [var.id.name for var in WithScopeHandler.get_optional_vars(node, context)]
- if optional_vars:
- context.report_error(
- f"Block expected no optional_vars (e.g., `x` in `with block() as x`), "
- f"but got {optional_vars}",
- node.span,
- )
-
-
-@register
-class InitBlock(WithScopeHandler):
- """With scope handler T.init()"""
-
- def __init__(self):
- def init(span: Span = None):
- assert self.context, "call 'exit_scope' before 'enter_scope'"
- if self.context.block_info_stack[-2].init is not None:
- self.context.report_error("Duplicate init block declaration", span)
- self.context.block_info_stack[-2].init = self.body
-
- super().__init__(func=init, concise_scope=False, def_symbol=True)
-
-
-class LoopInfo:
- """Helper class for loop information"""
-
- loop_var: Var
- begin: PrimExpr
- extent: PrimExpr
- kind: ForKind
- thread_binding: Optional[str]
- annotations: Optional[Mapping[str, Object]]
-
- def __init__(
- self,
- begin: PrimExpr,
- extent: PrimExpr,
- kind: ForKind,
- thread_binding: Optional[str] = None,
- annotations: Optional[Mapping[str, Object]] = None,
- ) -> None:
- self.begin = begin
- self.extent = extent
- self.kind = kind
- self.thread_binding = thread_binding
- self.annotations = annotations
-
-
-class ForScopeHandler(ScopeHandler):
- """Base class for all for scope handlers"""
-
- def __init__(self, func):
- super().__init__(func)
- self.loop_vars: List[Var] = []
- self.loop_info: List[LoopInfo] = []
-
- def enter_scope(
- self,
- node: synr.ast.Node,
- context: ContextMaintainer,
- arg_list: List[Any],
- span: synr.ast.Span,
- ):
- assert isinstance(
- node, synr.ast.For
- ), f"ForScopeHandler expected synr.ast.For but got {type(node)}"
-
- loop_var_names = list()
- spans = list()
- if isinstance(node.lhs, synr.ast.Var):
- loop_var_names.append(node.lhs.id.name)
- spans.append(tvm_span_from_synr(node.lhs.id.span))
- elif isinstance(node.lhs, list):
- for elt in node.lhs:
- if not isinstance(elt, synr.ast.Var):
- context.report_error(
- f"Invalid loop var. Expected a var, but got {type(elt)}", elt.span
- )
- loop_var_names.append(elt.id.name)
- spans.append(tvm_span_from_synr(elt.id.span))
- else:
- context.report_error(
- f"Invalid loop var. Expected var or list of vars as lhs, but got {type(node.lhs)}",
- span,
- )
-
- self.node = node
- self.context = context
- # collect loop infos by calling self.func
- call_with_error_reporting(context.report_error, span, self.func, *arg_list)
- if len(loop_var_names) != len(self.loop_info):
- self.context.report_error(
- f"Inconsistent number of vars and loops, got {len(loop_var_names)} "
- + f"vs {len(self.loop_info)}",
- self.node.span,
- )
- # generate loop vars
- self.loop_vars = []
- for name, lv_span, li in zip(loop_var_names, spans, self.loop_info):
- if not li.begin.dtype.startswith("int"):
- raise NotImplementedError(f"Unsupported dtype in loop begin: {li.begin.dtype}")
- if not li.extent.dtype.startswith("int"):
- raise NotImplementedError(f"Unsupported dtype in loop extent: {li.extent.dtype}")
- dtype = "int64" if "int64" in [li.begin.dtype, li.extent.dtype] else "int32"
- self.loop_vars.append(tvm.te.var(name, dtype=dtype, span=lv_span))
-
- for loop_var, loop_info in zip(self.loop_vars, self.loop_info):
- context.update_symbol(loop_var.name, loop_var, node)
- context.loop_stack[loop_var] = Range.from_min_extent(loop_info.begin, loop_info.extent)
-
- def exit_scope(
- self,
- node: synr.ast.Node,
- context: ContextMaintainer,
- arg_list: List[Any],
- span: synr.ast.Span,
- ):
- assert self.loop_vars, "call 'exit_scope' before 'enter_scope'"
- for loop_var in self.loop_vars:
- context.loop_stack.pop(loop_var)
- # Use assert here since we have check it in `enter_scope`
- assert len(self.loop_vars) == len(self.loop_info)
-
- body = self.body
- for var, info in zip(reversed(self.loop_vars), reversed(self.loop_info)):
- body = tvm.tir.For(
- var,
- info.begin,
- info.extent,
- info.kind,
- body,
- info.thread_binding,
- info.annotations,
- span=tvm_span_from_synr(span),
- )
-
- return body
-
- def create_loop_info(
- self,
- begin: PrimExpr,
- end: PrimExpr,
- kind: ForKind,
- thread_binding: Optional[str] = None,
- annotations: Optional[Mapping[str, Object]] = None,
- ) -> None:
- """
- Helper function for creating For in TVM Script parser.
-
- Parameters
- ----------
- begin : PrimExpr
- The beginning value.
-
- end : PrimExpr
- The endding value.
-
- kind : ForKind
- The type of the for.
-
- thread_binding: Optional[str]
- The thread this loop binds to.
-
- annotations : Optional[Mapping[str, Object]]
- Additional annotation hints.
-
- span : Optional[Span]
- The location of this for in the source code.
-
- Returns
- -------
- for : For
- The constructed For.
- """
- begin, end = [convert(_) for _ in [begin, end]]
- assert self.context and self.node, "call 'exit_scope' before 'enter_scope'"
- extent = end if begin == 0 else self.context.analyzer.simplify(end - begin)
- if begin == 0 and isinstance(extent, PrimExpr):
- begin = IntImm(extent.dtype, 0, begin.span)
- self.annotations: Mapping[str, Object] = {}
- if annotations is not None:
- self.annotations = {
- key: String(val) if isinstance(val, str) else val
- for key, val in annotations.items()
- }
-
- self.loop_info.append(LoopInfo(begin, extent, kind, thread_binding, annotations))
-
-
-@register
-class Serial(ForScopeHandler):
- """For scope handler T.serial(begin, end, annotations)"""
-
- def __init__(self):
- def serial(
- begin: PrimExpr,
- end: PrimExpr = None,
- annotations: Optional[Mapping[str, Object]] = None,
- ):
- if end is None:
- end = begin
- begin = 0
- self.create_loop_info(begin, end, ForKind.SERIAL, annotations=annotations)
-
- super().__init__(serial)
-
-
-@register
-class Parallel(ForScopeHandler):
- """For scope handler T.parallel(begin, end, annotations)"""
-
- def __init__(self):
- def parallel(
- begin: PrimExpr,
- end: PrimExpr = None,
- annotations: Optional[Mapping[str, Object]] = None,
- ):
- if end is None:
- end = begin
- begin = 0
- self.create_loop_info(begin, end, ForKind.PARALLEL, annotations=annotations)
-
- super().__init__(parallel)
-
-
-@register
-class Vectorized(ForScopeHandler):
- """For scope handler T.vectorized(begin, end, annotations)"""
-
- def __init__(self):
- def vectorized(
- begin: PrimExpr,
- end: PrimExpr = None,
- annotations: Optional[Mapping[str, Object]] = None,
- ):
- if end is None:
- end = begin
- begin = 0
- self.create_loop_info(begin, end, ForKind.VECTORIZED, annotations=annotations)
-
- super().__init__(vectorized)
-
-
-@register
-class Unroll(ForScopeHandler):
- """For scope handler T.unroll(begin, end, annotations)"""
-
- def __init__(self):
- def unroll(
- begin: PrimExpr,
- end: PrimExpr = None,
- annotations: Optional[Mapping[str, Object]] = None,
- ):
- if end is None:
- end = begin
- begin = 0
- self.create_loop_info(begin, end, ForKind.UNROLLED, annotations=annotations)
-
- super().__init__(unroll)
-
-
-@register
-class ThreadBinding(ForScopeHandler):
- """For scope handler T.thread_binding(begin, end, thread, annotations)"""
-
- def __init__(self):
- def thread_binding(
- begin: PrimExpr,
- end: PrimExpr = None,
- thread: str = None,
- annotations: Optional[Mapping[str, Object]] = None,
- ):
- if thread is None:
- if isinstance(end, str): # handle case like thread_binding(128, "threadIdx.x")
- thread = end
- end = None
- else:
- raise ValueError("Thread cannot be None for thread_binding")
- if end is None:
- end = begin
- begin = 0
- thread_iter_var = IterVar(None, None, IterVar.ThreadIndex, thread)
- self.create_loop_info(
- begin,
- end,
- ForKind.THREAD_BINDING,
- thread_binding=thread_iter_var,
- annotations=annotations,
- )
-
- super().__init__(thread_binding)
-
-
-@register
-class RangeHandler(ForScopeHandler):
- """For scope handler range(begin, end, annotations)
- Note that tir.range is totally the same as T.serial
- """
-
- def __init__(self):
- def for_range(
- begin: PrimExpr,
- end: PrimExpr = None,
- annotations: Optional[Mapping[str, Object]] = None,
- ):
- if end is None:
- end = begin
- begin = 0
- self.create_loop_info(begin, end, ForKind.SERIAL, annotations=annotations)
-
- super().__init__(for_range)
-
- def signature(self):
- return "range", get_param_list(self.func)
-
-
-@register
-class Grid(ForScopeHandler):
- """For scope handler T.grid(extents)"""
-
- def __init__(self):
- def grid(*extents: List[PrimExpr]):
- for extent in extents:
- self.create_loop_info(0, extent, ForKind.SERIAL)
-
- super().__init__(grid)
diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py
deleted file mode 100644
index 15502055b7..0000000000
--- a/python/tvm/script/tir/special_stmt.py
+++ /dev/null
@@ -1,964 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser Special Stmt Classes"""
-# pylint: disable=unused-argument, no-self-argument, inconsistent-return-statements
-# pylint: disable=relative-beyond-top-level
-from typing import Callable, List, Optional, Tuple, Any, Mapping, Union
-
-import synr
-from synr import ast
-from tvm.ir.expr import PrimExpr, Range
-
-import tvm.tir
-from tvm.runtime import Object, String
-from tvm.target import Target
-from tvm.ir import Span
-from tvm.tir import IntImm, IterVar, Var
-
-from .node import BufferSlice
-
-from ..context_maintainer import BlockInfo, ContextMaintainer
-from ..registry import register
-from ..utils import (
- get_param_list,
- tvm_span_from_synr,
- call_with_error_reporting,
-)
-
-
-def convert_to_int(
- value: Union[IntImm, int],
- arg_name: str,
- report_error: Callable,
- span: Union[Span, synr.ast.Span],
-) -> int:
- """convert a const int or TVM IntImm to Python int.
- Reports an error when input cannot be converted to int.
-
- Parameters
- ----------
- value : Union[tvm.tir.IntImm, int]
- The input value to be converted.
- arg_name : str
- Function argument name for error reporting.
- report_error: Callable
- The report error function handle
- span : Union[synr.ast.Span, tvm.ir.Span]
- Location of the error
- """
- if isinstance(value, IntImm):
- return value.value
- if isinstance(value, int):
- return value
- report_error(
- f"Expected int or IntImm for {arg_name}, but got {str(type(value))}",
- span,
- )
-
-
-class SpecialStmt:
- """Base class for all Special Stmts"""
-
- def __init__(self, func: Callable, def_symbol: bool):
- self.func: Callable = func
- self.def_symbol: bool = def_symbol
- self.node: Optional[synr.ast.Node] = None
- self.context: Optional[ContextMaintainer] = None
-
- def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
- return "tir." + self.func.__name__, get_param_list(self.func)
-
- def handle(
- self,
- node: ast.Node,
- context: ContextMaintainer,
- arg_list: List[Any],
- span: synr.ast.Span,
- ):
- self.node = node
- self.context = context
- return call_with_error_reporting(
- context.report_error, span, self.func, *arg_list, span=tvm_span_from_synr(span)
- )
-
-
-@register
-class MatchBuffer(SpecialStmt):
- """Special Stmt match_buffer(param, shape, dtype, data, strides, elem_offset, scope, align,
- offset_factor, buffer_type, axis_separators)
-
- Note
- ----
- This Special Stmt will perform different behavior depends on the type of param.
- If the param is a var in function parameter, it will create a buffer from DLTensor.
- Else if the param is a subregion of other buffers, then create a subregion match inside a block.
-
- Example
- -------
- Match buffer from function parameter
- .. code-block:: python
- A = T.match_buffer(a, (128, 128), dtype="float32")
-
- Match buffer from Buffer subregion
- .. code-block:: python
- A = T.match_buffer(B[0:128, i * 128 : i * 128 + 128], (128, 128), dtype="float32")
- """
-
- def __init__(self):
- def match_buffer(
- param,
- shape,
- dtype="float32",
- data=None,
- strides=None,
- elem_offset=None,
- scope="global",
- align=-1,
- offset_factor=0,
- buffer_type="default",
- axis_separators=None,
- span=None,
- ):
- if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
- self.context.report_error(
- "`match_buffer` must be assigned to a single buffer, "
- "e.g. A = match_buffer(...)",
- self.node.span,
- )
- if strides is None:
- strides = []
- align = convert_to_int(align, "align", self.context.report_error, self.node.span)
- offset_factor = convert_to_int(
- offset_factor, "offset_factor", self.context.report_error, self.node.span
- )
- buffer_name: str = self.node.lhs[0].id.name
- buffer = tvm.tir.decl_buffer(
- shape,
- dtype,
- buffer_name,
- data,
- strides,
- elem_offset,
- scope,
- align,
- offset_factor,
- buffer_type,
- axis_separators,
- span=span,
- )
- if isinstance(param, tvm.tir.Var):
- if param not in self.context.func_params:
- self.context.report_error(
- "Can not bind non-input param to buffer", self.node.rhs.params[0].span
- )
- self.context.func_buffer_map[param] = buffer
- elif isinstance(param, BufferSlice):
- buffer_region = param.as_buffer_region()
- self.context.current_block_scope().match_buffers.append(
- tvm.tir.MatchBufferRegion(buffer, buffer_region)
- )
- else:
- self.context.report_error(
- "The source of match_buffer expected Var or BufferSlice, but got "
- + str(type(param)),
- self.node.rhs.params[0].span,
- )
- self.context.update_symbol(buffer_name, buffer, self.node)
-
- super().__init__(match_buffer, def_symbol=True)
-
-
-@register
-class BufferDeclare(SpecialStmt):
- """Special Stmt buffer_decl(shape, dtype, data, strides, elem_offset, scope, align,
- offset_factor, buffer_type, axis_separators)
- Example
- -------
- .. code-block:: python
- A = T.buffer_decl((128, 128), dtype="float32")
- """
-
- def __init__(self):
- def buffer_decl(
- shape,
- dtype="float32",
- data=None,
- strides=None,
- elem_offset=None,
- scope="global",
- align=-1,
- offset_factor=0,
- buffer_type="default",
- axis_separators=None,
- span=None,
- ):
- if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
- self.context.report_error(
- "`buffer_decl` must be assigned to a single buffer, e.g. A = buffer_decl(...)",
- self.node.span,
- )
-
- if strides is None:
- strides = []
- align = convert_to_int(align, "align", self.context.report_error, self.node.span)
- offset_factor = convert_to_int(
- offset_factor, "offset_factor", self.context.report_error, self.node.span
- )
- buffer_name: str = self.node.lhs[0].id.name
- buffer = tvm.tir.decl_buffer(
- shape,
- dtype,
- buffer_name,
- data,
- strides,
- elem_offset,
- scope,
- align,
- offset_factor,
- buffer_type,
- axis_separators,
- span=span,
- )
- self.context.update_symbol(buffer_name, buffer, self.node)
- return buffer
-
- super().__init__(buffer_decl, def_symbol=True)
-
-
-@register
-class AllocBuffer(SpecialStmt):
- """Special function alloc_buffer(shape, dtype, data, strides, elem_offset, scope, align,
- offset_factor, buffer_type, axis_separators)
-
- Example
- -------
- .. code-block:: python
-
- A = T.alloc_buffer((128, 128), dtype="float32")
- """
-
- def __init__(self):
- def alloc_buffer(
- shape,
- dtype="float32",
- data=None,
- strides=None,
- elem_offset=None,
- scope="global",
- align=-1,
- offset_factor=0,
- buffer_type="default",
- axis_separators=None,
- span=None,
- ):
- if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
- self.context.report_error(
- "`alloc_buffer` must be assigned to a single buffer, "
- "e.g. A = alloc_buffer(...)",
- self.node.span,
- )
-
- if strides is None:
- strides = []
- align = convert_to_int(align, "align", self.context.report_error, self.node.span)
- offset_factor = convert_to_int(
- offset_factor, "offset_factor", self.context.report_error, self.node.span
- )
- buffer_name: str = self.node.lhs[0].id.name
- buffer = tvm.tir.decl_buffer(
- shape,
- dtype,
- buffer_name,
- data,
- strides,
- elem_offset,
- scope,
- align,
- offset_factor,
- buffer_type,
- axis_separators,
- span=span,
- )
- if self.context.current_block_scope():
- self.context.current_block_scope().alloc_buffers.append(buffer)
- else:
- # If it is allocated outside all blocks, allocate it under root block.
- self.context.root_alloc_buffers.append(buffer)
- self.context.update_symbol(buffer_name, buffer, self.node)
-
- super().__init__(alloc_buffer, def_symbol=True)
-
-
-@register
-class BlockReads(SpecialStmt):
- """Special function reads([read_regions], *other_regions)
-
- Note
- ----
- *other_region is an unpackable list of BufferSlice to support
- reads syntax sugar like reads(BufferRegion1, BufferRegion2, ...)
-
- Example
- -------
- .. code-block:: python
-
- T.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]])
- """
-
- def __init__(self):
- def reads(
- *read_regions: Union[BufferSlice, List[BufferSlice]],
- span: Span = None,
- ):
- assert self.context, "call 'exit_scope' before 'enter_scope'"
- block_scope = self.context.current_block_scope()
- if block_scope is None:
- self.context.report_error(
- "Expected to declare read regions inside a block.",
- span,
- )
- if block_scope.reads is not None:
- self.context.report_error(
- "Duplicate write region declaration, "
- + "previous one is "
- + str(", ".join(str(x) for x in block_scope.reads)),
- span,
- )
- if len(read_regions) > 1:
- for read_region in read_regions:
- if not isinstance(read_region, BufferSlice):
- self.context.report_error(
- "Incorrect input type. Expected *BufferSlice or List[BufferSlice],"
- + f" but got {type(read_regions)}",
- span,
- )
- elif len(read_regions) == 1:
- if isinstance(read_regions[0], list):
- read_regions = read_regions[0]
-
- block_scope.reads = read_regions
-
- super().__init__(reads, def_symbol=False)
-
-
-@register
-class BlockWrites(SpecialStmt):
- """Special function writes([write_regions], *other_regions)
-
- Note
- ----
- *other_region is an unpackable list of BufferSlice to support
- writes syntax sugar like writes(BufferRegion1, BufferRegion2, ...)
-
- Example
- -------
- .. code-block:: python
-
- T.writes([C[vi: vi + 4, vj])
- """
-
- def __init__(self):
- def writes(
- *write_regions: Union[BufferSlice, List[BufferSlice]],
- span: Span = None,
- ):
- assert self.context, "call 'exit_scope' before 'enter_scope'"
- block_scope = self.context.current_block_scope()
- if block_scope is None:
- self.context.report_error(
- "Expected to declare write regions inside a block.",
- span,
- )
- if block_scope.writes is not None:
- self.context.report_error(
- "Duplicate write region declaration, "
- + "previous one is "
- + str(", ".join(str(x) for x in block_scope.writes)),
- span,
- )
- if len(write_regions) > 1:
- for write_region in write_regions:
- if not isinstance(write_region, BufferSlice):
- self.context.report_error(
- "Incorrect input type. Expected *BufferSlice or List[BufferSlice],"
- + f" but got {type(write_regions)}",
- span,
- )
- elif len(write_regions) == 1:
- if isinstance(write_regions[0], list):
- write_regions = write_regions[0]
- block_scope.writes = write_regions
-
- super().__init__(writes, def_symbol=False)
-
-
-@register
-class BlockAttr(SpecialStmt):
- """Special function block_attr({attr_key: attr_value})
-
- Example
- -------
- .. code-block:: python
-
- T.block_attr({"double_buffer_scope": 1})
- """
-
- def __init__(self):
- def block_attr(attrs: Mapping[str, Object], span: Span = None):
- assert self.context, "call 'exit_scope' before 'enter_scope'"
- block_scope = self.context.current_block_scope()
- if block_scope is None:
- self.context.report_error(
- "Expected to declare block annotations inside a block.",
- span,
- )
- if block_scope.annotations is not None:
- self.context.report_error(
- "Duplicate block annotations declaration, "
- + "previous one is "
- + str(block_scope.annotations),
- span,
- )
- attrs = {
- key: String(val) if isinstance(val, str) else val for key, val in attrs.items()
- }
- block_scope.annotations = attrs
-
- super().__init__(block_attr, def_symbol=False)
-
-
-class BlockAxis(SpecialStmt):
- """Special stmt for defining a spatial block axis
- axis.S(dom, iter_value)
-
- Example
- -------
- .. code-block:: python
-
- vi = T.axis.S(128, i * 4 + j)
- """
-
- def axis(
- self,
- var_name: str,
- dom: Union[PrimExpr, Range],
- value: PrimExpr,
- iter_type: int,
- span: Optional[Span] = None,
- ) -> None:
- """
- Helper function for creating block axis
-
- Parameters
- ----------
- var_name : str
- The name_hint of var
-
- dom : Union[PrimExpr, Range]
- The iter domain.
-
- value : PrimExpr
- The binding value
-
- iter_type : int
- The iteration type.
-
- span : Optional[Span]
- The location of this for in the source code.
- """
- assert self.context, "call 'exit_scope' before 'enter_scope'"
- block_scope: BlockInfo = self.context.current_block_scope()
- if block_scope is None:
- self.context.report_error(
- "Expected to declare block axes inside a block.",
- self.node.span,
- )
- if var_name in [iter_var.var.name for iter_var in block_scope.iter_vars]:
- self.context.report_error("Duplicate block axis " + var_name, self.node.span)
-
- dom = tvm.runtime.convert(dom)
- if isinstance(dom, PrimExpr):
- dom = tvm.ir.Range(dom)
- elif isinstance(dom, tvm.ir.container.Array) and len(dom) == 2:
- dom = tvm.ir.Range(dom[0], dom[1])
- elif not isinstance(dom, tvm.ir.Range):
- self.context.report_error(
- f"Block axis domain expected PrimExpr or Range, but got {type(dom)}",
- self.node.span,
- )
- block_var = tvm.tir.Var(var_name, dtype=dom.extent.dtype)
- value = tvm.runtime.convert(value)
- if not isinstance(value, PrimExpr):
- self.context.report_error(
- f"Block axis value expected PrimExpr, but got {type(value)}",
- self.node.span,
- )
- iter_var = tvm.tir.IterVar(dom, block_var, iter_type)
- block_scope.iter_vars.append(iter_var)
- block_scope.iter_values.append(value)
- self.context.update_symbol(var_name, block_var, self.node)
-
-
-@register
-class BlockAxisSpatial(BlockAxis):
- """Special stmt for defining a spatial block axis
- axis.spatial(dom, iter_value)
-
- Example
- -------
- .. code-block:: python
-
- vi = T.axis.spatial(128, k)
- """
-
- def __init__(self):
- def axis_spatial(
- dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
- ):
- if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
- self.context.report_error(
- "`axis.spatial` must be assigned to a var, e.g. vi = axis.spatial(...)",
- self.node.span,
- )
- self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DataPar)
-
- super().__init__(axis_spatial, def_symbol=True)
-
- def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
- return "tir.axis.spatial", get_param_list(self.func)
-
-
-@register
-class BlockAxisS(BlockAxis):
- """The sugar special stmt for defining a spatial block axis
- axis.S(dom, iter_value)
-
- Example
- -------
- .. code-block:: python
-
- vi = T.axis.S(128, k)
- """
-
- def __init__(self):
- def axis_spatial(
- dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
- ):
- if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
- self.context.report_error(
- "`axis.S` must be assigned to a var, e.g. vi = axis.S(...)",
- self.node.span,
- )
- self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DataPar)
-
- super().__init__(axis_spatial, def_symbol=True)
-
- def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
- return "tir.axis.S", get_param_list(self.func)
-
-
-@register
-class BlockAxisReduce(BlockAxis):
- """Special stmt for defining a reduce block axis
- axis.reduce(dom, iter_value)
-
- Example
- -------
- .. code-block:: python
-
- vi = T.axis.reduce(128, k)
- """
-
- def __init__(self):
- def axis_reduce(
- dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
- ):
- if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
- self.context.report_error(
- "`axis.reduce` must be assigned` to a var, e.g. vi = axis.reduce(...)",
- self.node.span,
- )
- self.axis(self.node.lhs[0].id.name, dom, value, IterVar.CommReduce)
-
- super().__init__(axis_reduce, def_symbol=True)
-
- def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
- return "tir.axis.reduce", get_param_list(self.func)
-
-
-@register
-class BlockAxisR(BlockAxis):
- """The sugar special stmt for defining a reduce block axis
- axis.R(dom, iter_value)
-
- Example
- -------
- .. code-block:: python
-
- vi = T.axis.R(128, k)
- """
-
- def __init__(self):
- def axis_reduce(
- dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
- ):
- if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
- self.context.report_error(
- "`axis.R` must be assigned to a var, e.g. vi = axis.R(...)",
- self.node.span,
- )
- self.axis(self.node.lhs[0].id.name, dom, value, IterVar.CommReduce)
-
- super().__init__(axis_reduce, def_symbol=True)
-
- def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
- return "tir.axis.R", get_param_list(self.func)
-
-
-@register
-class BlockAxisScan(BlockAxis):
- """Special stmt for defining a ordered block axis
- axis.scan(dom, iter_value)
-
- Example
- -------
- .. code-block:: python
-
- vi = T.axis.scan(128, k)
- """
-
- def __init__(self):
- def axis_scan(
- dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
- ):
- if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
- self.context.report_error(
- "`axis.scan` must be assigned to a var, e.g. vi = axis.scan(...)",
- self.node.span,
- )
- self.axis(self.node.lhs[0].id.name, dom, value, IterVar.Ordered)
-
- super().__init__(axis_scan, def_symbol=True)
-
- def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
- return "tir.axis.scan", get_param_list(self.func)
-
-
-@register
-class BlockAxisOpaque(BlockAxis):
- """Special stmt for defining a opaque block axis
- axis.opaque(dom, iter_value)
-
- Example
- -------
- .. code-block:: python
-
- vi = T.axis.opaque(128, k)
- """
-
- def __init__(self):
- def axis_opaque(
- dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
- ):
- if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
- self.context.report_error(
- "`axis.opaque` must be assigned to a var, e.g. vi = axis.opaque(...)",
- self.node.span,
- )
- self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DimInfo)
-
- super().__init__(axis_opaque, def_symbol=True)
-
- def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
- return "tir.axis.opaque", get_param_list(self.func)
-
-
-@register
-class BlockAxisRemap(BlockAxis):
- """Special stmt for remapping loops vars to block axes.
- axis.remap(iter_type, iter_value)
-
- Note
- ----
- Iter_type is a string consisting of 'S' and 'R', where 'S' means
- for spatial and 'R' means for reduce.
-
- Example
- -------
- .. code-block:: python
-
- vi, vj = T.axis.remap("SS", [i, j])
- """
-
- def __init__(self):
- def axis_remap(iter_types: str, loop_vars: List[tvm.tir.expr.Var], span: Span = None):
- if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) >= 1:
- self.context.report_error(
- "`axis.remap` must be assigned to one or more vars, "
- "e.g. vi, vj = axis.remap(...)",
- self.node.span,
- )
- var_num: int = len(self.node.lhs)
- if var_num != len(iter_types):
- self.context.report_error(
- f"`iter_type` expected {var_num} charactor(s), "
- f"but got {len(iter_types)}: {iter_types}",
- span,
- )
- if var_num != len(loop_vars):
- self.context.report_error(
- f"`iter_type` expected {var_num} loop var(s), "
- f"but got {len(loop_vars)}: {loop_vars}",
- span,
- )
- for var, iter_ty, loop_var in zip(self.node.lhs, iter_types, loop_vars):
- iter_type: int
- if iter_ty == "S":
- iter_type = IterVar.DataPar
- elif iter_ty == "R":
- iter_type = IterVar.CommReduce
- else:
- self.context.report_error(
- f'`iter_type` only expected "S" (for spatial) or "R" (for reduce), '
- f'but got "{iter_ty}"',
- span,
- )
-
- if not isinstance(loop_var, tvm.tir.expr.Var):
- self.context.report_error(
- f"Values of `axis.remap` expected single loop var, but got {loop_var}",
- loop_var.span,
- )
- loops = self.context.loop_stack
- if loop_var not in loops:
- self.context.report_error(
- f"Cannot find loop var {loop_var} in loop nesting.",
- span,
- )
- self.axis(var.id.name, loops[loop_var], loop_var, iter_type)
-
- super().__init__(axis_remap, def_symbol=True)
-
- def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
- return "tir.axis.remap", get_param_list(self.func)
-
-
-@register
-class BlockPredicate(SpecialStmt):
- """Special function where(predicate)
-
- Example
- -------
- .. code-block:: python
-
- T.where(i < 4)
- """
-
- def __init__(self):
- def where(predicate, span=None):
- assert self.context, "call 'exit_scope' before 'enter_scope'"
- block_scope = self.context.current_block_scope()
- if block_scope is None:
- self.context.report_error(
- "Expected to declare the predicate inside a block.",
- span,
- )
- if block_scope.predicate is not None:
- self.context.report_error(
- "Duplicate block predicate declaration, "
- + "previous one is "
- + str(block_scope.predicate),
- span,
- )
-
- block_scope.predicate = predicate
-
- super().__init__(where, def_symbol=False)
-
-
-@register
-class VarDef(SpecialStmt):
- """Special function for defining a Var"""
-
- def __init__(self):
- def var(dtype, span):
- assert isinstance(
- self.node, ast.Assign
- ), f"VarDef expected ast.Assign but got {type(self.node)}"
- names = [x.id.name for x in self.node.lhs]
- if len(names) != 1:
- self.context.report_error(
- f"VarDef expected assign to only one var, but got {names}", span
- )
- v = Var(names[0], dtype, span=span)
- self.context.update_symbol(v.name, v, self.node)
-
- super().__init__(var, def_symbol=True)
-
-
-@register
-class BufferVarDef(SpecialStmt):
- """Special function for defining a variable of pointer type"""
-
- def __init__(self):
- def buffer_var(dtype, storage_scope, span):
- assert isinstance(
- self.node, ast.Assign
- ), f"BufferVarDef expected ast.Assign but got {type(self.node)}"
- names = [x.id.name for x in self.node.lhs]
- if len(names) != 1:
- self.context.report_error(
- f"VarDef expected assign to only one var, but got {names}", span
- )
- ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope)
- v = Var(names[0], ptr_type, span=span)
- self.context.update_symbol(v.name, v, self.node)
-
- super().__init__(buffer_var, def_symbol=True)
-
-
-@register
-class EnvThread(SpecialStmt):
- """Bind a var to thread env"""
-
- def __init__(self):
- def env_thread(env_name, span):
- assert isinstance(
- self.node, ast.Assign
- ), f"EnvThread expected ast.Assign but got {type(self.node)}"
- names = [x.id.name for x in self.node.lhs]
- if len(names) != 1:
- self.context.report_error(
- f"VarDef expected assign to only one var, but got {names}", span
- )
- v = Var(names[0], dtype="int32", span=span)
- self.context.func_var_env_dict[v] = env_name
- self.context.update_symbol(v.name, v, self.node)
-
- super().__init__(env_thread, def_symbol=True)
-
-
-@register
-class FuncAttr(SpecialStmt):
- """Special Stmt for declaring the DictAttr of PrimFunc
- Example
- -------
- .. code-block:: python
- T.func_attr({"tir.noalias": True, "global_symbol"})
- """
-
- def __init__(self):
- def func_attr(dict_attr, span):
- self.context.func_dict_attr = dict_attr
-
- super().__init__(func_attr, def_symbol=False)
-
-
-@register
-class PreflattenedBufferMap(SpecialStmt):
- """Special Stmt for declaring the PrimFunc::preflattened_buffer_map
-
- Example
- -------
- .. code-block:: python
- A0 = T.match_buffer(A, (48,), dtype="float32")
- T.preflattened_buffer_map(A, (1, 4, 4, 3), elem_offset=1, align=4, dtype="float32")
- """
-
- def __init__(self):
- def preflattened_buffer(
- postflattened,
- shape,
- dtype="float32",
- data=None,
- strides=None,
- elem_offset=None,
- scope="global",
- align=-1,
- offset_factor=0,
- buffer_type="default",
- span=None,
- ):
-
- param = None
- for key, value in self.context.func_buffer_map.items():
- if value.same_as(postflattened):
- param = key
- break
-
- assert (
- param is not None
- ), f"Post-flatten buffer {postflattened.name} does not appear in the buffer map."
-
- if data is None:
- data = self.context.func_buffer_map[param].data
-
- buffer_name: str = f"{postflattened.name}_preflatten"
- if align != -1:
- if isinstance(align, IntImm):
- align = align.value
- else:
- assert isinstance(align, int), f"align: want int or IntImm, got {align!r}"
-
- if offset_factor != 0:
- if isinstance(offset_factor, IntImm):
- offset_factor = offset_factor.value
- else:
- assert isinstance(
- offset_factor, int
- ), f"offset_factor: want int or IntImm, got {offset_factor!r}"
-
- preflattened = tvm.tir.decl_buffer(
- shape,
- dtype,
- buffer_name,
- data,
- strides,
- elem_offset,
- scope,
- align,
- offset_factor,
- buffer_type,
- span=span,
- )
-
- self.context.func_preflattened_buffer_map[param] = preflattened
-
- super().__init__(preflattened_buffer, def_symbol=False)
-
-
-@register
-class TargetAttrValue(SpecialStmt):
- """Special Stmt for target attr value.
- Example
- -------
- .. code-block:: python
- T.target("llvm")
- """
-
- def __init__(self):
- def target(*args, span):
- self.context.report_error(f"T.target should not appear as a stmt", span)
-
- super().__init__(target, def_symbol=False)
-
- def __call__(self, target_config):
- if not isinstance(target_config, (str, dict)):
- raise ValueError(
- f"T.target expected a config dict or string, but got {type(target_config)}"
- )
- return Target(target_config)
diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py
deleted file mode 100644
index 4548102a9e..0000000000
--- a/python/tvm/script/tir/ty.py
+++ /dev/null
@@ -1,216 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser Typing Class for TIR
-
-This module provides typing class for TVM script type annotation usage, it can be viewed as
-a wrapper for uniform Type system in IR
-"""
-# pylint: disable=invalid-name
-from numbers import Integral
-
-import tvm
-from .special_stmt import SpecialStmt, convert_to_int
-
-
-class TypeGeneric: # pylint: disable=too-few-public-methods
- """Base class for all the TVM script typing class"""
-
- def evaluate(self):
- """Return an actual ir.Type Object that this Generic class wraps"""
- raise TypeError("Cannot get tvm.Type from a generic type")
-
- def require_type_generic_at(self, idx): # pylint: disable=unused-argument
- """If True, the `idx`th type argument must be TypeGeneric"""
- return True
-
- # This function is added here to avoid a pylint error
- # for T.int/float below not being callable
- def __call__(self):
- raise NotImplementedError()
-
-
-class ConcreteType(TypeGeneric): # pylint: disable=too-few-public-methods, abstract-method
- """TVM script typing class for uniform Type objects
-
- Params
- ------
- vtype: Union[str, tvm.ir.Type]
-
- The IR type represented by the type annotation. If a string
- (e.g. "float32"), this represents a `ir.PrimType` generated
- from that string. If a `ir.Type` is provided, this represents
- the type provided.
- """
-
- def __init__(self, vtype):
- if isinstance(vtype, tvm.ir.Type):
- self.type = vtype
- else:
- self.type = tvm.ir.PrimType(vtype)
-
- def __call__(self, *args): # pylint: disable=arguments-differ
- pass
-
- def evaluate(self):
- return self.type
-
-
-class VoidType(ConcreteType): # pylint: disable=too-few-public-methods, abstract-method
- """TVM script typing class for void type"""
-
- def __init__(self):
- super().__init__("")
-
-
-class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method
- """TVM script typing class generator for PtrType
-
- [] operator is overloaded, accepts a ConcreteType and an optional storage scope string,
- returns a ConcreteType wrapping PtrType
- """
-
- def __getitem__(self, args):
- if isinstance(args, TypeGeneric):
- args = [args]
- if len(args) == 1:
- vtype, scope = args[0], "global"
- elif len(args) == 2:
- vtype, scope = args[0], args[1]
- else:
- raise TypeError(f"Illegal type argument num for Ptr")
- if not isinstance(vtype, TypeGeneric):
- raise TypeError(f"Ptr expects a type argument, but received {type(vtype).__name__}")
- if not isinstance(scope, str):
- raise TypeError(f"Ptr expects storage scope argument be a string")
- return ConcreteType(tvm.ir.PointerType(vtype.evaluate(), scope))
-
- def require_type_generic_at(self, idx):
- return idx != 1 # the second argument is storage scope for Ptr
-
-
-class GenericTupleType(TypeGeneric): # pylint: disable=abstract-method
- """TVM script typing class generator for TupleType
-
- [] operator is overloaded, accepts a list of ConcreteType and returns a ConcreteType
- wrapping TupleType
- """
-
- def __getitem__(self, vtypes):
- if isinstance(vtypes, TypeGeneric):
- vtypes = [vtypes]
- return ConcreteType(tvm.ir.TupleType([vtype.evaluate() for vtype in vtypes]))
-
-
-class GenericBufferType(SpecialStmt): # pylint: disable=too-few-public-methods, abstract-method
- """TVM script typing class for uniform Type objects"""
-
- def __init__(self, vtype):
- def match_buffer_syntax_sugar(
- shape,
- dtype: str = "float32",
- name: str = None,
- data=None,
- strides=None,
- elem_offset=None,
- scope="global",
- align=-1,
- offset_factor=0,
- buffer_type="default",
- axis_separators=None,
- span=None,
- ):
- if strides is None:
- strides = []
- align = convert_to_int(align, "align", self.context.report_error, self.node.span)
- offset_factor = convert_to_int(
- offset_factor, "offset_factor", self.context.report_error, self.node.span
- )
- buffer = tvm.tir.decl_buffer(
- shape,
- dtype,
- name,
- data,
- strides,
- elem_offset,
- scope,
- align,
- offset_factor,
- buffer_type,
- axis_separators,
- span=span,
- )
- return buffer
-
- self.type = vtype
- super().__init__(match_buffer_syntax_sugar, def_symbol=True)
-
- def __call__(
- self,
- shape,
- dtype="float32",
- *,
- name: str = None,
- data=None,
- strides=None,
- elem_offset=None,
- scope="global",
- align=-1,
- offset_factor=0,
- buffer_type="default",
- axis_separators=None,
- span=None,
- ):
- """
- This function is for Buffer(...) syntax sugar.
- """
- pass # pylint: disable=unnecessary-pass
-
- def __getitem__(self, args):
- """
- This function is for Buffer[...] syntax sugar
- Note that args is the list of all arguments
- """
- if len(args) < 2:
- raise ValueError("T.Buffer[...] needs at least two arguments: shape and dtype.")
-
- shape = args[0]
- dtype = args[1]
-
- valid_shape = isinstance(shape, (tvm.ir.PrimExpr, Integral, tuple, list))
- valid_dtype = isinstance(dtype, str)
- if not (valid_shape and valid_dtype):
- raise ValueError(
- "The first argument of T.Buffer[...] needs to be a tuple, "
- "followed by the second argument dtype as a string"
- )
-
-
-# add all floating point and integer datatypes to the module
-for _dtype in ["float", "uint", "int"]:
- for _size in ["8", "16", "32", "64"]:
- for _lanes in ["", "x4", "x8", "x16", "x32"]:
- _name = _dtype + _size + _lanes
- globals()[_name] = ConcreteType(_name)
-
-boolean = ConcreteType("bool")
-handle = ConcreteType("handle")
-void = VoidType()
-Ptr = GenericPtrType()
-Tuple = GenericTupleType()
-# we don't have 'buffer' type on the cpp side
-# thus 'handle' is used here for convenience's sake
-Buffer = GenericBufferType("handle")
diff --git a/python/tvm/script/utils.py b/python/tvm/script/utils.py
deleted file mode 100644
index c655a62237..0000000000
--- a/python/tvm/script/utils.py
+++ /dev/null
@@ -1,105 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Helper functions in TVM Script Parser"""
-
-from typing import Callable, List, Any, Optional, Tuple
-
-import inspect
-import synr
-
-from tvm.ir import Span, SourceName
-from tvm.error import DiagnosticError
-
-
-def get_param_list(
- func: Callable,
-) -> Tuple[List[str], List[Tuple[str, Tuple[Any, ...]]], Optional[str]]:
- """Get the parameter list from definition of function"""
- full_arg_spec: inspect.FullArgSpec = inspect.getfullargspec(func)
-
- args: List[str]
- defaults: Optional[Tuple[Any, ...]]
- kwonlyargs: List[str]
- args, defaults, kwonlyargs = (
- full_arg_spec.args,
- full_arg_spec.defaults,
- full_arg_spec.kwonlyargs,
- )
-
- if defaults is None:
- defaults = tuple()
-
- if full_arg_spec.varkw is not None:
- raise RuntimeError(
- "TVM Script register error : variable keyword argument is not supported now"
- )
-
- if len(kwonlyargs) == 1 and kwonlyargs[0] == "span":
- pass
- elif not len(kwonlyargs) == 0:
- raise RuntimeError("TVM Script register error : keyword only argument is not supported now")
-
- pos_only: List[str] = list()
- for arg in args[: len(args) - len(defaults)]:
- if arg != "span":
- pos_only.append(arg)
- kwargs: List[Tuple[str, Tuple[Any, ...]]] = list()
- for default, arg in zip(defaults, args[len(args) - len(defaults) :]):
- if arg != "span":
- kwargs.append((arg, default))
-
- return pos_only, kwargs, full_arg_spec.varargs
-
-
-def tvm_span_from_synr(span: synr.ast.Span) -> Span:
- """Convert a synr span to a TVM span"""
- return Span(
- SourceName(span.filename),
- span.start_line,
- span.end_line,
- span.start_column,
- span.end_column,
- )
-
-
-def synr_span_from_tvm(span: Span) -> synr.ast.Span:
- """Convert a TVM span to a synr span"""
- return synr.ast.Span(
- span.source_name.name,
- span.line,
- span.column,
- span.end_line,
- span.end_column,
- )
-
-
-def call_with_error_reporting(
- report_error,
- node_span,
- func,
- *args,
- **kwargs,
-):
- """Call function with exception handling and report error using node_span"""
- try:
- return func(*args, **kwargs)
- except DiagnosticError:
- raise
- except Exception as err: # pylint: disable=broad-except
- # printing last non-empty row of error message.
- error_msg = list(filter(None, str(err).split("\n")))[-1]
- report_error(error_msg, node_span)
diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py
index 5454041713..e769559b47 100644
--- a/python/tvm/tir/analysis/analysis.py
+++ b/python/tvm/tir/analysis/analysis.py
@@ -20,11 +20,11 @@ from typing import Dict, List, Union
from tvm import Object
from tvm.ir import IRModule
-from tvm.tir.expr import Var
-from tvm.tir.stmt import Block, BufferRegion, PrimExpr
-from .. import Buffer, Stmt
+from ..buffer import Buffer
+from ..expr import Var
from ..function import PrimFunc
+from ..stmt import Block, BufferRegion, PrimExpr, Stmt
from . import _ffi_api
diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py
index beefcb0d28..5742999c67 100644
--- a/python/tvm/tir/expr.py
+++ b/python/tvm/tir/expr.py
@@ -28,15 +28,16 @@ For example, you can use addexp.a to get the left operand of an Add node.
assert(y.a == x)
"""
from typing import Optional, Union
-from tvm import ir
+
import tvm._ffi
+import tvm.ir._ffi_api
+from tvm import ir
+from tvm.ir import Op, PrimExpr
from tvm.ir.base import Span
+from tvm.runtime import DataType, DataTypeCode, Object, ObjectGeneric, const
-from tvm.runtime import Object, ObjectGeneric, DataType, DataTypeCode, const
-from tvm.ir import PrimExpr, Op
-import tvm.ir._ffi_api
-from . import generic as _generic
from . import _ffi_api
+from . import generic as _generic
def div_ambiguity_error():
@@ -66,8 +67,6 @@ def _dtype_is_float(value):
class ExprOp(object):
"""Operator overloading for Expr like expressions."""
- # TODO(tkonolige): use inspect to add source information to these objects
-
def __add__(self, other):
return _generic.add(self, other)
@@ -1005,6 +1004,8 @@ class Select(PrimExprWithOp):
"""
def __init__(self, condition, true_value, false_value, span=None):
+ if isinstance(condition, bool):
+ condition = IntImm("bool", condition)
self.__init_handle_by_constructor__(
_ffi_api.Select, condition, true_value, false_value, span # type: ignore
)
diff --git a/python/tvm/tir/schedule/block_scope.py b/python/tvm/tir/schedule/block_scope.py
index 30e047b4f7..0ebaf212d1 100644
--- a/python/tvm/tir/schedule/block_scope.py
+++ b/python/tvm/tir/schedule/block_scope.py
@@ -20,8 +20,8 @@ from typing import List, Optional, Union
from tvm._ffi import register_object
from tvm.runtime import Object
-from tvm.tir import Block, For
+from ..stmt import Block, For
from . import _ffi_api
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index fdc8717032..043f7922be 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -21,9 +21,11 @@ from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
from tvm.ir import IRModule, PrimExpr
from tvm.runtime import Object, String
-from tvm.tir import Block, Buffer, FloatImm, For, IntImm, PrimFunc
-from ..function import IndexMap
+from ..buffer import Buffer
+from ..expr import FloatImm, IntImm
+from ..function import IndexMap, PrimFunc
+from ..stmt import Block, For
from . import _ffi_api
from ._type_checker import type_checked
from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod
diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py
index fbf21843e7..3aed52fb50 100644
--- a/python/tvm/tir/schedule/state.py
+++ b/python/tvm/tir/schedule/state.py
@@ -22,8 +22,9 @@ from typing import Dict, Optional, Union
from tvm._ffi import register_object
from tvm.ir import IRModule
from tvm.runtime import Object
-from tvm.tir import Block, BlockRealize, For, PrimFunc
+from ..function import PrimFunc
+from ..stmt import Block, BlockRealize, For
from . import _ffi_api
from .block_scope import BlockScope, StmtSRef
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index 4847e377de..3c2228e6d9 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -754,3 +754,7 @@ def stmt_list(stmt):
res += stmt_list(x)
return res
return [stmt]
+
+
+def type_annotation(dtype, span=None):
+ return _ffi_api.TypeAnnotation(dtype, span)
diff --git a/python/tvm/tir/usmp/transform/transform.py b/python/tvm/tir/usmp/transform/transform.py
index f472172cf3..86d8bef356 100644
--- a/python/tvm/tir/usmp/transform/transform.py
+++ b/python/tvm/tir/usmp/transform/transform.py
@@ -20,8 +20,9 @@
from typing import Dict
import tvm
-from tvm.tir import Stmt
-from tvm.tir.usmp.utils import PoolAllocation
+
+from ...stmt import Stmt
+from ..utils import PoolAllocation
from . import _ffi_api
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index e21d014fe1..0cb8f79942 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -810,6 +810,14 @@ BufferRegion::BufferRegion(Buffer buffer, Array<Range> region) {
CHECK_EQ(buffer->shape.size(), region.size())
<< "The dimension between " << buffer << " and region " << region
<< " mismatched, the buffer is " << buffer;
+ for (const Range& r : region) {
+ ICHECK(r->min->dtype.is_int() || r->min->dtype.is_uint())
+ << "ValueError: ranges of BufferRegion should be int, but got type " << r->min->dtype
+ << " for range " << r << " in its min value " << r->min;
+ ICHECK(r->extent->dtype.is_int() || r->extent->dtype.is_uint())
+ << "ValueError: ranges of BufferRegion should be int, but got type " << r->extent->dtype
+ << " for range " << r << " in its extent value " << r->extent;
+ }
ObjectPtr<BufferRegionNode> node = make_object<BufferRegionNode>();
node->buffer = std::move(buffer);
node->region = std::move(region);
@@ -1091,5 +1099,7 @@ PrimExpr TypeAnnotation(DataType dtype, Span span) {
TVM_REGISTER_OP("tir.type_annotation")
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
+TVM_REGISTER_GLOBAL("tir.TypeAnnotation").set_body_typed(TypeAnnotation);
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 509badbebb..0e605d1439 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -97,6 +97,17 @@ PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span s
{x, y, q, s}, span);
}
+// address_of
+PrimExpr address_of(tir::BufferLoad buffer_load, Span span) {
+ return tir::Call(DataType::Handle(), tir::builtin::address_of(), {buffer_load}, span);
+}
+
+// lookup_param
+PrimExpr lookup_param(String param_name, Span span) {
+ return tir::Call(DataType::Handle(), tir::builtin::lookup_param(), {tir::StringImm(param_name)},
+ span);
+}
+
// The public function with a quick checking path.
void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*)
CHECK(lhs.defined()) << "ValueError: `lhs` is null in the binary operator";
@@ -705,6 +716,11 @@ PrimExpr isnan(PrimExpr x, Span span) {
}
}
+// isnullptr
+PrimExpr isnullptr(PrimExpr x, Span span) {
+ return tir::Call(DataType::Bool(1), tir::builtin::isnullptr(), {x}, span);
+}
+
// isinf
PrimExpr isinf(PrimExpr x, Span span) {
DataType t = DataType::Bool(x.dtype().lanes());
@@ -956,6 +972,10 @@ TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc);
TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast);
+TVM_REGISTER_GLOBAL("tir.infinity").set_body_typed(tvm::infinity);
+
+TVM_REGISTER_GLOBAL("tir.reinterpret").set_body_typed(tvm::reinterpret);
+
// operator overloading, smarter than make
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \
@@ -1004,6 +1024,8 @@ REGISTER_MAKE_BIT_OP(bitwise_xor, bitwise_xor);
REGISTER_MAKE_BIT_OP(left_shift, left_shift); // NOLINT(*)
REGISTER_MAKE_BIT_OP(right_shift, right_shift);
+TVM_REGISTER_GLOBAL("tir._OpNot").set_body_typed(logical_not);
+
TVM_REGISTER_GLOBAL("tir._OpIfThenElse")
.set_body_typed([](PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span) {
return if_then_else(cond, true_value, false_value, span);