You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/09/17 22:55:08 UTC

[tvm] 01/02: IRBuilder

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch ir-builder-v2
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit e6e27a4d179b36a492f83822c742739ea10f5a3c
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Mon Aug 15 12:33:04 2022 -0700

    IRBuilder
---
 include/tvm/tir/op.h                               |   33 +-
 python/tvm/script/__init__.py                      |    6 +-
 python/tvm/script/context_maintainer.py            |  251 ----
 python/tvm/script/diagnostics.py                   |   55 -
 python/tvm/script/meta_unparser.py                 |   45 -
 python/tvm/script/parser.py                        | 1392 --------------------
 python/tvm/script/{ => parser}/__init__.py         |   12 +-
 python/tvm/script/{__init__.py => parser/_core.py} |   13 +-
 python/tvm/script/{ => parser/core}/__init__.py    |    7 +-
 python/tvm/script/parser/core/diagnostics.py       |  175 +++
 python/tvm/script/parser/core/dispatch.py          |   63 +
 python/tvm/script/parser/core/doc.py               |  361 +++++
 .../script/{printer => parser/core}/doc_core.py    |    0
 .../{tir/prim_func.py => parser/core/entry.py}     |   46 +-
 python/tvm/script/parser/core/evaluator.py         |  284 ++++
 python/tvm/script/parser/core/parser.py            |  273 ++++
 .../{tir/prim_func.py => parser/core/utils.py}     |   39 +-
 python/tvm/script/{ => parser/ir}/__init__.py      |    8 +-
 .../script/{tir/__init__.py => parser/ir/entry.py} |   24 +-
 .../{tir/__init__.py => parser/ir/parser.py}       |   28 +-
 python/tvm/script/{ => parser/tir}/__init__.py     |   11 +-
 python/tvm/script/parser/tir/entry.py              |  101 ++
 python/tvm/script/parser/tir/operation.py          |   84 ++
 python/tvm/script/parser/tir/parser.py             |  268 ++++
 python/tvm/script/registry.py                      |   62 -
 python/tvm/script/tir/__init__.pyi                 |  487 -------
 python/tvm/script/tir/intrin.py                    |  231 ----
 python/tvm/script/tir/node.py                      |  218 ---
 python/tvm/script/tir/scope_handler.py             |  793 -----------
 python/tvm/script/tir/special_stmt.py              |  964 --------------
 python/tvm/script/tir/ty.py                        |  216 ---
 python/tvm/script/utils.py                         |  105 --
 python/tvm/tir/analysis/analysis.py                |    6 +-
 python/tvm/tir/expr.py                             |   15 +-
 python/tvm/tir/schedule/block_scope.py             |    2 +-
 python/tvm/tir/schedule/schedule.py                |    6 +-
 python/tvm/tir/schedule/state.py                   |    3 +-
 python/tvm/tir/stmt.py                             |    4 +
 python/tvm/tir/usmp/transform/transform.py         |    5 +-
 src/tir/ir/stmt.cc                                 |   10 +
 src/tir/op/op.cc                                   |   22 +
 41 files changed, 1793 insertions(+), 4935 deletions(-)

diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 0939e25efd..63b7baaea7 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -527,7 +527,13 @@ TVM_DLL PrimExpr isnan(PrimExpr x, Span span = Span());
  * \return The result expression.
  */
 TVM_DLL PrimExpr isfinite(PrimExpr x, Span span = Span());
-
+/*!
+ * \brief Check if x is nullptr.
+ * \param x The input data
+ * \param span The location of this operation in the source.
+ * \return The result expression.
+ */
+TVM_DLL PrimExpr isnullptr(PrimExpr x, Span span = Span());
 /*!
  * \brief Check if x is infinite.
  * \param x The input data
@@ -601,6 +607,15 @@ TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr>
 TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {},
                       Span span = Span());
 
+/*!
+ * \brief Calculate fmod(x, y)
+ * \param x Left operand.
+ * \param y Right operand.
+ * \param span The location of this operation in the source.
+ * \return The result expression.
+ */
+TVM_DLL PrimExpr fmod(PrimExpr x, PrimExpr y, Span span = Span());
+
 /*!
  * \brief Calculate floor(x)
  * \param x The input expression.
@@ -675,6 +690,22 @@ TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span sp
 TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s,
                                   Span span = Span());
 
+/*!
+ * \brief Returns the address of an element in the buffer
+ * \param buffer_load The input BufferLoad.
+ * \param span The location of this operation in the source.
+ * \return The address of an element in the buffer.
+ */
+TVM_DLL PrimExpr address_of(tir::BufferLoad buffer_load, Span span = Span());
+
+/*!
+ * \brief Returns the param by name
+ * \param param_name The param name.
+ * \param span The location of this operation in the source.
+ * \return The handle of param.
+ */
+TVM_DLL PrimExpr lookup_param(String param_name, Span span = Span());
+
 // Intrinsic operators
 #define TVM_DECLARE_INTRIN_UNARY(OpName)                                \
   inline PrimExpr OpName(PrimExpr x, Span span = Span()) {              \
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py
index 555659d0c5..3107ada88b 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/__init__.py
@@ -15,7 +15,5 @@
 # specific language governing permissions and limitations
 # under the License.
 """TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
-from . import tir
-
-from .parser import ir_module, from_source
+from . import parser
+from .parser import ir, ir_module, parse, tir
diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py
deleted file mode 100644
index f7f16855c7..0000000000
--- a/python/tvm/script/context_maintainer.py
+++ /dev/null
@@ -1,251 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Context Maintainer for TIR"""
-
-from typing import List, Mapping, Union, Optional, Dict, Callable
-import synr
-
-
-import tvm
-from tvm.ir import Span
-from tvm.ir.expr import Range
-from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
-from tvm.runtime import Object
-from tvm.tir.expr import IterVar
-from .tir.node import BufferSlice
-
-
-class BlockInfo:
-    """Information for block and block_realize signature
-
-    Examples
-    ----------
-    .. code-block:: python
-
-        @T.prim_func
-        def example_func(a: T.handle, b: T.handle, c: T.handle) -> None:
-            A = T.match_buffer(a, (16, 16), "float32")
-            B = T.match_buffer(b, (16, 16), "float32")
-            C = T.match_buffer(a, (16, 16), "float32")
-
-            for i, j, k in T.grid(16, 16, 16):
-                with T.block("matmul"):
-                    vi = T.axis.S(16, i)
-                    vj = T.axis.S(16, j)
-                    vk = T.axis.R(16, k)         # iter_bindings = {vj: i, vj: j, vk: k}
-
-                    T.where(True)         # predicate of the block_realize
-
-                    T.reads(A[0:16, 0:16], B[0: 16, 0: 16])      # reads region of the block
-                    T.writes(C[0: 16, 0: 16])                    # writes region of the block
-                    T.block_attr({"attr_key": "attr_value"})     # block annotations
-
-                    # alloc_buffers inside the block
-                    CC = T.alloc_buffer((1, 1), dtype="float32")
-
-                    # match_buffers of the block,
-                    # which bind a sub-region of source buffer into a new buffer
-                    D = T.match_buffer(C[vi, vj], ())
-
-                    # init part of the block, executed when all reduce axes are the beginning value
-                    with T.init():
-                        C[vi, vj] = T.float32(0)
-
-                    # block body
-                    CC[0, 0] = A[vi, vk] * B[vj, vk]
-                    D[()] += CC[0, 0]         # The same as C[vi, vj] += CC[0, 0]
-    """
-
-    alloc_buffers: List[Buffer] = []
-    """List[Buffer]: list of T.alloc_buffer statements in the block signature"""
-    match_buffers: List[MatchBufferRegion] = []
-    """List[MatchBufferRegion]: list of T.match_buffer statements in the block signature"""
-    iter_values: List[PrimExpr] = []
-    """List[PrimExpr]: list of binding values for iter vars"""
-    iter_vars: List[IterVar] = []
-    """List[PrimExpr]: list of iter vars in the block"""
-    reads: Optional[List[BufferSlice]] = None
-    """Optional[List[BufferSlice]]:
-    list of T.reads statements in the block signature, None for not-visited"""
-    writes: Optional[List[BufferSlice]] = None
-    """Optional[List[BufferSlice]]:
-    list of T.writes statements in the block signature, None for not-visited"""
-    annotations: Optional[Mapping[str, Object]] = None
-    """Optional[Mapping[str, Object]]:
-    list of T.block_attr statements in the block signature, None for not-visited"""
-    predicate: Optional[PrimExpr] = None
-    """Optional[PrimExpr]: block realize predicate, None for not-visited"""
-    init: Optional[Stmt] = None
-    """Optional[Stmt]: init part of the block, None for not-visited"""
-
-    def __init__(self):
-        self.alloc_buffers = []
-        self.match_buffers = []
-        self.iter_values = []
-        self.iter_vars = []
-        self.reads = None
-        self.writes = None
-        self.annotations = None
-        self.predicate = None
-        self.init = None
-
-
-class ContextMaintainer:
-    """Maintain all the necessary context info
-    Parameters
-    ----------
-    _report_error : Callable[[str, Union[Span, synr.ast.Span]], None]
-        The report error function handle
-    """
-
-    # scope context
-    node_stack: List[List[synr.ast.Node]] = []
-    """List[List[synr.ast.Node]]: The ast nodes insides the current scope"""
-    block_info_stack: List[BlockInfo] = []
-    """List[BlockInfo]: The block info for the current block scope"""
-    loop_stack: Dict[Var, Range] = {}
-    """Dict[Var, Range]: The dict from loop var to its domain outside the block"""
-    symbols: List[Dict[str, Union[Var, Buffer]]] = []
-    """List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope"""
-    closure_vars: Dict[str, Object] = {}
-    """ClosureVars: The closure vars defined in Python interpreter"""
-
-    # function context
-    func_params: List[Var] = []
-    """List[Var]: The function parameters"""
-    func_buffer_map: Mapping[Var, Buffer] = {}
-    """Mapping[Var, Buffer]: The function buffer map"""
-    func_preflattened_buffer_map: Mapping[Var, Buffer] = {}
-    """Mapping[Var, Buffer]: The function buffer map, prior to any flattening."""
-    func_dict_attr: Mapping[str, Object] = {}
-    """Mapping[str, Object]: The function attrs"""
-    func_var_env_dict: Mapping[Var, str] = {}
-    """Mapping[Var, str]: The map from var to env thread"""
-
-    # parser and analyzer
-    analyzer: tvm.arith.Analyzer = tvm.arith.Analyzer()
-    """tvm.arith.Analyzer: The analyzer for simplifying"""
-    _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
-    """Callable[[str, Union[Span, synr.ast.Span]], None]: The report error function handle"""
-
-    # root alloc_buffer
-    root_alloc_buffers: List[Buffer] = []
-    """List[Buffer]: The buffers allocated under root block"""
-
-    def __init__(
-        self,
-        _report_error: Callable[[str, Union[Span, synr.ast.Span]], None],
-        closure_vars: Dict[str, Object],
-    ):
-        # scope context
-        self.node_stack = []
-        self.block_info_stack = []
-        self.loop_stack = {}
-        self.symbols = []
-        self.closure_vars = closure_vars
-        # function context
-        self.func_params = []
-        self.func_buffer_map = {}
-        self.func_preflattened_buffer_map = {}
-        self.func_dict_attr = {}
-        self.func_var_env_dict = {}
-        # parser and analyzer
-        self._report_error = _report_error
-        self.analyzer = tvm.arith.Analyzer()
-        # root alloc_buffer
-        self.root_alloc_buffers = []
-
-    def enter_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
-        """Creates a new scope
-
-        Note
-        ----
-        This function is used for normal scopes that do not involve
-        a `with block` scope. Use `enter_block_scope`
-        for block scope cases.
-
-        Parameters
-        ----------
-        nodes : Optional[List[synr.ast.Node]]
-            The synr AST nodes in new scope
-        """
-        if nodes is None:
-            nodes = []
-        self.node_stack.append(list(reversed(nodes)))
-        self.symbols.append(dict())
-
-    def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
-        """Creates a new block scope, the function will call `enter_scope` implicitly
-        Besides the behaviors of `enter_scope`, it will update loop_stack and block_info_stack
-        to maintain block info.
-
-        Note
-        ----
-        This function should be used to handle a block scope,
-        aka the blocks that involve a `with block` scope.
-
-        Parameters
-        ----------
-        nodes : Optional[List[synr.ast.Node]]
-            The synr AST nodes in new scope
-        """
-        self.enter_scope(nodes)
-        # Create a new BlockInfo for the new block
-        self.block_info_stack.append(BlockInfo())
-
-    def exit_scope(self):
-        """Pop the inner most scope"""
-        self.symbols.pop()
-        self.node_stack.pop()
-
-    def exit_block_scope(self):
-        """Pop the inner most block scope, the function will call `exit_scope` implicitly"""
-        self.exit_scope()
-        # Pop block_info
-        self.block_info_stack.pop()
-
-    def update_symbol(self, name: str, symbol: Union[Buffer, Var], node: synr.ast.Node):
-        """Append a symbol into current scope"""
-        if isinstance(symbol, Buffer):
-            if name in self.symbols[0]:
-                self.report_error("Duplicate Buffer name: " + symbol.name, node.span)
-            self.symbols[0][name] = symbol
-        else:
-            self.symbols[-1][name] = symbol
-
-    def remove_symbol(self, name: str):
-        """Remove a symbol"""
-        for symbols in reversed(self.symbols):
-            if name in symbols:
-                symbols.pop(name)
-                return
-        raise RuntimeError("Internal error of tvm script parser: no symbol named " + name)
-
-    def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]:
-        """Look up symbol by name"""
-        for symbols in reversed(self.symbols):
-            if name in symbols:
-                return symbols[name]
-        return self.closure_vars.get(name)
-
-    def report_error(self, message: str, span: Union[Span, synr.ast.Span]):
-        self._report_error(message, span)
-
-    def current_block_scope(self) -> BlockInfo:
-        if self.block_info_stack:
-            return self.block_info_stack[-1]
-        return None
diff --git a/python/tvm/script/diagnostics.py b/python/tvm/script/diagnostics.py
deleted file mode 100644
index e676461ab3..0000000000
--- a/python/tvm/script/diagnostics.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Bridge from synr's (the library used for parsing the python AST)
-   DiagnosticContext to TVM's diagnostics
-"""
-from synr import DiagnosticContext, ast
-
-import tvm
-from tvm.ir.diagnostics import DiagnosticContext as TVMCtx
-from tvm.ir.diagnostics import get_renderer, DiagnosticLevel, Diagnostic
-
-
-class TVMDiagnosticCtx(DiagnosticContext):
-    """TVM diagnostics for synr"""
-
-    diag_ctx: TVMCtx
-
-    def __init__(self) -> None:
-        self.diag_ctx = TVMCtx(tvm.IRModule(), get_renderer())
-        self.source_name = None
-
-    def to_tvm_span(self, src_name, ast_span: ast.Span) -> tvm.ir.Span:
-        return tvm.ir.Span(
-            src_name,
-            ast_span.start_line,
-            ast_span.end_line,
-            ast_span.start_column,
-            ast_span.end_column,
-        )
-
-    def add_source(self, name: str, source: str) -> None:
-        src_name = self.diag_ctx.module.source_map.add(name, source)
-        self.source_name = src_name
-
-    def emit(self, _level, message, span):
-        span = self.to_tvm_span(self.source_name, span)
-        self.diag_ctx.emit(Diagnostic(DiagnosticLevel.ERROR, span, message))
-        self.diag_ctx.render()  # Raise exception on the first error we hit. TODO remove
-
-    def render(self):
-        self.diag_ctx.render()
diff --git a/python/tvm/script/meta_unparser.py b/python/tvm/script/meta_unparser.py
deleted file mode 100644
index b1472ccdc7..0000000000
--- a/python/tvm/script/meta_unparser.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Unparse meta AST node into a dict"""
-# pylint: disable=invalid-name
-
-from synr import Transformer
-
-
-class MetaUnparser(Transformer):
-    """Python AST Visitor to unparse meta AST node into a dict"""
-
-    def transform(self, node):
-        method = "transform_" + node.__class__.__name__
-        visitor = getattr(self, method, None)
-        if visitor is None:
-            self.error(f"Unexpected node type {type(node)} when parsing __tvm_meta__", node.span)
-        return visitor(node)
-
-    def transform_DictLiteral(self, node):
-        keys = [self.visit(key) for key in node.keys]
-        values = [self.visit(value) for value in node.values]
-        return dict(zip(keys, values))
-
-    def transform_Tuple(self, node):
-        return tuple(self.visit(element) for element in node.elts)
-
-    def transform_ArrayLiteral(self, node):
-        return [self.visit(element) for element in node.elts]
-
-    def transform_Constant(self, node):
-        return node.value
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
deleted file mode 100644
index c34aae2345..0000000000
--- a/python/tvm/script/parser.py
+++ /dev/null
@@ -1,1392 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser For TIR
-
-We use [synr](https://synr.readthedocs.io) to get an AST that is stable over
-different python versions. Synr also provides an error handling context that we
-use for error reporting.
-"""
-# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except
-import types
-import json
-import operator
-import inspect
-from typing import Any, Callable, Dict, List, Optional, Union
-from synr import ast, Transformer, to_ast
-
-import tvm
-from tvm import IRModule
-from tvm._ffi.base import TVMError
-from tvm.ir import GlobalVar
-from tvm.ir.function import BaseFunc
-from tvm.tir import buffer
-from tvm.tir.function import PrimFunc
-from . import _ffi_api
-from . import tir
-
-from .context_maintainer import ContextMaintainer
-from .meta_unparser import MetaUnparser
-from .registry import Registry
-from .diagnostics import TVMDiagnosticCtx
-from .utils import tvm_span_from_synr, synr_span_from_tvm, call_with_error_reporting
-
-from .tir.intrin import Intrin
-from .tir.node import Slice, BufferSlice
-from .tir.scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler
-from .tir.special_stmt import SpecialStmt
-from .tir import ty
-
-
-class CallArgumentReader(object):
-    """Helper class to read required arguments from passed arguments.
-
-    When parsing a function call, we need to match the arguments provided in
-    the AST to the required arguments of the function. This class makes sure
-    all the positional arguments are filled and also fill keyword arguments
-    with thier default value if a different value was not provided.
-    """
-
-    def __init__(self, func_name, args, kwargs, parser, node):
-        self.func_name = func_name
-        self.args = args
-        self.kwargs = kwargs
-        self.parser = parser
-        self.node = node
-
-    def get_pos_only_arg(self, pos, name):
-        """Get corresponding position only function argument from argument list"""
-        if len(self.args) >= pos:
-            arg = self.args[pos - 1]
-        elif name not in self.kwargs:
-            # If no positional argument was found in the AST, we see if it was
-            # defined by name instead.
-            # TODO(tkonolige): this error message is not quite correct. The
-            # number of required arguments is >= pos
-            self.parser.report_error(
-                f"{self.func_name} requires {pos} arguments, but only {len(self.args)} were given.",
-                self.node.span,
-            )
-        else:
-            arg = self.kwargs[name]
-
-        return arg
-
-    def get_kwarg(self, pos, name, default):
-        """Get corresponding keyword function argument from argument list.
-
-        If the user hasn't provided the argument, set it to the default value.
-        """
-        if len(self.args) >= pos:
-            arg = self.args[pos - 1]
-        elif name in self.kwargs:
-            arg = self.kwargs[name]
-        else:
-            return default
-
-        return arg
-
-    def get_varargs(self, pos):
-        """Get corresponding variable argument from argument list"""
-        if len(self.args) >= pos and len(self.kwargs) == 0:
-            return self.args[pos - 1 :]
-        return []
-
-
-class TVMScriptParser(Transformer):
-    """Synr AST visitor pass which finally lowers to TIR.
-
-    Notes for Extension
-    -------------------
-    1. To support a new type of AST node, add a function transform_xxx().
-    2. To support new functions, add the function to the appropriate registry:
-        We divide allowed function calls in TVM script into 3 categories,
-        intrin, scope_handler and special_stmt.
-        1. intrin functions are low level functions like mod, load, and
-           constants. They correspond to a tir `IRNode`. They must have a
-           return value. The user can register intrin functions for the parser to
-           use.
-        2. scope_handler functions have no return value. They take two
-           arguments: the parser and the AST node. scope_handler functions are
-           used in with and for statements.
-        3. special_stmt functions handle cases that do not have a corresponding
-           tir `IRNode`. These functions take the parser and the AST node as
-           arguments and may return a value.
-        When visiting a Call node, we check the special_stmt registry first. If
-        no registered function is found, we then check the intrin registry.
-        When visiting With node, we check the with_scope registry.
-        When visiting For node, we check the for_scope registry.
-    """
-
-    _binop_maker = {
-        ast.BuiltinOp.Add: tvm.tir.Add,
-        ast.BuiltinOp.Sub: tvm.tir.Sub,
-        ast.BuiltinOp.Mul: tvm.tir.Mul,
-        ast.BuiltinOp.Div: tvm.tir.Div,
-        ast.BuiltinOp.FloorDiv: tvm.tir.FloorDiv,
-        ast.BuiltinOp.Mod: tvm.tir.FloorMod,
-        ast.BuiltinOp.BitOr: lambda lhs, rhs, span: operator.or_(lhs, rhs),
-        ast.BuiltinOp.BitAnd: lambda lhs, rhs, span: operator.and_(lhs, rhs),
-        ast.BuiltinOp.BitXor: lambda lhs, rhs, span: operator.xor(lhs, rhs),
-        ast.BuiltinOp.GT: tvm.tir.GT,
-        ast.BuiltinOp.GE: tvm.tir.GE,
-        ast.BuiltinOp.LT: tvm.tir.LT,
-        ast.BuiltinOp.LE: tvm.tir.LE,
-        ast.BuiltinOp.Eq: tvm.tir.EQ,
-        ast.BuiltinOp.NotEq: tvm.tir.NE,
-        ast.BuiltinOp.And: tvm.tir.And,
-        ast.BuiltinOp.Or: tvm.tir.Or,
-    }
-
-    _unaryop_maker = {
-        ast.BuiltinOp.USub: lambda rhs, span: operator.neg(rhs),
-        ast.BuiltinOp.Invert: lambda rhs, span: operator.invert(rhs),
-        ast.BuiltinOp.Not: tvm.tir.Not,
-    }
-
-    # pylint gets confused here with synr.Transformer which doesn't have a
-    # custom init, so just disable it
-    def __init__(
-        self, base_lineno, tir_namespace, closure_vars
-    ):  # pylint: disable=super-init-not-called
-        self.context = None
-
-        self.base_lineno = base_lineno
-        self.current_lineno = 0
-        self.current_col_offset = 0
-        self.tir_namespace = tir_namespace
-        self.closure_vars = closure_vars
-        self.meta = None
-        self._inside_buffer_sugar = False
-
-    def init_function_parsing_env(self):
-        """Initialize function parsing environment"""
-        self.context = ContextMaintainer(self.report_error, self.closure_vars)  # scope emitter
-
-    def init_meta(self, meta_dict):
-        if meta_dict is not None:
-            self.meta = tvm.ir.load_json(json.dumps(meta_dict))
-
-    def transform(self, node):
-        """Generic transformation for visiting the AST. Dispatches to
-        `transform_ClassName` for the appropriate ClassName."""
-        old_lineno, old_col_offset = self.current_lineno, self.current_col_offset
-
-        if hasattr(node, "lineno"):
-            self.current_lineno = self.base_lineno + node.lineno - 1
-        if hasattr(node, "col_offset"):
-            self.current_col_offset = node.col_offset
-
-        method = "transform_" + node.__class__.__name__
-        visitor = getattr(self, method, self.generic_visit)
-        transform_res = visitor(node)
-
-        self.current_lineno, self.current_col_offset = old_lineno, old_col_offset
-
-        return transform_res
-
-    def match_tir_namespace(self, identifier: str) -> bool:
-        """Check if the namespace is equal to tvm.script.tir"""
-        return identifier in self.tir_namespace
-
-    def report_error(self, message: str, span: Union[ast.Span, tvm.ir.Span]):
-        """Report an error occuring at a location.
-
-        This just dispatches to synr's DiagnosticContext.
-
-        Parameters
-        ----------
-        message : str
-            Error message
-        span : Union[synr.ast.Span, tvm.ir.Span]
-            Location of the error
-        """
-        if isinstance(span, tvm.ir.Span):
-            span = synr_span_from_tvm(span)
-        self.error(message, span)
-
-    def parse_body(self, parent):
-        """Parse remaining statements in this scope.
-
-        Parameters
-        ----------
-        parent : synr.ast.Node
-            Parent node of this scope. Errors will be reported here.
-        """
-        body = []
-        spans = []
-        stmt = parent
-        while len(self.context.node_stack[-1]) > 0:
-            stmt = self.context.node_stack[-1].pop()
-            spans.append(stmt.span)
-            res = self.transform(stmt)
-            if res is not None:
-                body.append(res)
-        if len(body) == 0:
-            self.report_error(
-                "Expected another statement at the end of this block. Perhaps you "
-                "used a concise statement and forgot to include a body afterwards.",
-                stmt.span,
-            )
-        else:
-            return (
-                tvm.tir.SeqStmt(body, tvm_span_from_synr(ast.Span.union(spans)))
-                if len(body) > 1
-                else body[0]
-            )
-
-    def parse_arg_list(self, func, node_call):
-        """Match the arguments of a function call in the AST to the required
-        arguments of the function. This handles positional arguments,
-        positional arguments specified by name, keyword arguments, and varargs.
-
-        Parameters
-        ----------
-        func : Function
-            The function that provides the signature
-
-        node_call: Union[ast.Call, ast.TypeApply, ast.TypeCall]
-            The AST call node that calls into the function.
-
-        Returns
-        -------
-        arg_list : list
-            The parsed positional argument.
-        """
-        assert isinstance(node_call, (ast.Call, ast.TypeApply, ast.TypeCall))
-        # collect arguments
-        args = [self.transform(arg) for arg in node_call.params]
-        if isinstance(node_call, ast.TypeApply):
-            kw_args = {}  # TypeApply (e.g. foo[bar]) doesn't have kwargs defined in synr
-        else:
-            kw_args = {
-                self.transform(k): self.transform(v) for k, v in node_call.keyword_params.items()
-            }
-        # get the name and parameter list of func
-        if isinstance(func, (Intrin, ScopeHandler, SpecialStmt)):
-            func_name, param_list = func.signature()
-        else:
-            self.report_error(
-                "Internal Error: function must be of type Intrin, ScopeHandler or SpecialStmt, "
-                f"but it is {type(func).__name__}",
-                node_call.span,
-            )
-        # check arguments and parameter list and get a list of arguments
-        reader = CallArgumentReader(func_name, args, kw_args, self, node_call)
-        pos_only, kwargs, varargs = param_list
-        internal_args = list()
-
-        for i, arg_name in enumerate(pos_only):
-            internal_args.append(reader.get_pos_only_arg(i + 1, arg_name))
-        for i, arg_info in enumerate(kwargs):
-            arg_name, default = arg_info
-            internal_args.append(reader.get_kwarg(i + 1 + len(pos_only), arg_name, default=default))
-        if varargs is not None:
-            internal_args.extend(reader.get_varargs(len(pos_only) + len(kwargs) + 1))
-        elif len(args) + len(kw_args) > len(pos_only) + len(kwargs):
-            self.report_error(
-                "Arguments mismatched. "
-                + f"Expected {len(pos_only) + len(kwargs)} args but got "
-                + f"{len(args) + len(kw_args)}",
-                node_call.span,
-            )
-        return internal_args
-
-    def parse_type(self, type_node, parent):
-        """Parse a type annotation.
-
-        We require the parent object to the type so that we have a place to
-        report the error message if the type does not exist.
-        """
-        if type_node is None:
-            self.report_error("A type annotation is required", parent.span)
-        res_type = self.transform(type_node)
-        return tvm.ir.TupleType([]) if res_type is None else res_type.evaluate()
-
-    def generic_visit(self, node):
-        """Fallback visitor if node type is not handled. Reports an error."""
-
-        self.report_error(type(node).__name__ + " AST node is not supported", node.span)
-
-    def transform_Module(self, node):
-        """Module visitor
-
-        Right now, we only support two formats for TVM Script.
-
-        Example
-        -------
-        1. Generate a PrimFunc (If the code is printed, then it may also contain metadata)
-        .. code-block:: python
-
-            import tvm
-
-            @tvm.script
-            def A(...):
-                ...
-
-            # returns a PrimFunc
-            func = A
-
-        2. Generate an IRModule
-        .. code-block:: python
-
-            import tvm
-
-            @tvm.script.ir_module
-            class MyMod():
-                @T.prim_func
-                def A(...):
-                    ...
-                @T.prim_func
-                def B(...):
-                    ...
-
-                __tvm_meta__ = ...
-
-            # returns an IRModule
-            mod = MyMod
-        """
-        if len(node.funcs) == 1:
-            return self.transform(next(iter(node.funcs.values())))
-        elif len(node.funcs) == 0:
-            self.report_error(
-                "You must supply at least one class or function definition", node.span
-            )
-        else:
-            self.report_error(
-                "Only one-function, one-class or function-with-meta source code is allowed",
-                ast.Span.union([x.span for x in list(node.funcs.values())[1:]]),
-            )
-
-    def transform_Class(self, node):
-        """Class definition visitor.
-
-        A class can have multiple function definitions and a single
-        :code:`__tvm_meta__` statement. Each class corresponds to a single
-        :code:`IRModule`.
-
-        Example
-        -------
-        .. code-block:: python
-
-            @tvm.script.ir_module
-            class MyClass:
-                __tvm_meta__ = {}
-                def A():
-                    T.evaluate(0)
-        """
-        if len(node.assignments) == 1:
-            if not (
-                len(node.assignments[0].lhs) == 1
-                and isinstance(node.assignments[0].lhs[0], ast.Var)
-                and node.assignments[0].lhs[0].id.name == "__tvm_meta__"
-            ):
-                self.report_error(
-                    "The only top level assignments allowed are `__tvm_meta__ = ...`",
-                    node.assignments[0].span,
-                )
-            self.init_meta(
-                MetaUnparser().do_transform(node.assignments[0].rhs, self._diagnostic_context)
-            )
-        elif len(node.assignments) > 1:
-            self.report_error(
-                "Only a single top level `__tvm_meta__` is allowed",
-                ast.Span.union([x.span for x in node.assignments[1:]]),
-            )
-
-        return IRModule(
-            {GlobalVar(name): self.transform(func) for name, func in node.funcs.items()}
-        )
-
-    def transform_Function(self, node):
-        """Function definition visitor.
-
-        Each function definition is translated to a single :code:`PrimFunc`.
-
-        There are a couple restrictions on TVM Script functions:
-        1. Function arguments must have their types specified.
-        2. The body of the function can contain :code:`func_attr` to specify
-           attributes of the function (like it's name).
-        3. The body of the function can also contain multiple :code:`buffer_bind`s,
-           which give shape and dtype information to arguments.
-        4. Return statements are implicit.
-
-        Example
-        -------
-        .. code-block:: python
-
-            @T.prim_func
-            def my_function(x: T.handle):  # 1. Argument types
-                T.func_attr({"global_symbol": "mmult"})  # 2. Function attributes
-                X_1 = tir.buffer_bind(x, [1024, 1024])  # 3. Buffer binding
-                T.evaluate(0)  # 4. This function returns 0
-        """
-
-        def check_as_torch_decorator(decorator: Union[ast.Call, ast.Var]):
-            if isinstance(decorator, ast.Call):
-                if len(decorator.params) != 1:
-                    return False
-                func_name = decorator.func_name
-            else:
-                func_name = decorator
-            if isinstance(func_name, ast.Var):
-                return func_name.id.name == "as_torch"
-
-        def check_decorator(decorators: List[ast.Expr]) -> bool:
-            """Check the decorator is `T.prim_func"""
-            if len(decorators) > 2 or len(decorators) == 0:
-                return False
-            if len(decorators) == 2 and not check_as_torch_decorator(decorators[0]):
-                return False
-            d: ast.Expr = decorators[-1]
-            return (
-                isinstance(d, ast.Attr)
-                and isinstance(d.object, ast.Var)
-                and self.match_tir_namespace(d.object.id.name)
-                and d.field.name == "prim_func"
-            )
-
-        self.init_function_parsing_env()
-        self.context.enter_scope(nodes=node.body.stmts)
-
-        # add parameters of function
-        for arg in node.params:
-            # Note that this case is for T.match_buffer syntax sugar
-            if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)) and isinstance(
-                self.transform(arg.ty.func_name), ty.GenericBufferType
-            ):
-                result = self.handle_match_buffer_type(arg.ty, arg.name)
-                if not isinstance(result, buffer.Buffer):
-                    self.report_error(
-                        "The result type of evaluating TypeCall and TypeApply stmt"
-                        f" is wrong: {type(result)}. It should be a Buffer",
-                        node.span,
-                    )
-                arg_name_with_handle = arg.name + "_handle"
-                arg_var = tvm.te.var(arg_name_with_handle, tvm.ir.PrimType("handle"))
-                self.context.func_buffer_map[arg_var] = result
-                self.context.update_symbol(arg.name, result, node)
-            else:
-                arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
-                self.context.update_symbol(arg.name, arg_var, node)
-            self.context.func_params.append(arg_var)
-
-        if not check_decorator(node.decorators):
-            self.report_error(
-                "All functions should be decorated by `T.prim_func`",
-                node.span,
-            )
-
-        # fetch the body of root block
-        body = self.parse_body(node.body)
-
-        # return a tir.PrimFunc
-        dict_attr = self.context.func_dict_attr
-        ret_type = self.parse_type(node.ret_type, node) if node.ret_type is not None else None
-        func = tvm.tir.PrimFunc(
-            self.context.func_params,
-            body,
-            ret_type,
-            buffer_map=self.context.func_buffer_map,
-            preflattened_buffer_map=self.context.func_preflattened_buffer_map,
-            attrs=tvm.ir.make_node("DictAttrs", **dict_attr) if dict_attr else None,
-            span=tvm_span_from_synr(node.span),
-        )
-
-        # New Scope : Implicit root block
-        # Each function contains an implicit root block in TensorIR,
-        # so here we need a block scope for it.
-        # If the PrimFunc is not a TensorIR func (e.g. TE scheduled func or low-level func),
-        # the root block will not be added. The logic to add root block is in `_ffi_api.Complete`
-
-        # Fix the PrimFunc
-        # 1. generate root block if necessary
-        # 2. generate surrounding loops for blocks if necessary
-
-        func = call_with_error_reporting(
-            self.report_error,
-            node.span,
-            _ffi_api.Complete,
-            func,
-            self.context.root_alloc_buffers,
-        )
-
-        self.context.exit_scope()
-        return func
-
-    def transform_Lambda(self, node):
-        """Lambda visitor
-
-        Return an array of input parameters and the transformed lambda body.
-        """
-
-        self.context.enter_scope(nodes=[node.body])
-
-        # add parameters of the lambda
-        arg_vars = []
-        for arg in node.params:
-            # Use "void" for dtype here. The actual type is not yet known and will be
-            # determined later. Using void type will allow IRSubstitute to do the
-            # replacement without flagging a type-mismatch error.
-            arg_var = tvm.te.var(arg.name, dtype="")
-            arg_vars.append(arg_var)
-            self.context.update_symbol(arg.name, arg_var, node)
-
-        # the body of a lambda must be an expr
-        if not isinstance(node.body, ast.Expr):
-            self.report_error("The body of a lambda must be an expression", node.span)
-
-        # transform the body of the lambda
-        body = self.transform(node.body)
-
-        self.context.exit_scope()
-        return arg_vars, body
-
-    def transform_Assign(self, node):
-        """Assign visitor
-        AST abstract grammar:
-            Assign(expr* targets, expr value, string? type_comment)
-
-        By now 5 patterns of Assign is supported:
-            1. special stmts with return value
-                1.1 Buffer = T.match_buffer()/T.buffer_decl()
-                1.2 Var = T.var()
-                1.3 Var = T.env_thread()
-            2. (BufferStore) Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr
-            3. (Store)       Var[PrimExpr] = PrimExpr
-            4. with scope handlers with concise scoping and var def
-                4.1 var = T.allocate()
-            5. A call to a pure python function, consuming and producing TVMScript values.
-               The outputs are inlined into the following body (no variable is created).
-               x, y = f(...)
-        """
-
-        if isinstance(node.rhs, ast.Call):
-            # Pattern 1 & Pattern 4
-            if isinstance(node.rhs.func_name, ast.Op):
-                func = None
-            else:
-                func = self.transform(node.rhs.func_name)
-
-            if isinstance(func, WithScopeHandler):
-                if not func.concise_scope or not func.def_symbol:
-                    self.report_error(
-                        "with scope handler " + func.signature()[0] + " is not suitable here",
-                        node.rhs.span,
-                    )
-                # Pattern 4
-                arg_list = self.parse_arg_list(func, node.rhs)
-                func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
-                func.body = self.parse_body(node)
-                return func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span)
-            elif isinstance(func, SpecialStmt):
-                # Pattern 1
-                arg_list = self.parse_arg_list(func, node.rhs)
-                func.handle(node, self.context, arg_list, node.rhs.func_name.span)
-                return self.parse_body(node)
-            elif isinstance(func, types.FunctionType):
-                # Pattern 5
-                args = [self.transform(arg) for arg in node.rhs.params]
-                try:
-                    out = func(*args)
-                except Exception as e:
-                    self.report_error(
-                        "Error occurred when invoking the function "
-                        + func.__name__
-                        + ": \n"
-                        + str(e),
-                        node.rhs.span,
-                    )
-
-                if len(node.lhs) == 1 and not isinstance(out, list):
-                    out = [out]
-
-                assert len(out) == len(node.lhs)
-
-                for var, value in zip(node.lhs, out):
-                    self.context.update_symbol(var.id.name, value, node)
-
-                body = self.parse_body(node)
-
-                for var, value in zip(node.lhs, out):
-                    self.context.remove_symbol(var.id.name)
-
-                return body
-
-        if isinstance(node.rhs, (ast.Call, ast.Constant)):
-            # Pattern 4 of let binding
-            value = self.transform(node.rhs)
-            if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var):
-                # This is a little confusing because it only is true when
-                # we have taken this branch. We might need to clarify what
-                # exectly is allowed in Assignments in tvmscript.
-                self.report_error(
-                    "Left hand side of assignment must be an unqualified variable",
-                    node.span,
-                )
-            ast_var = node.lhs[0]
-
-            if node.ty is None and hasattr(value, "dtype"):
-                var_ty = value.dtype
-            else:
-                var_ty = self.parse_type(node.ty, ast_var)
-
-            var = tvm.te.var(
-                ast_var.id.name,
-                var_ty,
-                span=tvm_span_from_synr(ast_var.span),
-            )
-            self.context.update_symbol(var.name, var, node)
-            body = self.parse_body(node)
-            self.context.remove_symbol(var.name)
-            return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))
-
-        self.report_error(
-            """Assignments should be one of:
-            1. A "special statement" with return value
-                1.1 Buffer = T.match_buffer()/T.buffer_decl()
-                1.2 Var = T.var()
-                1.3 Var = T.env_thread()
-            2. A store into a buffer: Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr
-            3. A store into a variable: Var[PrimExpr] = PrimExpr
-            4. A with scope handler with concise scoping and var def
-                4.1 var = T.allocate()
-            5. The right-hand side being a call to a pure python function, consuming and
-               producing TVMScript values.
-               x, y = f(...)""",
-            node.span,
-        )
-
-    def transform_SubscriptAssign(self, node):
-        """Visitor for statements of the form :code:`x[1] = 2`."""
-        symbol = self.transform(node.params[0])
-        indexes = self.transform(node.params[1])
-        rhs = self.transform(node.params[2])
-        rhs_span = tvm_span_from_synr(node.params[2].span)
-        if isinstance(symbol, tvm.tir.Buffer):
-            if len(indexes) != len(symbol.shape):
-                self.report_error(
-                    f"Buffer {symbol.name} is {len(symbol.shape)}-dimensional, "
-                    f"cannot be indexed by {len(indexes)}-dimensional indices.",
-                    node.params[1].span,
-                )
-
-            def __convert_index(x):
-                if isinstance(x, Slice):
-                    return x.as_index_expr(self.report_error)
-                return x
-
-            # BufferStore
-            indexes = [__convert_index(x) for x in indexes]
-            return tvm.tir.BufferStore(
-                symbol,
-                tvm.runtime.convert(rhs, span=rhs_span),
-                indexes,
-                span=tvm_span_from_synr(node.span),
-            )
-        else:
-            if symbol.dtype == "handle" and len(indexes) != 1:
-                self.report_error(
-                    "Handles only support one-dimensional indexing. Use `T.match_buffer` to "
-                    "construct a multidimensional buffer from a handle.",
-                    node.params[0].span,
-                )
-            if len(indexes) != 1:
-                self.report_error(
-                    f"Store is only allowed with one index, but {len(indexes)} were provided.",
-                    node.params[1].span,
-                )
-            self.report_error(
-                "Use of tir.Store has been deprecated in favor of tir.BufferStore.", node.span
-            )
-
-    def transform_AttrAssign(self, node):
-        """Visitor for statements of the form :code:`x.y = 2`."""
-        obj = self.transform(node.params[0])
-        field = node.params[1]
-        value = self.transform(node.params[2])
-
-        if not hasattr(obj, field.name):
-            self.error(f"Field {field.name} does not exist", field.span)
-
-        var = getattr(obj, field.name)
-
-        if not isinstance(var, tvm.tir.Var):
-            self.error(
-                f"Can only assign to tir.Var attributes, not {type(var).__name__}", node.span
-            )
-
-        body = self.parse_body(node)
-        return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))
-
-    def transform_Assert(self, node):
-        """Assert visitor
-
-        Pattern corresponds to concise mode of :code:`with T.Assert()`.
-        """
-
-        condition = self.transform(node.condition)
-        if node.msg is None:
-            self.report_error("Assert statements must have an error message.", node.span)
-        message = self.transform(node.msg)
-        body = self.parse_body(node)
-        return tvm.tir.AssertStmt(
-            condition, tvm.runtime.convert(message), body, span=tvm_span_from_synr(node.span)
-        )
-
-    def transform_For(self, node):
-        """For visitor
-        AST abstract grammar:
-            For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment)
-        By now 1 pattern of For is supported:
-            1. for scope handler
-                for name in T.serial()/T.parallel()/T.vectorized()/T.unroll()/range()/
-                            T.grid()/T.thread_binding()
-        """
-
-        if not isinstance(node.rhs, ast.Call):
-            self.report_error("The loop iterator should be a function call.", node.rhs.span)
-        func = self.transform(node.rhs.func_name)
-        if not isinstance(func, ForScopeHandler):
-            self.report_error(
-                "Only For scope handlers can be used in a for statement.", node.rhs.func_name.span
-            )
-        # prepare for new for scope
-        old_lineno, old_col_offset = self.current_lineno, self.current_col_offset
-        self.current_lineno = node.span.start_line
-        self.current_col_offset = node.span.start_column
-        self.context.enter_scope(nodes=node.body.stmts)
-        # for scope handler process the scope
-        arg_list = [
-            tvm.runtime.convert(arg, span=tvm_span_from_synr(node.rhs.span))
-            for arg in self.parse_arg_list(func, node.rhs)
-        ]
-        func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
-        func.body = self.parse_body(node)
-        res = func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span)
-        # exit the scope
-        self.context.exit_scope()
-        self.current_lineno, self.current_col_offset = old_lineno, old_col_offset
-        return res
-
-    def transform_While(self, node):
-        """While visitor
-        AST abstract grammar:
-            While(expr condition, stmt* body)
-        """
-        condition = self.transform(node.condition)
-        # body
-        self.context.enter_scope(nodes=node.body.stmts)
-        body = self.parse_body(node)
-        self.context.exit_scope()
-
-        return tvm.tir.While(condition, body, span=tvm_span_from_synr(node.span))
-
-    def transform_With(self, node):
-        """With visitor
-        AST abstract grammar:
-            With(withitem* items, stmt* body, string? type_comment)
-            withitem = (expr context_expr, expr? optional_vars)
-        By now 2 patterns of With is supported:
-            1. with scope handler with symbol def
-                with T.allocate() as targets:
-            2. with scope handler without symbol def
-                with T.block(*axes)/T.let()/T.Assert()/T.attr()/T.realize()
-        """
-
-        if not isinstance(node.rhs, ast.Call):
-            self.report_error(
-                "The context expression of a `with` statement should be a function call.",
-                node.rhs.span,
-            )
-
-        func = self.transform(node.rhs.func_name)
-
-        if not isinstance(func, WithScopeHandler):
-            self.report_error(
-                f"Function {func} cannot be used in a `with` statement.", node.rhs.func_name.span
-            )
-        # prepare for new block scope
-        old_lineno, old_col_offset = self.current_lineno, self.current_col_offset
-        self.current_lineno = node.body.span.start_line
-        self.current_col_offset = node.body.span.start_column
-        self.context.enter_block_scope(nodes=node.body.stmts)
-        # with scope handler process the scope
-        arg_list = self.parse_arg_list(func, node.rhs)
-        func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
-        func.body = self.parse_body(node)
-        res = func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span)
-        # exit the scope
-        self.context.exit_block_scope()
-        self.current_lineno, self.current_col_offset = old_lineno, old_col_offset
-        return res
-
-    def transform_If(self, node):
-        """If visitor
-        AST abstract grammar:
-            If(expr test, stmt* body, stmt* orelse)
-        """
-
-        condition = self.transform(node.condition)
-        # then body
-        self.context.enter_scope(nodes=node.true.stmts)
-        then_body = self.parse_body(node)
-        self.context.exit_scope()
-
-        # else body
-        if len(node.false.stmts) > 0:
-            self.context.enter_scope(nodes=node.false.stmts)
-            else_body = self.parse_body(node)
-            self.context.exit_scope()
-        else:
-            else_body = None
-
-        return tvm.tir.IfThenElse(
-            condition, then_body, else_body, span=tvm_span_from_synr(node.span)
-        )
-
-    def transform_Call(self, node):
-        """Call visitor
-
-        3 different Call patterns are allowed:
-            1. Intrin representing a PrimExpr/IterVar
-                1.1 tir.int/uint/float8/16/32/64/floormod/floordiv/load/cast/ramp/broadcast/max
-                1.2 tir.range/reduce_axis/scan_axis/opaque_axis
-            2. tir.Op(dtype, ...)
-            3. other callable functions
-        """
-
-        if isinstance(node.func_name, ast.Op):
-            if node.func_name.name == ast.BuiltinOp.Subscript:
-                return self.transform_Subscript(node)
-            if node.func_name.name in self._binop_maker:
-                lhs = self.transform(node.params[0])
-                # There is no supertype for everything that can appear in
-                # an expression, so we manually add what we might get here.
-                if not isinstance(lhs, (tvm.tir.PrimExpr, BufferSlice)):
-                    # We would really like to report a more specific
-                    # error here, but this parser contains no distinction
-                    # between parsing statements and parsing expressions. All
-                    # rules just call `transform`.
-                    self.report_error(
-                        f"Left hand side of binary op must be a PrimExpr, "
-                        "but it is a {type(lhs).__name__}",
-                        node.params[0].span,
-                    )
-                rhs = self.transform(node.params[1])
-                if not isinstance(rhs, (tvm.tir.PrimExpr, BufferSlice)):
-                    self.report_error(
-                        f"Right hand side of binary op must be a PrimExpr, "
-                        "but it is a {type(rhs).__name__}",
-                        node.params[1].span,
-                    )
-                return call_with_error_reporting(
-                    self.report_error,
-                    node.span,
-                    lambda node, lhs, rhs, span: self._binop_maker[node.func_name.name](
-                        lhs, rhs, span=span
-                    ),
-                    node,
-                    lhs,
-                    rhs,
-                    tvm_span_from_synr(node.span),
-                )
-            if node.func_name.name in self._unaryop_maker:
-                rhs = self.transform(node.params[0])
-                if node.func_name.name == ast.BuiltinOp.USub and isinstance(
-                    node.params[0], ast.Constant
-                ):
-                    # '-literal' should be parsed together for proper literal type inference
-                    if not isinstance(rhs, (tvm.tir.IntImm, tvm.tir.FloatImm)):
-                        self.report_error("The literal is illegal after -", node.params[0].span)
-                    return tvm.tir.const(-rhs.value)
-                return self._unaryop_maker[node.func_name.name](
-                    rhs, span=tvm_span_from_synr(node.span)
-                )
-            self.report_error(f"Unsupported operator {node.func_name.name}.", node.func_name.span)
-        else:
-            func = self.transform(node.func_name)
-            if isinstance(func, Intrin) and not func.stmt:
-                # pattern 1
-                arg_list = self.parse_arg_list(func, node)
-                return call_with_error_reporting(
-                    self.report_error,
-                    node.func_name.span,
-                    func.handle,
-                    arg_list,
-                    node.func_name.span,
-                )
-            else:
-                args = [self.transform(arg) for arg in node.params]
-                kw_args = {
-                    self.transform(k): self.transform(v) for k, v in node.keyword_params.items()
-                }
-                if isinstance(func, tvm.tir.op.Op):
-                    if not "dtype" in kw_args.keys():
-                        self.report_error(f"{func} requires a dtype keyword argument.", node.span)
-                    # pattern 2
-                    return tvm.tir.Call(
-                        kw_args["dtype"], func, args, span=tvm_span_from_synr(node.span)
-                    )
-                elif callable(func):
-                    # pattern 3
-                    return func(*args, **kw_args)
-                else:
-                    self.report_error(
-                        f"Function is neither callable nor a tvm.tir.op.Op (it is a {type(func)}).",
-                        node.func_name.span,
-                    )
-
-    def transform_UnassignedCall(self, node):
-        """Visitor for statements that are function calls.
-
-        This handles function calls that appear on thier own line like `tir.realize`.
-
-        Examples
-        --------
-        .. code-block:: python
-
-            @T.prim_func
-            def f():
-                A = T.buffer_decl([10, 10])
-                T.realize(A[1:2, 1:2], "")  # This is an UnassignedCall
-                A[1, 1] = 2  # This is also an UnassignedCall
-        """
-        # Only allowed builtin operator that can be a statement is x[1] = 3 i.e. subscript assign.
-        if isinstance(node.call.func_name, ast.Op):
-            if node.call.func_name.name == ast.BuiltinOp.SubscriptAssign:
-                return self.transform_SubscriptAssign(node.call)
-
-            if node.call.func_name.name == ast.BuiltinOp.AttrAssign:
-                return self.transform_AttrAssign(node.call)
-
-            self.report_error(
-                "Binary and unary operators are not allowed as a statement", node.span
-            )
-
-        # handle a regular function call
-        func = self.transform(node.call.func_name)
-        arg_list = self.parse_arg_list(func, node.call)
-
-        if isinstance(func, tir.scope_handler.AssertHandler):
-            self.report_error(
-                "A standalone `T.Assert` is not allowed. Use `assert condition, message` "
-                "instead.",
-                node.call.func_name.span,
-            )
-
-        if isinstance(func, Intrin):
-            if func.stmt:
-                return call_with_error_reporting(
-                    self.report_error,
-                    node.call.func_name.span,
-                    func.handle,
-                    arg_list,
-                    node.call.func_name.span,
-                )
-            else:
-                self.report_error(f"This intrinsic cannot be used as a statement.", node.call.span)
-        elif isinstance(func, WithScopeHandler) and func.concise_scope and not func.def_symbol:
-            func.enter_scope(node, self.context, arg_list, node.call.func_name.span)
-            func.body = self.parse_body(node)
-            return func.exit_scope(node, self.context, arg_list, node.call.func_name.span)
-        elif isinstance(func, SpecialStmt) and not func.def_symbol:
-            func.handle(node, self.context, arg_list, node.call.func_name.span)
-            return
-
-        self.report_error(
-            "Unexpected statement. Expected an assert, an intrinsic, a with statement, or a "
-            f"special statement, but got {type(func).__name__}.",
-            node.call.func_name.span,
-        )
-
-    def transform_Slice(self, node):
-        """Index slice visitor."""
-        start = self.transform(node.start)
-        end = self.transform(node.end)
-        if not (
-            isinstance(node.step, ast.Constant)
-            and isinstance(node.step.value, int)
-            and node.step.value > 0
-        ):
-            self.report_error(
-                "Only positive integer step size is supported for slices.", node.step.span
-            )
-        return Slice(start, end, node.step.value, tvm_span_from_synr(node.span))
-
-    def transform_Subscript(self, node):
-        """Array access visitor.
-
-        By now only 3 types of Subscript are supported:
-            1. Buffer[index, index, ...], Buffer element access(BufferLoad & BufferStore)
-               Var[index] Buffer element access()
-            2. Buffer[start: stop, start: stop, ...], BufferRealize(realize(buffer[...]))
-            3. Array[index], Buffer element access
-        """
-
-        symbol = self.transform(node.params[0])
-        if symbol is None:
-            self.report_error(
-                f"Variable {node.params[0].id.name} is not defined.", node.params[0].span
-            )
-
-        indexes = [self.transform(x) for x in node.params[1].values]
-        if isinstance(symbol, tvm.tir.expr.Var):
-            if symbol.dtype == "handle":
-                self.report_error(
-                    "Cannot read directly from a handle, use `T.match_buffer` "
-                    "to create a buffer to read from.",
-                    node.params[0].span,
-                )
-            if len(indexes) > 1:
-                self.report_error(
-                    "Only a single index can be provided when indexing into a `var`.",
-                    node.params[1].span,
-                )
-            index = indexes[0]
-            if not isinstance(index, (tvm.tir.PrimExpr, int)):
-                self.report_error(
-                    "Var load index should be an int or PrimExpr, but it is a" + type(index),
-                    node.span,
-                )
-
-            self.report_error(
-                "Use of tir.Load has been deprecated in favor of tir.BufferLoad", node.span
-            )
-        elif isinstance(symbol, tvm.tir.Buffer):
-            return BufferSlice(
-                symbol, indexes, self.report_error, span=tvm_span_from_synr(node.span)
-            )
-        elif isinstance(symbol, tvm.container.Array):
-            if len(indexes) > 1:
-                self.report_error(
-                    "Array access should be one-dimension access, but the indices are "
-                    + str(indexes),
-                    node.span,
-                )
-            index = indexes[0]
-            if not isinstance(index, (int, tvm.tir.expr.IntImm)):
-                self.report_error(
-                    "Array access index expected int or IntImm, but got " + type(index),
-                    node.span,
-                )
-            if int(index) >= len(symbol):
-                self.report_error(
-                    f"Array access out of bound, size: {len(symbol)}, got index {index}.",
-                    node.span,
-                )
-            return symbol[int(index)]
-        else:
-            self.report_error(
-                f"Cannot subscript from a {type(symbol).__name__}. Only variables and "
-                "buffers are supported.",
-                node.params[0].span,
-            )
-
-    def transform_Attr(self, node):
-        """Visitor for field access of the form `x.y`.
-
-        This visitor is used to lookup function and symbol names. We have two
-        cases to handle here:
-        1. If we have a statement of the form `tir.something`, then we lookup
-           `tir.something` in the `Registry`. If the function is not in the
-           registry, then we try to find a `tvm.ir.op.Op` with the same name.
-        2. All other names `tvm.something` are lookup up in this current python
-           namespace.
-        """
-
-        def get_full_attr_name(node: ast.Attr) -> str:
-            reverse_field_names = [node.field.name]
-            while isinstance(node.object, ast.Attr):
-                node = node.object
-                reverse_field_names.append(node.field.name)
-            if isinstance(node.object, ast.Var):
-                reverse_field_names.append(node.object.id.name)
-            return ".".join(reversed(reverse_field_names))
-
-        if isinstance(node.object, (ast.Var, ast.Attr)):
-            full_attr_name = get_full_attr_name(node)
-            attr_object, fields = full_attr_name.split(".", maxsplit=1)
-            if self.match_tir_namespace(attr_object):
-                func_name = "tir." + fields
-                res = Registry.lookup(func_name)
-                if res is not None:
-                    return res
-                try:
-                    return tvm.ir.op.Op.get(func_name)
-                except TVMError as e:
-                    # Check if we got an attribute error
-                    if e.args[0].find("AttributeError"):
-                        self.report_error(f"Unregistered function `tir.{fields}`.", node.span)
-                    else:
-                        raise e
-
-        symbol = self.transform(node.object)
-        if symbol is None:
-            self.report_error("Unsupported Attribute expression.", node.object.span)
-        if not hasattr(symbol, node.field.name):
-            self.report_error(
-                f"Type {type(symbol)} does not have a field called `{node.field.name}`.", node.span
-            )
-        res = getattr(symbol, node.field.name)
-        return res
-
-    def transform_TypeAttr(self, node):
-        """Visitor for field access of the form `x.y` for types.
-
-        We have two cases here:
-        1. If the type is of the form `T.something`, we look up the type in
-           the `tir` namespace in this module.
-        2. If the type is of the form `tvm.x.something` then we look up
-           `tvm.x.something` in this modules namespace.
-        """
-        if isinstance(node.object, ast.TypeVar):
-            if self.match_tir_namespace(node.object.id.name):
-                if not hasattr(tir, node.field.name):
-                    self.report_error(
-                        f"Invalid type annotation `tir.{node.field.name}`.", node.span
-                    )
-                return getattr(tir, node.field.name)
-
-        symbol = self.transform(node.object)
-        if symbol is None:
-            self.report_error("Unsupported Attribute expression", node.object.span)
-        if not hasattr(symbol, node.field):
-            self.report_error(
-                f"Type {type(symbol)} does not have a field called `{node.field}`.", node.span
-            )
-        res = getattr(symbol, node.field)
-        return res
-
-    def transform_DictLiteral(self, node):
-        """Dictionary literal visitor.
-
-        Handles dictionary literals of the form `{x:y, z:2}`.
-        """
-
-        keys = [self.transform(key) for key in node.keys]
-        values = [self.transform(value) for value in node.values]
-
-        return dict(zip(keys, values))
-
-    def transform_Tuple(self, node):
-        """Tuple visitor.
-
-        Handles tuples of the form `(x, y, 2)`.
-        """
-
-        return tuple(self.transform(element) for element in node.values)
-
-    def transform_ArrayLiteral(self, node):
-        """List literal visitor.
-
-        Handles lists of the form `[x, 2, 3]`.
-        """
-
-        return [self.transform(element) for element in node.values]
-
-    def transform_Var(self, node):
-        """Variable visitor
-
-        Handles variables like `x` in `x = 2`.
-        """
-
-        name = node.id.name
-        if name == "meta":
-            return self.meta
-        symbol = Registry.lookup(name)
-        if symbol is not None:
-            return symbol
-        symbol = self.context.lookup_symbol(name)
-        if symbol is not None:
-            return symbol
-        self.report_error(f"Unknown identifier {name}.", node.span)
-
-    def transform_TypeVar(self, node):
-        """Type variable visitor.
-
-        Equivalent to `transform_Var` but for types.
-        """
-        name = node.id.name
-        symbol = Registry.lookup(name) or self.context.lookup_symbol(name)
-        if symbol is not None:
-            return symbol
-        self.report_error(f"Unknown identifier {name}.", node.span)
-
-    def transform_Constant(self, node):
-        """Constant value visitor.
-
-        Constant values include `None`, `"strings"`, `2` (integers), `4.2`
-        (floats), and `true` (booleans).
-        """
-        return tvm.runtime.convert(node.value, span=tvm_span_from_synr(node.span))
-
-    def transform_TypeConstant(self, node):
-        """Constant value visitor for types.
-
-        See `transform_Constant`.
-        """
-        if self._inside_buffer_sugar:
-            return self.transform_Constant(node)
-
-        return node.value
-
-    def transform_TypeTuple(self, node):
-        """Tuple value visitor for types.
-
-        Mostly used in `transform_TypeCall` and `transform_TypeApply`.
-        """
-        return [self.transform(value) for value in node.values]
-
-    def transform_TypeCall(self, node):
-        """TypeCall visitor
-
-        This occurs when an expression is used inside a T.Buffer
-        parameter annotation.
-        """
-
-        # ast.Call has the BuiltinOp as node.func_name.name, where
-        # ast.TypeCall has the BuiltinOp as node.func_name.  So we can
-        # delegate to self.transform_Call, but the error messages for
-        # unsupported operations will highlight the entire expression
-        # and not just the function itself.
-        op = ast.Op(node.span, node.func_name)
-        call = ast.Call(node.span, op, node.params, node.keyword_params)
-        return self.transform_Call(call)
-
-    def transform_TypeApply(self, node):
-        """Visitor for Type[Type] expressions.
-
-        Mostly used for ``T.Ptr`` expressions.
-        """
-        func = self.transform(node.func_name)
-
-        if not isinstance(func, ty.TypeGeneric) or not hasattr(func, "__getitem__"):
-            self.report_error(
-                f"Use of type arguments requires a type that accepts type arguments (e.g. T.Ptr), "
-                f"but found {type(func).__name__} instead.",
-                node.span,
-            )
-
-        param_types = []
-        for idx, param in enumerate(node.params):
-            param_type = self.transform(param)
-            if not isinstance(param_type, ty.TypeGeneric) and func.require_type_generic_at(idx):
-                self.report_error(
-                    f"Expected a type but found {type(param).__name__} "
-                    f"at {idx}th type argument",
-                    param.span,
-                )
-
-            param_types.append(param_type)
-
-        if len(param_types) == 1:
-            return func[param_types[0]]
-        else:
-            return func[param_types]
-
-    def handle_match_buffer_type(self, node, buffer_name):
-        """special function to handle syntax sugar for match buffer.
-
-        This method is for buffer declarations in the function parameters.
-        """
-        func = self.transform(node.func_name)
-        assert isinstance(func, SpecialStmt)
-
-        # parse args and kwargs for TypeCall and TypeApply
-        self._inside_buffer_sugar = True
-        try:
-            arg_list = self.parse_arg_list(func, node)
-        finally:
-            self._inside_buffer_sugar = False
-
-        # Note that the third element in arg_list would always be the 'name'
-        # TODO: This index is hardcoded as a workaround. Better to make it programmatic
-        if arg_list[2] is None:
-            arg_list[2] = buffer_name
-        buf = func.handle(node, self.context, arg_list, node.func_name.span)
-        return buf
-
-    def transform_Return(self, node):
-        self.report_error(
-            "TVM script does not support return statements. Instead the last statement in any "
-            "block is implicitly returned.",
-            node.span,
-        )
-
-
-def get_tir_namespace(script: Union[Callable, type]) -> List[str]:
-    assert inspect.isfunction(script) or inspect.isclass(script)
-    env: Dict[str, Any] = script.__globals__
-    return [key for key in env.keys() if env[key] == tir]
-
-
-def from_source(
-    input_func: Union[str, Callable], tir_prefix: Optional[List[str]] = None
-) -> Union[PrimFunc, IRModule]:
-    """Parse function or string into PrimFunc or IRModule.
-
-    If possible, pass the TVM script in as a function so that line numbers and
-    filename will be accurate.
-
-    Parameters
-    ----------
-    input_module : Union[str, Callable]
-        The python function to be parsed.
-
-    tir_prefix : Optional[List[str]]
-        The tir prefix list. Only works for str input, default by "tir" and "T".
-
-    Returns
-    -------
-    output : Union[Function, Module]
-        The Function or Module in IR.
-    """
-    if isinstance(input_func, str):
-        tir_prefix = ["T", "tir"] if tir_prefix is None else tir_prefix
-        return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix, {}))
-    elif inspect.isfunction(input_func):
-        _, start_line = inspect.getsourcelines(input_func)
-        env: Dict[str, Any] = input_func.__globals__
-        namespace = [key for key in env.keys() if env[key] is tir]
-        _closure_vars = inspect.getclosurevars(input_func)
-        closure_vars = {**_closure_vars.nonlocals, **_closure_vars.globals}
-        parser = TVMScriptParser(start_line, namespace, closure_vars)
-        result = to_ast(input_func, TVMDiagnosticCtx(), parser)
-        return result
-    else:
-        raise TypeError("Only function definitions are supported.")
-
-
-def ir_module(input_module: type) -> IRModule:
-    """Decorate a python class as tvm IRModule.
-
-    Parameters
-    ----------
-    input_module : type
-        The python class to be parsed.
-
-    Returns
-    -------
-    output : IRModule
-        The result IRModule.
-    """
-    if inspect.isclass(input_module):
-        func_dict = {
-            name: f for name, f in input_module.__dict__.items() if isinstance(f, BaseFunc)
-        }
-        return IRModule(func_dict)
-    raise TypeError("Only class definitions are supported.")
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/__init__.py
similarity index 83%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/__init__.py
index 555659d0c5..5161a2601c 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/__init__.py
@@ -13,9 +13,9 @@
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
-# under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
-from . import tir
-
-from .parser import ir_module, from_source
+# under the Licens.
+"""The parser"""
+from . import _core, ir, tir
+from ._core import parse
+from .ir import ir_module
+from .tir import prim_func
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/_core.py
similarity index 76%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/_core.py
index 555659d0c5..4f5411dc36 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/_core.py
@@ -13,9 +13,10 @@
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
-# under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
-from . import tir
-
-from .parser import ir_module, from_source
+# under the Licens.
+"""The core parser infra"""
+# pylint: disable=unused-import
+from .core import dispatch, doc, utils
+from .core.dispatch import OpMethod, register_op
+from .core.entry import parse
+from .core.parser import Parser
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/core/__init__.py
similarity index 85%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/core/__init__.py
index 555659d0c5..94d8dab032 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/core/__init__.py
@@ -14,8 +14,5 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-
-from . import tir
-
-from .parser import ir_module, from_source
+"""The core parser infra"""
+from . import diagnostics, dispatch, doc, doc_core, entry, evaluator, parser, utils
diff --git a/python/tvm/script/parser/core/diagnostics.py b/python/tvm/script/parser/core/diagnostics.py
new file mode 100644
index 0000000000..51c26bbc24
--- /dev/null
+++ b/python/tvm/script/parser/core/diagnostics.py
@@ -0,0 +1,175 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import inspect
+import re
+import sys
+from typing import Union
+
+from tvm.ir import IRModule, SourceName, Span, diagnostics
+
+from . import doc
+
+
+class Source:
+    source_name: str
+    start_line: int
+    start_column: int
+    source: str
+    full_source: str
+
+    def __init__(self, program: Union[str, doc.AST]):
+        if isinstance(program, str):
+            self.source_name = "<str>"
+            self.start_line = 1
+            self.start_column = 0
+            self.source = program
+            self.full_source = program
+            return
+
+        self.source_name = inspect.getsourcefile(program)  # type: ignore
+        lines, self.start_line = getsourcelines(program)  # type: ignore
+        if lines:
+            self.start_column = len(lines[0]) - len(lines[0].lstrip())
+        else:
+            self.start_column = 0
+        if self.start_column and lines:
+            self.source = "\n".join([l[self.start_column :].rstrip() for l in lines])
+        else:
+            self.source = "".join(lines)
+        try:
+            # It will cause a problem when running in Jupyter Notebook.
+            # `mod` will be <module '__main__'>, which is a built-in module
+            # and `getsource` will throw a TypeError
+            mod = inspect.getmodule(program)
+            if mod:
+                self.full_source = inspect.getsource(mod)
+            else:
+                self.full_source = self.source
+        except TypeError:
+            # It's a work around for Jupyter problem.
+            # Since `findsource` is an internal API of inspect, we just use it
+            # as a fallback method.
+            src, _ = inspect.findsource(program)  # type: ignore
+            self.full_source = "".join(src)
+
+    def as_ast(self) -> doc.AST:
+        return doc.parse(self.source)
+
+
+_getfile = inspect.getfile  # pylint: disable=invalid-name
+_findsource = inspect.findsource  # pylint: disable=invalid-name
+
+
+def _patched_inspect_getfile(obj):
+    if not inspect.isclass(obj):
+        return _getfile(obj)
+    mod = getattr(obj, "__module__", None)
+    if mod is not None:
+        file = getattr(sys.modules[mod], "__file__", None)
+        if file is not None:
+            return file
+    for _, member in inspect.getmembers(obj):
+        if inspect.isfunction(member):
+            if obj.__qualname__ + "." + member.__name__ == member.__qualname__:
+                return inspect.getfile(member)
+    raise TypeError(f"Source for {obj:!r} not found")
+
+
+def findsource(obj):
+    import linecache  # pylint: disable=import-outside-toplevel
+
+    if not inspect.isclass(obj):
+        return _findsource(obj)
+
+    file = inspect.getsourcefile(obj)
+    if file:
+        linecache.checkcache(file)
+    else:
+        file = inspect.getfile(obj)
+        if not (file.startswith("<") and file.endswith(">")):
+            raise OSError("source code not available")
+
+    module = inspect.getmodule(obj, file)
+    if module:
+        lines = linecache.getlines(file, module.__dict__)
+    else:
+        lines = linecache.getlines(file)
+    if not lines:
+        raise OSError("could not get source code")
+    qual_names = obj.__qualname__.replace(".<locals>", "<locals>").split(".")
+    pattern_list = []
+    for name in qual_names:
+        if name.endswith("<locals>"):
+            pattern_list.append(re.compile(r"^(\s*)def\s*" + name[:-8] + r"\b"))
+        else:
+            pattern_list.append(re.compile(r"^(\s*)class\s*" + name + r"\b"))
+    for i, line in enumerate(lines):
+        match = pattern_list[0].match(line)
+        if match:
+            pattern_list.pop(0)
+        if not pattern_list:
+            return lines, i
+    raise OSError("could not find class definition")
+
+
+def getsourcelines(obj):
+    obj = inspect.unwrap(obj)
+    lines, l_num = findsource(obj)
+    return inspect.getblock(lines[l_num:]), l_num + 1
+
+
+inspect.getfile = _patched_inspect_getfile
+
+
+class Diagnostics:
+
+    source: Source
+    ctx: diagnostics.DiagnosticContext
+
+    def __init__(self, source: Source):
+        mod = IRModule()
+        mod.source_map.add(source.source_name, source.full_source)
+        self.source = source
+        self.ctx = diagnostics.DiagnosticContext(mod, diagnostics.get_renderer())
+
+    def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel) -> None:
+        lineno = node.lineno or self.source.start_line
+        col_offset = node.col_offset or self.source.start_column
+        end_lineno = node.end_lineno or lineno
+        end_col_offset = node.end_col_offset or col_offset
+        lineno += self.source.start_line - 1
+        end_lineno += self.source.start_line - 1
+        col_offset += self.source.start_column + 1
+        end_col_offset += self.source.start_column + 1
+        self.ctx.emit(
+            diagnostics.Diagnostic(
+                level=level,
+                span=Span(
+                    source_name=SourceName(self.source.source_name),
+                    line=lineno,
+                    end_line=end_lineno,
+                    column=col_offset,
+                    end_column=end_col_offset,
+                ),
+                message=message,
+            )
+        )
+
+    def error(self, node: doc.AST, message: str) -> None:
+        self._emit(node, message, diagnostics.DiagnosticLevel.ERROR)
+        self.ctx.render()
diff --git a/python/tvm/script/parser/core/dispatch.py b/python/tvm/script/parser/core/dispatch.py
new file mode 100644
index 0000000000..f10b90961a
--- /dev/null
+++ b/python/tvm/script/parser/core/dispatch.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type
+
+from .doc import AST
+
+if TYPE_CHECKING:
+    from .parser import Parser
+
+
+ParseMethod = Callable[["Parser", AST], None]
+ParseVTable: Dict[Tuple[str, str], ParseMethod] = {}
+
+OpMethod = Callable[..., Any]
+OpVTable: Dict[Tuple[Type, AST, int], OpMethod] = {}
+
+
+def register(token: str, type_name: str):
+    """Register a method for a dispatch token and type name"""
+
+    def f(method: ParseMethod):
+        ParseVTable[(token, type_name)] = method
+
+    return f
+
+
+def get(
+    token: str,
+    type_name: str,
+    default: Optional[ParseMethod] = None,
+) -> Optional[ParseMethod]:
+    return ParseVTable.get((token, type_name), default)
+
+
+def register_op(ty: Type, op: AST, operand_index: int):  # pylint: disable=invalid-name
+    def f(method: OpMethod):
+        OpVTable[(ty, op, operand_index)] = method
+
+    return f
+
+
+def get_op(  # pylint: disable=invalid-name
+    ty: Type,
+    op: Type,
+    operand_index: int,
+    default: Optional[OpMethod] = None,
+) -> Optional[OpMethod]:
+    return OpVTable.get((ty, op, operand_index), default)
diff --git a/python/tvm/script/parser/core/doc.py b/python/tvm/script/parser/core/doc.py
new file mode 100644
index 0000000000..f6a641cb64
--- /dev/null
+++ b/python/tvm/script/parser/core/doc.py
@@ -0,0 +1,361 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import ast
+import inspect
+import sys
+import typing
+from collections import defaultdict
+
+from . import doc_core as doc
+from .doc_core import *  # pylint: disable=unused-import,wildcard-import,redefined-builtin,W0614
+
+FnToDoc = typing.Callable[[ast.AST], doc.AST]
+FnFromDoc = typing.Callable[[doc.AST], ast.AST]
+
+
+class Entry:
+    to_doc: typing.Optional[FnToDoc]
+    from_doc: typing.Optional[FnFromDoc]
+
+    def __init__(self):
+        self.to_doc = None
+        self.from_doc = None
+
+
+class Registry:
+    _inst: typing.Optional["Registry"] = None
+    table: typing.Dict[str, Entry]
+
+    def __init__(self):
+        self.table = defaultdict(Entry)
+
+
+def register_to_doc(name: str):
+    def f(to_doc: FnToDoc):  # pylint: disable=redefined-outer-name
+        reg = Registry._inst  # pylint: disable=protected-access
+        reg.table[name].to_doc = to_doc
+
+    return f
+
+
+def register_from_doc(name: str):
+    def f(to_doc: FnFromDoc):  # pylint: disable=redefined-outer-name
+        reg = Registry._inst  # pylint: disable=protected-access
+        reg.table[name].from_doc = to_doc
+
+    return f
+
+
+def _is_atomic_type(node):
+    return (
+        node is None
+        or node in [..., True, False]
+        or isinstance(
+            node,
+            (
+                int,
+                float,
+                str,
+                bool,
+                bytes,
+                complex,
+            ),
+        )
+    )
+
+
+def _get_registry_entry(cls_name, attr):
+    cls_name = cls_name.split(".")[-1]
+    reg = Registry._inst  # pylint: disable=protected-access
+    if cls_name in reg.table:
+        entry = reg.table[cls_name]
+        return getattr(entry, attr, None)
+    return None
+
+
+def from_doc(node):
+    if _is_atomic_type(node):
+        return node
+    if isinstance(node, tuple):
+        return tuple(from_doc(n) for n in node)
+    if isinstance(node, list):
+        return [from_doc(n) for n in node]
+    func = _get_registry_entry(node.__class__.__name__, "from_doc")
+    if not func:
+        raise NotImplementedError(f"from_doc is not implemented for: {node.__class__.__name__}")
+    return func(node)
+
+
+def to_doc(node):
+    if _is_atomic_type(node):
+        return node
+    if isinstance(node, tuple):
+        return tuple(to_doc(n) for n in node)
+    if isinstance(node, list):
+        return [to_doc(n) for n in node]
+    func = _get_registry_entry(node.__class__.__name__, "to_doc")
+    if not func:
+        raise NotImplementedError(f"to_doc is not implemented for: {node.__class__.__name__}")
+    return func(node)
+
+
+def parse(
+    source,
+    filename="<unknown>",
+    mode="exec",
+) -> doc.AST:
+    try:
+        program = ast.parse(  # pylint: disable=unexpected-keyword-arg
+            source=source,
+            filename=filename,
+            mode=mode,
+            feature_version=(3, 8),
+        )
+    except:  # pylint: disable=bare-except
+        program = ast.parse(
+            source=source,
+            filename=filename,
+            mode=mode,
+        )
+    return to_doc(program)
+
+
+class NodeVisitor:
+    def visit(self, node: doc.AST) -> None:
+        if isinstance(node, (list, tuple)):
+            for item in node:
+                self.visit(item)
+            return
+        if not isinstance(node, doc.AST):
+            return
+        getattr(
+            self,
+            "visit_" + node.__class__.__name__.split(".")[-1],
+            self.generic_visit,
+        )(node)
+
+    def generic_visit(self, node: doc.AST) -> None:
+        for field in node.__class__._FIELDS:  # pylint: disable=protected-access
+            value = getattr(node, field, None)
+            if value is None:
+                pass
+            elif isinstance(value, (doc.AST, list, tuple)):
+                self.visit(value)
+
+
+class NodeTransformer:
+    def visit(self, node: doc.AST) -> doc.AST:
+        if isinstance(node, list):
+            return [self.visit(item) for item in node]
+        if isinstance(node, tuple):
+            return tuple(self.visit(item) for item in node)
+        if not isinstance(node, doc.AST):
+            return node
+        return getattr(
+            self,
+            "visit_" + node.__class__.__name__.split(".")[-1],
+            self.generic_visit,
+        )(node)
+
+    def generic_visit(self, node: doc.AST) -> doc.AST:
+        kv: typing.Dict[str, typing.Any] = {}
+        for field in node.__class__._FIELDS:  # pylint: disable=protected-access
+            value = getattr(node, field, None)
+            if value is None:
+                pass
+            elif isinstance(value, (doc.AST, list, tuple)):
+                value = self.visit(value)
+            kv[field] = value
+        return node.__class__(**kv)
+
+
+def _register_default():
+    class DefaultTranslator:
+        def __init__(self, doc_cls, func, fields):
+            self.doc_cls = doc_cls  # getattr(doc, name)
+            self.func = func
+            self.fields = fields
+
+        def __call__(self, node):
+            kv = {attr: self.func(getattr(node, attr, None)) for attr in self.fields}
+            return self.doc_cls(**kv)
+
+    Registry._inst = Registry()  # pylint: disable=protected-access
+    for cls_name in dir(doc):
+        doc_cls = getattr(doc, cls_name)
+        if not hasattr(ast, cls_name):
+            continue
+        if inspect.isclass(doc_cls) and issubclass(doc_cls, doc.AST):
+            assert "." not in cls_name
+            register_to_doc(cls_name)(
+                DefaultTranslator(
+                    getattr(doc, cls_name),
+                    to_doc,
+                    doc_cls._FIELDS,  # pylint: disable=protected-access
+                )
+            )
+            register_from_doc(cls_name)(
+                DefaultTranslator(
+                    getattr(ast, cls_name),
+                    from_doc,
+                    doc_cls._FIELDS,  # pylint: disable=protected-access
+                )
+            )
+
+
+def _py_version() -> typing.Tuple[int, int]:
+    return (sys.version_info.major, sys.version_info.minor)
+
+
+def _register_constant_handling():
+    if _py_version() not in [(3, 6), (3, 7)]:
+        return
+
+    def as_constant(f) -> doc.Constant:
+        def to_doc_func(x: ast.AST) -> doc.Constant:
+            return doc.Constant(
+                value=getattr(x, f) if isinstance(f, str) else f(x),
+                kind=None,
+                s=None,
+                n=None,
+                lineno=x.lineno,
+                col_offset=x.col_offset,
+                end_lineno=x.lineno,
+                end_col_offset=x.col_offset,
+            )
+
+        return to_doc_func
+
+    register_to_doc("Str")(as_constant("s"))
+    register_to_doc("NameConstant")(as_constant("value"))
+    register_to_doc("Num")(as_constant("n"))
+    register_to_doc("Bytes")(as_constant("s"))
+    register_to_doc("Ellipsis")(as_constant(lambda _: ...))
+
+
+def _register_subscription_handling():
+    if _py_version() >= (3, 9):
+        return
+
+    def subscript_to_doc(x: ast.Subscript) -> doc.Subscript:
+        if isinstance(x.slice, ast.Slice):
+            return doc.Subscript(
+                value=to_doc(x.value),
+                slice=doc.Slice(
+                    lower=to_doc(x.slice.lower),
+                    upper=to_doc(x.slice.upper),
+                    step=to_doc(x.slice.step),
+                    lineno=getattr(x.slice, "lineno", None),
+                    col_offset=getattr(x.slice, "col_offset", None),
+                    end_lineno=getattr(x.slice, "end_lineno", None),
+                    end_col_offset=getattr(x.slice, "end_col_offset", None),
+                ),
+                ctx=to_doc(x.ctx),
+                lineno=getattr(x, "lineno", None),
+                col_offset=getattr(x, "col_offset", None),
+                end_lineno=getattr(x, "end_lineno", None),
+                end_col_offset=getattr(x, "end_col_offset", None),
+            )
+        if isinstance(x.slice, ast.ExtSlice):
+            return doc.Subscript(
+                value=to_doc(x.value),
+                slice=doc.Tuple(
+                    elts=[to_doc(i) for i in x.slice.dims],
+                    ctx=doc.Load(
+                        lineno=None,
+                        col_offset=None,
+                        end_lineno=None,
+                        end_col_offset=None,
+                    ),
+                    lineno=getattr(x, "lineno", None),
+                    col_offset=getattr(x, "col_offset", None),
+                    end_lineno=getattr(x, "end_lineno", None),
+                    end_col_offset=getattr(x, "end_col_offset", None),
+                ),
+                ctx=to_doc(x.ctx),
+                lineno=getattr(x, "lineno", None),
+                col_offset=getattr(x, "col_offset", None),
+                end_lineno=getattr(x, "end_lineno", None),
+                end_col_offset=getattr(x, "end_col_offset", None),
+            )
+        if isinstance(x.slice, ast.Index):
+            return doc.Subscript(
+                value=to_doc(x.value),
+                slice=to_doc(x.slice.value),
+                ctx=to_doc(x.ctx),
+                lineno=getattr(x, "lineno", None),
+                col_offset=getattr(x, "col_offset", None),
+                end_lineno=getattr(x, "end_lineno", None),
+                end_col_offset=getattr(x, "end_col_offset", None),
+            )
+        raise TypeError(f"Unknown subscript type: {type(x.slice)}")
+
+    def subscript_from_doc(x: doc.Subscript) -> ast.Subscript:
+        if isinstance(x.slice, doc.Slice):
+            result = ast.Subscript(
+                value=from_doc(x.value),
+                slice=from_doc(x.slice),
+                ctx=from_doc(x.ctx),
+            )
+        elif isinstance(x.slice, doc.Tuple):
+            result = ast.Subscript(
+                value=from_doc(x.value),
+                slice=ast.ExtSlice(
+                    dims=[from_doc(i) for i in x.slice.elts],
+                ),
+                ctx=from_doc(x.ctx),
+            )
+        else:
+            result = ast.Subscript(
+                value=from_doc(x.value),
+                slice=ast.Index(value=from_doc(x.slice)),
+                ctx=from_doc(x.ctx),
+            )
+        result.lineno = x.lineno
+        result.col_offset = x.col_offset
+        result.end_lineno = x.end_lineno
+        result.end_col_offset = x.end_col_offset
+        return result
+
+    register_to_doc("Subscript")(subscript_to_doc)
+    register_from_doc("Subscript")(subscript_from_doc)
+
+
+def _register_index_handling():
+    if _py_version() >= (3, 9):
+        return
+
+    def index_to_doc(x: ast.Index) -> doc.Expr:
+        return to_doc(x.value)
+
+    def index_from_doc(x: doc.Expr) -> ast.Index:
+        result = ast.Index(value=from_doc(x), ctx=from_doc(x.ctx))
+        result.lineno = x.lineno
+        result.col_offset = x.col_offset
+        result.end_lineno = x.end_lineno
+        result.end_col_offset = x.end_col_offset
+        return result
+
+    register_to_doc("Index")(index_to_doc)
+    register_from_doc("Index")(index_from_doc)
+
+
+_register_default()
+_register_constant_handling()
+_register_subscription_handling()
+_register_index_handling()
diff --git a/python/tvm/script/printer/doc_core.py b/python/tvm/script/parser/core/doc_core.py
similarity index 100%
rename from python/tvm/script/printer/doc_core.py
rename to python/tvm/script/parser/core/doc_core.py
diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/parser/core/entry.py
similarity index 51%
copy from python/tvm/script/tir/prim_func.py
copy to python/tvm/script/parser/core/entry.py
index 923eb97d27..ccf42e8c15 100644
--- a/python/tvm/script/tir/prim_func.py
+++ b/python/tvm/script/parser/core/entry.py
@@ -14,32 +14,30 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script Interface for PrimFunc"""
+# pylint: disable=missing-docstring
+"""The entry point of TVM parser."""
+from typing import Any, Union
 
-import inspect
-from typing import Callable
+from ...ir_builder import IRBuilder
+from . import doc
+from .diagnostics import Source
+from .parser import Parser
 
-from tvm.tir.function import PrimFunc
-from ..parser import from_source
 
+def parse(program: Union[doc.AST, Any, str], extra_vars=None):
+    if extra_vars is None:
+        from tvm.script.parser import ir  # pylint: disable=import-outside-toplevel
+        from tvm.script.parser import tir  # pylint: disable=import-outside-toplevel
 
-def prim_func(input_func: Callable) -> PrimFunc:
-    """Decorate a python function as tvm script.
+        extra_vars = {
+            "I": ir,
+            "ir": ir,
+            "T": tir,
+            "tir": tir,
+        }
 
-    Parameters
-    ----------
-    func : input_func
-        The function to be parsed.
-
-    Returns
-    -------
-    output : PrimFunc
-        The result functions.
-    """
-    if inspect.isfunction(input_func):
-        result = from_source(input_func)
-        result.__name__ = input_func.__name__
-        result.__qualname__ = input_func.__qualname__
-        return result
-
-    raise TypeError("Only function definitions are supported.")
+    source = Source(program)
+    parser = Parser(source)
+    with IRBuilder() as builder:
+        parser.parse(extra_vars=extra_vars)
+    return builder.get()
diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py
new file mode 100644
index 0000000000..8cbde66e39
--- /dev/null
+++ b/python/tvm/script/parser/core/evaluator.py
@@ -0,0 +1,284 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""AST Evaluation"""
+import ast
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
+
+from . import dispatch, doc
+
+if TYPE_CHECKING:
+    from .parser import Parser
+
+DEFAULT_OP: Dict[Type, Callable[..., Any]] = {
+    doc.Add: lambda a, b: a + b,
+    doc.Sub: lambda a, b: a - b,
+    doc.Mult: lambda a, b: a * b,
+    doc.Div: lambda a, b: a / b,
+    doc.FloorDiv: lambda a, b: a // b,
+    doc.Mod: lambda a, b: a % b,
+    doc.LShift: lambda a, b: a << b,
+    doc.RShift: lambda a, b: a >> b,
+    doc.BitOr: lambda a, b: a | b,
+    doc.BitXor: lambda a, b: a ^ b,
+    doc.BitAnd: lambda a, b: a & b,
+    doc.MatMult: lambda a, b: a @ b,
+    # fmt: off
+    doc.Pow: lambda a, b: a ** b,
+    # fmt: on
+    doc.Eq: lambda a, b: a == b,
+    doc.NotEq: lambda a, b: a != b,
+    doc.Lt: lambda a, b: a < b,
+    doc.LtE: lambda a, b: a <= b,
+    doc.Gt: lambda a, b: a > b,
+    doc.GtE: lambda a, b: a >= b,
+    doc.Is: lambda a, b: a is b,
+    doc.IsNot: lambda a, b: a is not b,
+    doc.In: lambda a, b: a in b,
+    doc.NotIn: lambda a, b: a not in b,
+    doc.And: lambda a, b: a and b,
+    doc.Or: lambda a, b: a or b,
+    doc.Invert: lambda a: ~a,
+    doc.Not: lambda a: not a,
+    doc.UAdd: lambda a: +a,
+    doc.USub: lambda a: -a,
+}
+
+
+class ExprEvaluator:
+
+    parser: "Parser"
+    value_table: Dict[str, Any]
+    new_value_count: int
+
+    def __init__(self, parser: "Parser", value_table: Dict[str, Any]) -> None:
+        super().__init__()
+        self.parser = parser
+        self.value_table = value_table
+        self.new_value_count = 0
+
+    @staticmethod
+    def eval(parser: "Parser", value_table: Dict[str, Any], node: doc.AST) -> Any:
+        self = ExprEvaluator(parser, value_table)
+        result = self._visit(node)  # pylint: disable=protected-access
+        if isinstance(result, doc.Name):
+            if result.id not in self.value_table:
+                self.parser.report_error(result, f"Undefined variable: {result.id}")
+            return self.value_table[result.id]
+        if isinstance(result, doc.Constant):
+            return result.value
+        raise TypeError(f"Unexpected result type: {type(result)}")
+
+    def _add_intermediate_result(self, value: Any) -> doc.Name:
+        name = f"__tvm_tmp_value_{self.new_value_count}"
+        self.new_value_count += 1
+        self.value_table[name] = value
+        lineno = 0
+        col_offset = 0
+        return doc.Name(
+            id=name,
+            ctx=doc.Load(
+                lineno=lineno,
+                col_offset=col_offset,
+                end_lineno=None,
+                end_col_offset=None,
+            ),
+            lineno=lineno,
+            col_offset=col_offset,
+            end_lineno=None,
+            end_col_offset=None,
+        )
+
+    def _visit(self, node: doc.AST) -> Any:
+        if isinstance(node, list):
+            return [self._visit(n) for n in node]
+        if isinstance(node, tuple):
+            return tuple(self._visit(n) for n in node)
+        assert isinstance(node, doc.AST)
+        if isinstance(node, doc.Name):
+            if node.id not in self.value_table:
+                self.parser.report_error(node, f"Undefined variable: {node.id}")
+            return node
+        if isinstance(
+            node,
+            (
+                doc.Constant,
+                doc.expr_context,
+                doc.operator,
+                doc.boolop,
+                doc.unaryop,
+                doc.cmpop,
+            ),
+        ):
+            return node
+        if not isinstance(node, (doc.expr, doc.slice)):
+            return node
+        if isinstance(node, doc.Lambda):
+            return self._eval_lambda(node)
+        fields = {}
+        for field in node.__class__._FIELDS:  # pylint: disable=protected-access
+            attr = getattr(node, field)
+            if isinstance(attr, (doc.AST, tuple, list)):
+                fields[field] = self._visit(attr)
+            else:
+                fields[field] = attr
+        try:
+            if isinstance(node, doc.BoolOp):
+                value = self._eval_bool_op(fields)
+            elif isinstance(node, doc.Compare):
+                value = self._eval_compare(fields)
+            elif isinstance(node, doc.UnaryOp):
+                value = self._eval_unary_op(fields)
+            elif isinstance(node, doc.BinOp):
+                value = self._eval_bin_op(fields)
+            elif isinstance(node, doc.Slice):
+                value = self._eval_slice(fields)
+            else:
+                value = self._eval_expr(node.__class__(**fields))
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.parser.report_error(node, str(e))
+        return self._add_intermediate_result(value)
+
+    def _eval_lambda(self, node: doc.Lambda) -> Any:
+        try:
+            value = self._eval_expr(node)
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.parser.report_error(node, str(e))
+        return self._add_intermediate_result(value)
+
+    def _eval_bool_op(self, fields: Dict[str, Any]) -> Any:
+        op = fields["op"]
+        if not isinstance(op, (doc.And, doc.Or)):
+            raise TypeError(f"Unexpected operator: {op}")
+        value = self._eval_expr(fields["values"][0])
+        for rhs in fields["values"][1:]:
+            value = _eval_op(op, values=[value, self._eval_expr(rhs)])
+        return value
+
+    def _eval_compare(self, fields: Dict[str, Any]) -> Any:
+        value = self._eval_expr(fields["left"])
+        for op, rhs in zip(fields["ops"], fields["comparators"]):
+            value = _eval_op(op, values=[value, self._eval_expr(rhs)])
+        return value
+
+    def _eval_unary_op(self, fields: Dict[str, Any]) -> Any:
+        value = self._eval_expr(fields["operand"])
+        value = _eval_op(fields["op"], values=[value])
+        return value
+
+    def _eval_bin_op(self, fields: Dict[str, Any]) -> Any:
+        return _eval_op(
+            fields["op"],
+            values=[
+                self._eval_expr(fields["left"]),
+                self._eval_expr(fields["right"]),
+            ],
+        )
+
+    def _eval_slice(self, fields: Dict[str, Any]) -> Any:
+        lower, upper, step = fields["lower"], fields["upper"], fields["step"]
+
+        lower = self._eval_expr(lower) if lower is not None else None
+        upper = self._eval_expr(upper) if upper is not None else None
+        step = self._eval_expr(step) if step is not None else None
+
+        return slice(lower, upper, step)
+
+    def _eval_expr(self, v: Any) -> Any:
+        return _eval_expr(v, self.value_table)
+
+
+def eval_expr(
+    parser: "Parser",
+    node: Union[doc.expr, doc.Expression],
+    dict_globals: Optional[Dict[str, Any]],
+) -> Any:
+    value_table = {}
+    if dict_globals is not None:
+        value_table.update(dict_globals)
+    return ExprEvaluator.eval(parser, value_table, node)
+
+
+def eval_assign(
+    parser: "Parser",
+    target: doc.expr,
+    source: Any,
+) -> Dict[str, Any]:
+    try:
+        return _eval_assign(target, source)
+    except Exception as e:  # pylint: disable=broad-except,invalid-name
+        parser.report_error(target, f"Failed to evaluate assignment: {str(e)}")
+        raise
+
+
+def _eval_expr(
+    node: Union[doc.expr, doc.Expression],
+    dict_globals: Optional[Dict[str, Any]],
+) -> Any:
+    node = doc.from_doc(node)
+    if isinstance(node, ast.expr):
+        node = ast.Expression(body=node)
+    assert isinstance(node, ast.Expression), "Expects an ast.Expression, but gets: " + str(node)
+    if dict_globals is None:
+        dict_globals = {}
+    node = ast.fix_missing_locations(node)
+    exe = compile(node, filename="<ast>", mode="eval")
+    return eval(exe, dict_globals)  # pylint: disable=eval-used
+
+
+def _eval_op(
+    op: doc.AST,
+    values: List[Any],
+):
+    op_type = type(op)  # pylint: disable=protected-access
+    for i, v in enumerate(values):
+        v_type = getattr(type(v), "_dispatch_type", None)
+        if v_type is None:
+            continue
+        f = dispatch.get_op(ty=v_type, op=op_type, operand_index=i, default=None)
+        if f is not None:
+            return f(*values)
+    return DEFAULT_OP[op_type](*values)
+
+
+def _eval_assign(
+    target: doc.expr,
+    source: Any,
+) -> Dict[str, Any]:
+    target = doc.from_doc(target)
+    assert isinstance(target, ast.expr)
+    RHS_VAR_NAME = "__tvm_rhs_var__"  # pylint: disable=invalid-name
+    rhs_var_name = RHS_VAR_NAME
+    dict_locals = {rhs_var_name: source}
+    mod = ast.fix_missing_locations(
+        ast.Module(
+            body=[
+                ast.Assign(
+                    targets=[target],
+                    value=ast.Name(
+                        id=rhs_var_name,
+                        ctx=ast.Load(),
+                    ),
+                )
+            ],
+            type_ignores=[],
+        )
+    )
+    exe = compile(mod, filename="<ast>", mode="exec")
+    exec(exe, {}, dict_locals)  # pylint: disable=exec-used
+    del dict_locals[rhs_var_name]
+    return dict_locals
diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py
new file mode 100644
index 0000000000..e26324262f
--- /dev/null
+++ b/python/tvm/script/parser/core/parser.py
@@ -0,0 +1,273 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""The core parser"""
+from collections import defaultdict
+from contextlib import contextmanager
+from typing import Any, Callable, Dict, List, Optional, Set, Union
+
+from tvm.error import DiagnosticError
+
+from . import dispatch, doc
+from .diagnostics import Diagnostics, Source
+from .evaluator import eval_assign, eval_expr
+
+DEFAULT_VISIT = {
+    "Interactive",
+    "Module",
+    "Expression",
+    "Pass",
+}
+
+
+def _deferred(f: Callable[[], None]):
+    @contextmanager
+    def context():
+        try:
+            yield
+        finally:
+            f()
+
+    return context()
+
+
+class VarTableFrame:
+    vars: Set[str]
+
+    def __init__(self):
+        self.vars = set()
+
+    def add(self, var: str):
+        if var in self.vars:
+            raise ValueError(f"Variable {var} already defined in current scope")
+        self.vars.add(var)
+
+    def pop_all(self, fn_pop: Callable[[str], None]):
+        for var in self.vars:
+            fn_pop(var)
+        self.vars.clear()
+
+
+class VarTable:
+
+    frames: List[VarTableFrame]
+    name2value: Dict[str, List[Any]]
+
+    def __init__(self):
+        self.frames = []
+        self.name2value = defaultdict(list)
+
+    def with_frame(self):
+        def pop_frame():
+            frame = self.frames.pop()
+            frame.pop_all(lambda name: self.name2value[name].pop())
+
+        self.frames.append(VarTableFrame())
+        return _deferred(pop_frame)
+
+    def add(self, var: str, value: Any):
+        self.frames[-1].add(var)
+        self.name2value[var].append(value)
+
+    def get(self) -> Dict[str, Any]:
+        return {key: values[-1] for key, values in self.name2value.items() if values}
+
+    def exist(self, value: Any):
+        for v in self.name2value.values():
+            if v is value:
+                return True
+        return False
+
+
+def _dispatch_wrapper(func: dispatch.ParseMethod) -> dispatch.ParseMethod:
+    def _wrapper(self: "Parser", node: doc.AST) -> None:
+        try:
+            return func(self, node)
+        except DiagnosticError:
+            raise
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.report_error(node, str(e))
+            raise
+
+    return _wrapper
+
+
+def _dispatch(self: "Parser", type_name: str) -> dispatch.ParseMethod:
+    for token in [self.dispatch_tokens[-1], "default"]:
+        func = dispatch.get(token=token, type_name=type_name, default=None)
+        if func is not None:
+            return _dispatch_wrapper(func)
+    return _dispatch_wrapper(lambda self, node: self.generic_visit(node))
+
+
+class Parser(doc.NodeVisitor):
+    """The TVMScript parser"""
+
+    diag: Diagnostics
+    dispatch_tokens: List[str]
+    var_table: VarTable
+
+    def __init__(self, source: Source) -> None:
+        self.diag = Diagnostics(source)
+        self.dispatch_tokens = ["default"]
+        self.var_table = VarTable()
+
+    def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any:
+        if extra_vars is None:
+            extra_vars = {}
+        with self.var_table.with_frame():
+            for k, v in extra_vars.items():
+                self.var_table.add(k, v)
+            node = self.diag.source.as_ast()
+            self.visit(node)
+
+    def with_dispatch_token(self, token: str):
+        def pop_token():
+            self.dispatch_tokens.pop()
+
+        self.dispatch_tokens.append(token)
+        return _deferred(pop_token)
+
+    def eval_expr(
+        self,
+        node: Union[doc.Expression, doc.expr],
+        extra_vars: Optional[Dict[str, Any]] = None,
+    ) -> Any:
+        var_values = self.var_table.get()
+        if extra_vars is not None:
+            for k, v in extra_vars.items():
+                var_values[k] = v
+        return eval_expr(self, node, var_values)
+
+    def _duplicate_lhs_check(self, target: doc.expr) -> Union[bool, Set[str]]:
+        if isinstance(target, (doc.Tuple, doc.List)):
+            vars: Set[str] = set()  # pylint: disable=redefined-builtin
+            for i in target.elts:
+                res = self._duplicate_lhs_check(i)
+                if isinstance(res, bool) and res:
+                    return True
+                assert isinstance(res, set)
+                if vars & res:
+                    return True
+                vars = vars.union(res)
+            return vars
+        elif isinstance(target, doc.Name):
+            return {target.id}
+        else:
+            self.report_error(target, "Invalid type in assign statement")
+            raise NotImplementedError
+
+    def eval_assign(
+        self,
+        target: doc.expr,
+        source: Any,
+        bind_value: Callable[["Parser", doc.expr, str, Any], Any],
+    ) -> Dict[str, Any]:
+        if self._duplicate_lhs_check(target) is True:
+            self.report_error(target, "Duplicate vars assigned.")
+        var_values = eval_assign(self, target, source)
+        for k, v in var_values.items():
+            var = bind_value(self, target, k, v)
+            self.var_table.add(k, var)
+        return var_values
+
+    def report_error(self, node: doc.AST, msg: str) -> None:  # pylint: disable=no-self-use
+        self.diag.error(node, msg)
+
+    def visit(self, node: doc.AST) -> None:
+        if isinstance(node, (list, tuple)):
+            for item in node:
+                self.visit(item)
+            return
+        if not isinstance(node, doc.AST):
+            return
+        name = node.__class__.__name__.split(".")[-1]
+        if name in DEFAULT_VISIT:
+            func = self.generic_visit
+        else:
+            func = getattr(self, "visit_" + name, None)
+        if func is None:
+            raise NotImplementedError(f"Visitor of AST node is not implemented: {name}")
+        try:
+            func(node)
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.report_error(node, str(e))
+
+    def visit_body(self, node: List[doc.stmt]) -> Any:
+        for stmt in node:
+            self.visit(stmt)
+
+    def visit_tvm_annotation(self, node: doc.expr) -> Any:
+        return _dispatch(self, "tvm_annotation")(self, node)
+
+    def visit_FunctionDef(self, node: doc.FunctionDef) -> Any:  # pylint: disable=invalid-name
+        if not node.decorator_list:
+            self.report_error(node, "Function must be decorated")
+        # TODO: only the last decorator is parsed
+        decorator = self.eval_expr(node.decorator_list[-1])
+        if not hasattr(decorator, "dispatch_token"):
+            self.report_error(node, "The parser does not understand the decorator")
+        token = decorator.dispatch_token
+        func = dispatch.get(token=token, type_name="FunctionDef", default=None)
+        if func is None:
+            self.report_error(node, "The parser does not understand the decorator")
+        try:
+            func(self, node)
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.report_error(node, str(e))
+
+    def visit_ClassDef(self, node: doc.ClassDef) -> Any:  # pylint: disable=invalid-name
+        func = dispatch.get(token="ir", type_name="ClassDef", default=None)
+        if func is None:
+            self.report_error(node, "The parser does not understand the decorator")
+        try:
+            func(self, node)
+        except Exception as e:  # pylint: disable=broad-except,invalid-name
+            self.report_error(node, str(e))
+
+    def visit_arguments(self, node: doc.arguments) -> Any:
+        return _dispatch(self, "arguments")(self, node)
+
+    def visit_For(self, node: doc.For) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "For")(self, node)
+
+    def visit_While(self, node: doc.While) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "While")(self, node)
+
+    def visit_With(self, node: doc.With) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "With")(self, node)
+
+    def visit_Assign(self, node: doc.Assign) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "Assign")(self, node)
+
+    def visit_Expr(self, node: doc.Expr) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "Expr")(self, node)
+
+    def visit_If(self, node: doc.If) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "If")(self, node)
+
+    def visit_AnnAssign(self, node: doc.AnnAssign) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "AnnAssign")(self, node)
+
+    def visit_AugAssign(self, node: doc.AugAssign) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "AugAssign")(self, node)
+
+    def visit_Assert(self, node: doc.Assert) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "Assert")(self, node)
+
+    def visit_Return(self, node: doc.Return) -> Any:  # pylint: disable=invalid-name
+        return _dispatch(self, "Return")(self, node)
diff --git a/python/tvm/script/tir/prim_func.py b/python/tvm/script/parser/core/utils.py
similarity index 54%
rename from python/tvm/script/tir/prim_func.py
rename to python/tvm/script/parser/core/utils.py
index 923eb97d27..aae45fe6ff 100644
--- a/python/tvm/script/tir/prim_func.py
+++ b/python/tvm/script/parser/core/utils.py
@@ -14,32 +14,23 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script Interface for PrimFunc"""
-
+# pylint: disable=missing-docstring
 import inspect
-from typing import Callable
-
-from tvm.tir.function import PrimFunc
-from ..parser import from_source
-
+from typing import Any, Callable, Dict
 
-def prim_func(input_func: Callable) -> PrimFunc:
-    """Decorate a python function as tvm script.
 
-    Parameters
-    ----------
-    func : input_func
-        The function to be parsed.
+def inspect_function_capture(func: Callable) -> Dict[str, Any]:
+    captured = {
+        **inspect.getclosurevars(func).nonlocals,
+        **func.__globals__,  # type: ignore
+    }
+    return captured
 
-    Returns
-    -------
-    output : PrimFunc
-        The result functions.
-    """
-    if inspect.isfunction(input_func):
-        result = from_source(input_func)
-        result.__name__ = input_func.__name__
-        result.__qualname__ = input_func.__qualname__
-        return result
 
-    raise TypeError("Only function definitions are supported.")
+def inspect_class_capture(cls: type) -> Dict[str, Any]:
+    result: Dict[str, Any] = {}
+    for _, v in cls.__dict__.items():
+        if inspect.isfunction(v):
+            func_vars = inspect_function_capture(v)
+            result.update(**func_vars)
+    return result
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/ir/__init__.py
similarity index 85%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/ir/__init__.py
index 555659d0c5..4cbd9910a2 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/ir/__init__.py
@@ -14,8 +14,8 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+# pylint: disable=missing-docstring
+from . import parser as _parser
+from .entry import ir_module
 
-from . import tir
-
-from .parser import ir_module, from_source
+__all__ = ["ir_module"]
diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/parser/ir/entry.py
similarity index 63%
copy from python/tvm/script/tir/__init__.py
copy to python/tvm/script/parser/ir/entry.py
index 2f2b4bbc25..3c1e4de5a7 100644
--- a/python/tvm/script/tir/__init__.py
+++ b/python/tvm/script/parser/ir/entry.py
@@ -14,18 +14,20 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVMScript for TIR"""
+# pylint: disable=missing-docstring
+import inspect
+from typing import Type
 
-# Type system
-from .ty import void, boolean, handle, Ptr, Tuple, Buffer
+from tvm.ir import IRModule
 
-from .prim_func import prim_func
+from .._core import parse, utils
 
-# add all floating point and integer datatypes to the module
-for _dtype in ["float", "uint", "int"]:
-    for _size in ["8", "16", "32", "64"]:
-        for _lanes in ["", "x4", "x8", "x16", "x32"]:
-            from . import ty
 
-            _name = _dtype + _size + _lanes
-            globals()[_name] = getattr(ty, _name)
+def ir_module(f: Type) -> IRModule:
+    if not inspect.isclass(f):
+        raise TypeError(f"Expect a class, but got: {f}")
+
+    return parse(f, utils.inspect_class_capture(f))
+
+
+setattr(ir_module, "dispatch_token", "ir")
diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/parser/ir/parser.py
similarity index 55%
rename from python/tvm/script/tir/__init__.py
rename to python/tvm/script/parser/ir/parser.py
index 2f2b4bbc25..8871d3b415 100644
--- a/python/tvm/script/tir/__init__.py
+++ b/python/tvm/script/parser/ir/parser.py
@@ -14,18 +14,24 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVMScript for TIR"""
+# pylint: disable=missing-docstring
+from ...ir_builder import ir as I
+from .._core import Parser, dispatch, doc
 
-# Type system
-from .ty import void, boolean, handle, Ptr, Tuple, Buffer
 
-from .prim_func import prim_func
+@dispatch.register(token="ir", type_name="ClassDef")
+def _visit_class_def(self: Parser, node: doc.ClassDef) -> None:
+    with self.var_table.with_frame():
+        with I.ir_module():
+            with self.with_dispatch_token("ir"):
+                self.visit_body(node.body)
 
-# add all floating point and integer datatypes to the module
-for _dtype in ["float", "uint", "int"]:
-    for _size in ["8", "16", "32", "64"]:
-        for _lanes in ["", "x4", "x8", "x16", "x32"]:
-            from . import ty
 
-            _name = _dtype + _size + _lanes
-            globals()[_name] = getattr(ty, _name)
+@dispatch.register(token="ir", type_name="Assign")
+def _visit_assign(_self: Parser, _node: doc.Assign) -> None:
+    pass
+
+
+@dispatch.register(token="ir", type_name="Expr")
+def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
+    pass
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/parser/tir/__init__.py
similarity index 71%
copy from python/tvm/script/__init__.py
copy to python/tvm/script/parser/tir/__init__.py
index 555659d0c5..930764f73d 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/parser/tir/__init__.py
@@ -14,8 +14,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
+# pylint: disable=missing-docstring
+from ...ir_builder.tir import *  # pylint: disable=redefined-builtin
+from ...ir_builder.tir import ir as _tir
+from . import operation as _operation
+from . import parser as _parser
+from .entry import Buffer, Ptr, prim_func
 
-from . import tir
-
-from .parser import ir_module, from_source
+__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func"]
diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py
new file mode 100644
index 0000000000..db4e2dd9a3
--- /dev/null
+++ b/python/tvm/script/parser/tir/entry.py
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import inspect
+from typing import Callable, Union
+
+from tvm.tir import Buffer, PrimFunc
+
+from ...ir_builder.tir import buffer_decl, ptr
+from .._core import parse, utils
+
+
+def _is_defined_in_class(frames):
+    if len(frames) > 2:
+        maybe_class_frame = frames[2]
+        statement_list = maybe_class_frame[4]
+        if statement_list is None:
+            return False
+        first_statement = statement_list[0]
+        line = first_statement.strip()
+        if line.startswith("class "):
+            return True
+        if line.startswith("@") and "ir_module" in line:
+            return True
+    return False
+
+
+def prim_func(f: Callable) -> Union[PrimFunc, Callable]:
+    if not inspect.isfunction(f):
+        raise TypeError(f"Expect a function, but got: {f}")
+    if _is_defined_in_class(inspect.stack()):
+        return f
+    return parse(f, utils.inspect_function_capture(f))
+
+
+setattr(prim_func, "dispatch_token", "tir")
+
+
+class BufferProxy:
+    def __call__(
+        self,
+        shape,
+        dtype="float32",
+        data=None,
+        strides=None,
+        elem_offset=None,
+        scope="global",
+        align=0,
+        offset_factor=0,
+        buffer_type="",
+        axis_separators=None,
+    ) -> Buffer:
+        return buffer_decl(
+            shape,
+            dtype=dtype,
+            data=data,
+            strides=strides,
+            elem_offset=elem_offset,
+            scope=scope,
+            align=align,
+            offset_factor=offset_factor,
+            buffer_type=buffer_type,
+            axis_separators=axis_separators,
+        )
+
+    def __getitem__(self, keys) -> Buffer:
+        if not isinstance(keys, tuple):
+            return self(keys)
+        if len(keys) >= 2 and not isinstance(keys[1], str):
+            return self(keys)
+        return self(*keys)  # pylint: disable=no-member # type: ignore
+
+
+class PtrProxy:
+    def __call__(self, dtype, storage_scope="global"):
+        if callable(dtype):
+            dtype = dtype().dtype
+        return ptr(dtype, storage_scope)  # pylint: disable=no-member # type: ignore
+
+    def __getitem__(self, keys):
+        if not isinstance(keys, tuple):
+            return self(keys)
+        return self(*keys)
+
+
+Buffer = BufferProxy()  # pylint: disable=invalid-name
+Ptr = PtrProxy()  # pylint: disable=invalid-name
diff --git a/python/tvm/script/parser/tir/operation.py b/python/tvm/script/parser/tir/operation.py
new file mode 100644
index 0000000000..716525b984
--- /dev/null
+++ b/python/tvm/script/parser/tir/operation.py
@@ -0,0 +1,84 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from typing import Type
+
+from tvm import tir
+from tvm.tir import IntImm
+
+from .._core import OpMethod, doc, register_op
+
+
+def _register_expr_op(ty: Type):  # pylint: disable=invalid-name
+    ty._dispatch_type = ty  # pylint: disable=protected-access
+
+    def _and(a, b):
+        if isinstance(a, bool):
+            a = IntImm("bool", a)
+        if isinstance(b, bool):
+            b = IntImm("bool", b)
+        return tir.And(a, b)
+
+    def _or(a, b):
+        if isinstance(a, bool):
+            a = IntImm("bool", a)
+        if isinstance(b, bool):
+            b = IntImm("bool", b)
+        return tir.Or(a, b)
+
+    def r(op: Type, i: int, m: OpMethod):  # pylint: disable=invalid-name
+        register_op(ty, op, i)(m)
+
+    for i in [0, 1]:
+        # Case 1. binop
+        r(doc.Add, i, tir.Add)
+        r(doc.Sub, i, tir.Sub)
+        r(doc.Mult, i, tir.Mul)
+        r(doc.Div, i, tir.Div)
+        r(doc.FloorDiv, i, tir.FloorDiv)
+        r(doc.Mod, i, tir.FloorMod)
+        r(doc.LShift, i, lambda a, b: a << b)
+        r(doc.RShift, i, lambda a, b: a >> b)
+        r(doc.BitOr, i, lambda a, b: a | b)
+        r(doc.BitXor, i, lambda a, b: a ^ b)
+        r(doc.BitAnd, i, lambda a, b: a & b)
+        # doc.MatMult <-- not implemented
+        # doc.Pow <-- not implemented
+        # Case 2. cmpop
+        r(doc.Eq, i, tir.EQ)
+        r(doc.NotEq, i, tir.NE)
+        r(doc.Lt, i, tir.LT)
+        r(doc.LtE, i, tir.LE)
+        r(doc.Gt, i, tir.GT)
+        r(doc.GtE, i, tir.GE)
+        # doc.Is <-- not implemented
+        # doc.IsNot <-- not implemented
+        # doc.In <-- not implemented
+        # doc.NotIn <-- not implemented
+        # Case 3. boolop
+        r(doc.And, i, _and)
+        r(doc.Or, i, _or)
+    for i in [0]:
+        #  Case 4. unaryop
+        r(doc.Invert, i, lambda a: ~a)
+        r(doc.Not, i, tir.Not)
+        r(doc.UAdd, i, lambda a: +a)
+        r(doc.USub, i, lambda a: -a)
+
+
+_register_expr_op(tir.PrimExpr)
+_register_expr_op(tir.IterVar)
diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py
new file mode 100644
index 0000000000..351238c06f
--- /dev/null
+++ b/python/tvm/script/parser/tir/parser.py
@@ -0,0 +1,268 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import contextlib
+from functools import partial
+from typing import Any
+
+from tvm.ir import PrimType
+from tvm.tir import Buffer, IterVar, PrimExpr, Var
+
+from ...ir_builder import tir as T
+from ...ir_builder.base import IRBuilderFrame as Frame
+from ...ir_builder.base import name
+from .._core import Parser, dispatch, doc
+
+
+def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
+    if isinstance(value, (list, tuple)):
+        for i, v in enumerate(value):
+            bind_with_value(self, node, f"{var_name}_{i}", v)
+        return value
+    elif isinstance(value, (Buffer, Var)):
+        name(var_name, value)
+        return value
+    else:
+        self.report_error(node, f"Do not know how to bind type: {type(value)} in with statement")
+        raise NotImplementedError
+
+
+def bind_for_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
+    if isinstance(value, (list, tuple)):
+        for i, v in enumerate(value):
+            bind_with_value(self, node, f"{var_name}_{i}", v)
+        return value
+    elif isinstance(value, Var):
+        name(var_name, value)
+        return value
+    else:
+        self.report_error(node, f"Do not know how to bind type: {type(value)} in for statement")
+        raise NotImplementedError
+
+
+def bind_assign_value(self: Parser, _node: doc.expr, var_name: str, value: Any) -> Any:
+    if isinstance(value, T.inline):
+        return value.value
+    elif isinstance(value, (list, tuple)):
+        for i, v in enumerate(value):
+            bind_with_value(self, _node, f"{var_name}_{i}", v)
+        return value
+    elif isinstance(value, Frame):
+        value.add_callback(partial(value.__exit__, None, None, None))
+        res = value.__enter__()
+        name(var_name, res)
+        return res
+    elif isinstance(value, (Buffer, IterVar)) or (
+        isinstance(value, Var) and not self.var_table.exist(value)
+    ):
+        name(var_name, value)
+        return value
+    elif isinstance(value, PrimExpr):
+        var = T.var(value.dtype)
+        name(var_name, var)
+        frame = T.let(var, value)
+        frame.add_callback(partial(frame.__exit__, None, None, None))
+        frame.__enter__()
+        return var
+    return value
+
+
+@dispatch.register(token="tir", type_name="For")
+def visit_for(self: Parser, node: doc.For) -> None:
+    for_frame = self.eval_expr(node.iter)
+    if not isinstance(for_frame, T.frame.ForFrame):
+        self.report_error(
+            node.iter,
+            "Expect the for loop to be one of the following: "
+            "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding",
+        )
+    with self.var_table.with_frame():
+        with for_frame as iters:
+            self.eval_assign(target=node.target, source=iters, bind_value=bind_for_value)
+            self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="While")
+def visit_while(self: Parser, node: doc.While) -> None:
+    with self.var_table.with_frame():
+        cond = self.eval_expr(node.test)
+        with T.While(cond):
+            self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="Assign")
+def visit_assign(self: Parser, node: doc.Assign) -> None:
+    if len(node.targets) != 1:
+        self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.")
+    lhs = node.targets[0]
+    rhs = self.eval_expr(node.value)
+    if isinstance(lhs, doc.Subscript):
+        if isinstance(lhs.slice, doc.Tuple):
+            indices = []
+            for index in lhs.slice.elts:
+                indices.append(self.eval_expr(index))
+        else:
+            indices = [self.eval_expr(lhs.slice)]
+        T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
+    else:
+        self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
+
+
+@dispatch.register(token="tir", type_name="AugAssign")
+def visit_aug_assign(self: Parser, node: doc.AugAssign) -> None:
+    lhs_pos = (
+        node.target.lineno,
+        node.target.col_offset,
+        node.target.end_lineno,
+        node.target.end_col_offset,
+    )
+    rhs_pos = (
+        node.value.lineno,
+        node.value.col_offset,
+        node.value.end_lineno,
+        node.value.end_col_offset,
+    )
+    node.target.ctx = doc.Load(*lhs_pos)
+    with self.var_table.with_frame():
+        lhs_name = "__tvm_tmp_value_aug_assign_lhs"
+        rhs_name = "__tvm_tmp_value_aug_assign_rhs"
+        lhs_expr = self.eval_expr(node.target)
+        rhs_expr = self.eval_expr(node.value)
+        self.var_table.add(lhs_name, lhs_expr)
+        self.var_table.add(rhs_name, rhs_expr)
+        op = doc.BinOp(
+            doc.Name(lhs_name, doc.Load(*lhs_pos), *lhs_pos),
+            node.op,
+            doc.Name(rhs_name, doc.Load(*rhs_pos), *rhs_pos),
+            *lhs_pos,
+        )
+        rhs = self.eval_expr(op)
+    lhs = node.target
+    lhs.ctx = doc.Store(*lhs_pos)
+    if isinstance(lhs, doc.Subscript):
+        if isinstance(lhs.slice, doc.Tuple):
+            indices = []
+            for index in lhs.slice.elts:
+                indices.append(self.eval_expr(index))
+        else:
+            indices = [self.eval_expr(lhs.slice)]
+        T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
+    else:
+        self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
+
+
+@dispatch.register(token="tir", type_name="AnnAssign")
+def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None:
+    lhs = node.target
+    rhs = self.eval_expr(node.value)
+    ann_var = self.visit_tvm_annotation(node.annotation)
+    if not isinstance(ann_var, Var):
+        self.report_error(node.annotation, "Annotation should be Var")
+    self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value)
+    frame = T.let(ann_var, rhs)
+    frame.add_callback(partial(frame.__exit__, None, None, None))
+    frame.__enter__()
+
+
+@dispatch.register(token="tir", type_name="With")
+def visit_with(self: Parser, node: doc.With) -> None:
+    with contextlib.ExitStack() as stack:
+        stack.enter_context(self.var_table.with_frame())
+        for item in node.items:
+            frame = self.eval_expr(item.context_expr)
+            if not isinstance(frame, Frame):
+                self.report_error(
+                    item.context_expr, "Invalid context expression in the with-statement."
+                )
+            rhs = stack.enter_context(frame)
+            if item.optional_vars is not None:
+                self.eval_assign(target=item.optional_vars, source=rhs, bind_value=bind_with_value)
+        self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="FunctionDef")
+def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
+    with self.var_table.with_frame():
+        self.var_table.add("range", T.serial)
+        with T.prim_func():
+            T.func_name(node.name)
+            if node.returns is not None:
+                ret_type = self.eval_expr(node.returns)
+                if callable(ret_type):
+                    ret_type = PrimType(ret_type().dtype)
+                T.func_ret(ret_type)
+            with self.with_dispatch_token("tir"):
+                self.visit(node.args)
+                self.visit_body(node.body)
+
+
+@dispatch.register(token="tir", type_name="arguments")
+def visit_arguments(self: Parser, node: doc.arguments) -> None:
+    # TODO: handle different types of arguments:
+    # - vararg: arg | None
+    # - kwonlyargs: list[arg]
+    # - kw_defaults: list[expr | None]
+    # - kwarg: arg | None
+    # - defaults: list[expr]
+    # - posonlyargs: list[arg]
+    arg: doc.arg
+    for arg in node.args:
+        if arg.annotation is None:
+            self.report_error(arg, "Type annotation is required for function parameters.")
+        param = T.arg(arg.arg, self.visit_tvm_annotation(arg.annotation))
+        self.var_table.add(arg.arg, param)
+
+
+@dispatch.register(token="tir", type_name="tvm_annotation")
+def visit_tvm_annotation(self: Parser, node: doc.expr):
+    annotation = self.eval_expr(node)
+    if callable(annotation):
+        annotation = annotation()
+    return annotation
+
+
+@dispatch.register(token="tir", type_name="Expr")
+def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
+    res = self.eval_expr(node.value)
+    if isinstance(res, Frame):
+        res.add_callback(partial(res.__exit__, None, None, None))
+        res.__enter__()
+
+
+@dispatch.register(token="tir", type_name="If")
+def visit_if(self: Parser, node: doc.If) -> None:
+    with self.var_table.with_frame():
+        with T.If(self.eval_expr(node.test)):
+            with T.Then():
+                self.visit_body(node.body)
+            if node.orelse:
+                with T.Else():
+                    self.visit_body(node.orelse)
+
+
+@dispatch.register(token="tir", type_name="Assert")
+def visit_assert(self: Parser, node: doc.Assert) -> None:
+    cond = self.eval_expr(node.test)
+    msg = self.eval_expr(node.msg)
+    frame = T.Assert(cond, msg)
+    frame.add_callback(partial(frame.__exit__, None, None, None))
+    frame.__enter__()
+
+
+@dispatch.register(token="tir", type_name="Return")
+def visit_return(self: Parser, node: doc.Return) -> None:
+    self.report_error(node, "Return is not allowed.")
diff --git a/python/tvm/script/registry.py b/python/tvm/script/registry.py
deleted file mode 100644
index e7d90dd515..0000000000
--- a/python/tvm/script/registry.py
+++ /dev/null
@@ -1,62 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser Function Registry """
-# pylint: disable=inconsistent-return-statements, relative-beyond-top-level, import-outside-toplevel
-import types
-from typing import Union, Callable, Dict, Optional, Any
-
-
-class Registry(object):
-    """Registration map
-    All these maps are static
-    """
-
-    registrations: Dict[str, type] = dict()
-
-    @staticmethod
-    def lookup(name: str) -> Optional[Any]:
-        if name in Registry.registrations:
-            # every time we create a new handler
-            # since we may want to keep some local info inside it
-            return Registry.registrations[name]()
-        return None
-
-
-def register(inputs: Union[Callable, type]) -> type:
-    """Register Intrin/ScopeHandler/SpecialStmt"""
-    registration: type
-    if isinstance(inputs, types.FunctionType):
-        # is function
-        from .tir.intrin import Intrin
-
-        def create_new_intrin(func) -> type:
-            class NewIntrin(Intrin):
-                def __init__(self):
-                    super().__init__(func)
-
-            return NewIntrin
-
-        registration = create_new_intrin(inputs)
-    elif isinstance(inputs, type):
-        # is class
-        registration = inputs
-    else:
-        raise ValueError()
-
-    key: str = registration().signature()[0]
-    Registry.registrations[key] = registration
-    return registration
diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi
deleted file mode 100644
index a64eed055a..0000000000
--- a/python/tvm/script/tir/__init__.pyi
+++ /dev/null
@@ -1,487 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=redefined-builtin
-from typing import (
-    Any,
-    Callable,
-    ContextManager,
-    Dict,
-    Iterable,
-    Optional,
-    Tuple,
-    Union,
-    Sequence,
-    List,
-    Mapping,
-    overload,
-)
-from numbers import Number
-import builtins
-
-from tvm.tir.function import PrimFunc
-from tvm.tir import Range
-from tvm.runtime import Object
-from tvm.target import Target
-from .node import BufferSlice
-
-"""
-redefine types
-"""
-
-class PrimExpr:
-    def __init__(self: PrimExpr) -> None: ...
-    @overload
-    def __add__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
-    @overload
-    def __add__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    @overload
-    def __sub__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
-    @overload
-    def __sub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    @overload
-    def __mul__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
-    @overload
-    def __mul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    @overload
-    def __div__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ...
-    @overload
-    def __div__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    def __mod__(self: PrimExpr, other: Union[int, float, PrimExpr]) -> PrimExpr: ...
-    def __radd__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    def __rsub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    def __rmul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    def __rdiv__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ...
-    def __floordiv__(self: PrimExpr, other: Union[int, float, PrimExpr]) -> PrimExpr: ...
-    def __index__(self: PrimExpr) -> int: ...  # so range doesn't complain
-
-class Var(PrimExpr): ...
-class IterVar(Var): ...
-
-class Buffer:
-    @overload
-    def __getitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int, slice]]) -> PrimExpr: ...
-    @overload
-    def __getitem__(self: Buffer, pos: Union[PrimExpr, int, slice]) -> PrimExpr: ...
-    @overload
-    def __setitem__(
-        self: Buffer, pos: Sequence[Union[PrimExpr, int, slice]], value: PrimExpr
-    ) -> None: ...
-    @overload
-    def __setitem__(self: Buffer, pos: Union[PrimExpr, int, slice], value: PrimExpr) -> None: ...
-    @property
-    def data(self: Buffer) -> Ptr: ...
-
-"""
-Intrinsic
-"""
-
-def min_value(dtype: str) -> PrimExpr: ...
-def max_value(dtype: str) -> PrimExpr: ...
-def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
-def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
-def ceildiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
-def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
-def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
-def abs(x: PrimExpr) -> PrimExpr: ...
-def load(
-    dtype: str, var: Var, index: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = None
-) -> PrimExpr: ...
-def cast(value: PrimExpr, dtype: str) -> PrimExpr: ...
-def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ...
-def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ...
-def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ...
-def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
-def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ...
-def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ...
-def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
-def evaluate(value: PrimExpr) -> None: ...
-def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
-def vectorlow(value: PrimExpr, dtype: str) -> PrimExpr: ...
-def vectorhigh(value: PrimExpr, dtype: str) -> PrimExpr: ...
-def store(
-    var: Var, index: PrimExpr, value: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = True
-) -> None: ...
-def comm_reducer(lambda_io: Callable[[Any, Any], Any], identities: List[PrimExpr]) -> PrimExpr: ...
-def llvm_lookup_intrinsic_id(name: str) -> PrimExpr: ...
-def preflattened_buffer(
-    buf: Buffer,
-    shape: Sequence[PrimExpr],
-    dtype: str = "float32",
-    data: Optional[Ptr] = None,
-    strides: Optional[Sequence[int]] = None,
-    elem_offset: Optional[int] = None,
-    scope: str = "global",
-    align: int = -1,
-    offset_factor: int = 0,
-    buffer_type: str = "default",
-) -> Buffer: ...
-
-"""
-Intrinsics - tvm builtin
-"""
-
-def tvm_thread_allreduce(
-    *freduceargs: Union[PrimExpr, builtins.bool, Ptr], dtype: str
-) -> PrimExpr: ...
-
-"""
-Unary operator
-Note that any intrinsics not registered in script.tir.intrin
-should add "dtype" as an argument. This is different from their
-definition but intentional.
-"""
-
-def exp(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def exp2(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def exp10(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def erf(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def tanh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def sigmoid(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def log(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def log2(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def log10(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def log1p(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def tan(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def cos(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def cosh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def acos(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def acosh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def sin(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def sinh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def asin(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def asinh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def atan(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def atanh(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def atan2(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def sqrt(x: PrimExpr, dtype: str) -> PrimExpr: ...
-def rsqrt(x: PrimExpr, dtype: str) -> PrimExpr: ...
-
-"""
-special_stmt - Buffers
-"""
-
-def match_buffer(
-    param: Union[Var, BufferSlice],
-    shape: Sequence[Union[PrimExpr, int]],
-    dtype: str = "float32",
-    data: Var = None,
-    strides: Optional[Sequence[int]] = None,
-    elem_offset: Optional[int] = None,
-    scope: str = "global",
-    align: int = -1,
-    offset_factor: int = 0,
-    buffer_type: str = "default",
-    axis_separators: Optional[List[int]] = None,
-) -> Buffer: ...
-def decl_buffer(
-    shape: Sequence[Union[PrimExpr, int]],
-    dtype: str = "float32",
-    data: Var = None,
-    strides: Optional[Sequence[int]] = None,
-    elem_offset: Optional[int] = None,
-    scope: str = "global",
-    align: int = -1,
-    offset_factor: int = 0,
-    buffer_type: str = "default",
-    axis_separators: Optional[List[int]] = None,
-) -> Buffer: ...
-def buffer_decl(
-    shape: Sequence[Union[PrimExpr, int]],
-    dtype: str = "float32",
-    data: Var = None,
-    strides: Optional[Sequence[int]] = None,
-    elem_offset: Optional[int] = None,
-    scope: str = "global",
-    align: int = -1,
-    offset_factor: int = 0,
-    buffer_type: str = "default",
-    axis_separators: Optional[List[int]] = None,
-) -> Buffer: ...
-def alloc_buffer(
-    shape: Sequence[Union[PrimExpr, int]],
-    dtype: str = "float32",
-    data: Var = None,
-    strides: Optional[Sequence[int]] = None,
-    elem_offset: Optional[int] = None,
-    scope: str = "global",
-    align: int = -1,
-    offset_factor: int = 0,
-    buffer_type: str = "default",
-    axis_separators: Optional[List[int]] = None,
-) -> Buffer: ...
-
-"""
-special_stmt - Reads/Writes
-"""
-
-@overload
-def reads(read_regions: List[BufferSlice]) -> None: ...
-@overload
-def reads(*read_regions: BufferSlice) -> None: ...
-@overload
-def writes(write_region: List[BufferSlice]) -> None: ...
-@overload
-def writes(*write_region: BufferSlice) -> None: ...
-def block_attr(attrs: Mapping[str, Object]) -> None: ...
-
-"""
-special_stmt - Axis
-"""
-
-class axis:
-    @overload
-    @staticmethod
-    def spatial(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def spatial(
-        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
-    ) -> IterVar: ...
-    @overload
-    @staticmethod
-    def S(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def S(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def reduce(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def reduce(
-        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
-    ) -> IterVar: ...
-    @overload
-    @staticmethod
-    def R(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def R(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def scan(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def scan(
-        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
-    ) -> IterVar: ...
-    @overload
-    @staticmethod
-    def opaque(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ...
-    @overload
-    @staticmethod
-    def opaque(
-        dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr
-    ) -> IterVar: ...
-    @staticmethod
-    def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ...
-
-def get_axis(begin: PrimExpr, end: PrimExpr, iter_type: int) -> IterVar: ...
-
-"""
-special_stmt - Annotations
-"""
-
-def buffer_var(dtype: str, storage_scope: str) -> Var: ...
-def func_attr(attrs: Mapping[str, Union[Object, str, bool, int, float]]) -> None: ...
-def prim_func(input_func: Callable) -> PrimFunc: ...
-
-"""
-special_stmt - Threads and Bindings
-"""
-
-def env_thread(env_name: str) -> IterVar: ...
-def bind(iter_var: IterVar, expr: PrimExpr) -> None: ...
-
-"""
-Scope handler
-"""
-
-class block(ContextManager):
-    def __init__(self, name_hint: str = "") -> None: ...
-    def __enter__(self) -> Sequence[IterVar]: ...
-
-class init(ContextManager):
-    def __init__(self) -> None: ...
-
-class let(ContextManager):
-    def __init__(self, var: Var, value: PrimExpr) -> None: ...
-
-def where(cond: PrimExpr) -> None: ...
-def allocate(
-    extents: List[PrimExpr],
-    dtype: str,
-    scope: str,
-    condition: Union[PrimExpr, builtins.bool] = True,
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Buffer: ...
-def launch_thread(env_var: Var, extent: Union[int, PrimExpr]) -> Var: ...
-def realize(
-    buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True
-) -> None: ...
-def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ...
-def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr: ...
-
-"""
-Scope handler - Loops
-"""
-
-@overload
-def serial(
-    begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def serial(
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def parallel(
-    begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def parallel(
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def vectorized(
-    begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def vectorized(
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def unroll(
-    begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def unroll(
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def thread_binding(
-    begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int],
-    thread: str,
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def thread_binding(
-    end: Union[PrimExpr, int],
-    thread: str,
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def for_range(
-    begin: Union[PrimExpr, int],
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-@overload
-def for_range(
-    end: Union[PrimExpr, int],
-    annotations: Optional[Mapping[str, Object]] = None,
-) -> Iterable[IterVar]: ...
-def grid(*extents: Union[PrimExpr, int]) -> Iterable[Sequence[IterVar]]: ...
-
-"""
-ty - redefine types
-"""
-
-class boolean: ...
-
-class handle(Var):
-    @overload
-    def __getitem__(self: handle, pos: Sequence[Union[int, PrimExpr, slice]]) -> Buffer: ...
-    @overload
-    def __getitem__(self: handle, pos: Union[int, PrimExpr, slice]) -> Buffer: ...
-    @overload
-    def __setitem__(
-        self: handle, pos: Sequence[Union[int, PrimExpr, slice]], value: Buffer
-    ) -> None: ...
-    @overload
-    def __setitem__(self: handle, pos: Union[int, PrimExpr, slice], value: Buffer) -> None: ...
-    @property
-    def data(self: handle) -> Ptr: ...
-
-class Ptr: ...
-
-def target(target_str: Union[str, Mapping[str, Object]]) -> Target: ...
-
-class var(Var):
-    def __init__(self: Var, dtype: str): ...
-
-class bool(PrimExpr):
-    def __init__(self: bool, imm: Union[PrimExpr, builtins.bool, builtins.int]): ...
-
-class int8(PrimExpr):
-    def __init__(self: int8, imm: Union[PrimExpr, int]): ...
-
-class int16(PrimExpr):
-    def __init__(self: int16, imm: Union[PrimExpr, int]): ...
-
-class int32(PrimExpr):
-    def __init__(self: int32, imm: Union[PrimExpr, int]): ...
-
-class int64(PrimExpr):
-    def __init__(self: int64, imm: Union[PrimExpr, int]): ...
-
-class uint8(PrimExpr):
-    def __init__(self: uint8, imm: Union[PrimExpr, int]): ...
-
-class uint16(PrimExpr):
-    def __init__(self: uint16, imm: Union[PrimExpr, int]): ...
-
-class uint32(PrimExpr):
-    def __init__(self: uint32, imm: Union[PrimExpr, int]): ...
-
-class uint64(PrimExpr):
-    def __init__(self: uint64, imm: Union[PrimExpr, int]): ...
-
-# use typing.Literal instead for python 3.8 or higher
-import sys
-
-if sys.version_info >= (3, 8):
-    from typing import Literal
-
-    SpecialFloatLiteral = Literal["inf", "-inf", "nan"]
-else:
-    SpecialFloatLiteral = str
-
-class float8(PrimExpr):
-    def __init__(self: float8, imm: Union[PrimExpr, int, float, SpecialFloatLiteral]): ...
-
-class float16(PrimExpr):
-    def __init__(self: float16, imm: Union[PrimExpr, int, float, SpecialFloatLiteral]): ...
-
-class float32(PrimExpr):
-    def __init__(self: float32, imm: Union[PrimExpr, int, float, SpecialFloatLiteral]): ...
-
-class float64(PrimExpr):
-    def __init__(self: float64, imm: Union[PrimExpr, int, float, SpecialFloatLiteral]): ...
diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py
deleted file mode 100644
index bd9aa1fdad..0000000000
--- a/python/tvm/script/tir/intrin.py
+++ /dev/null
@@ -1,231 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser Intrinsic Classes"""
-# pylint: disable=redefined-builtin, relative-beyond-top-level
-import builtins
-from typing import List, Any
-
-import tvm.tir
-from tvm.tir import FloatImm
-from ..registry import register
-from ...target import codegen
-from ..utils import get_param_list, tvm_span_from_synr
-
-
-class Intrin:
-    def __init__(self, intrin, stmt=False):
-        self.intrin = intrin
-        self.stmt = stmt
-
-    def signature(self):
-        return "tir." + self.intrin.__name__, get_param_list(self.intrin)
-
-    def handle(self, arg_list: List[Any], span: tvm.ir.Span):
-        return self.intrin(*arg_list, span=tvm_span_from_synr(span))
-
-
-@register
-def bool(imm, span):
-    return imm.astype("bool", span)
-
-
-# register all datatypes
-for _dtype in ["float", "uint", "int"]:
-    for _size in ["8", "16", "32", "64"]:
-        for _lanes in ["", "x4", "x8", "x16", "x32"]:
-            _name = _dtype + _size + _lanes
-
-            # nest closures so we copy the name string
-            def wrap(name):
-                def f(imm, span):
-                    if name.startswith("float"):
-                        if imm in {"inf", "-inf", "nan"}:
-                            return FloatImm(dtype=name, value=float(imm), span=span)
-                    return imm.astype(name, span)
-
-                f.__name__ = name
-                return f
-
-            _intrin = wrap(_name)
-            register(_intrin)
-
-
-@register
-def min_value(dtype, span):
-    return tvm.tir.min_value(dtype, span)
-
-
-@register
-def max_value(dtype, span):
-    return tvm.tir.max_value(dtype, span)
-
-
-@register
-def floordiv(x, y, span):
-    return tvm.tir.floordiv(x, y, span)
-
-
-@register
-def floormod(x, y, span):
-    return tvm.tir.floormod(x, y, span)
-
-
-@register
-def truncmod(x, y, span):
-    return tvm.tir.truncmod(x, y, span)
-
-
-@register
-def truncdiv(x, y, span):
-    return tvm.tir.truncdiv(x, y, span)
-
-
-@register
-def ceildiv(x, y, span):
-    return tvm.tir.ceildiv(x, y, span)
-
-
-@register
-def abs(x, span):
-    return tvm.tir.abs(x, span)
-
-
-@register
-def load(dtype, var, index, predicate=None, span=None):
-    return tvm.tir.Load(dtype, var, index, predicate, span)
-
-
-@register
-def cast(value, dtype, span):
-    return tvm.tir.Cast(dtype, value, span)
-
-
-@register
-def ramp(base, stride, lanes, span):
-    return tvm.tir.Ramp(base, stride, lanes.value, span)
-
-
-@register
-def broadcast(value, lanes, span):
-    return tvm.tir.Broadcast(value, lanes.value, span)
-
-
-@register
-def iter_var(var, dom, iter_type, thread_tag, span):
-    iter_type = getattr(tvm.tir.IterVar, iter_type)
-    return tvm.tir.IterVar(dom, var, iter_type, thread_tag, span)
-
-
-@register
-def max(a, b, span):  # pylint: disable=redefined-builtin
-    return tvm.tir.Max(a, b, span)
-
-
-@register
-def min(a, b, span):  # pylint: disable=redefined-builtin
-    return tvm.tir.Min(a, b, span)
-
-
-def get_axis(begin, end, iter_type, span):
-    ana = tvm.arith.Analyzer()
-    extent = ana.simplify(end - begin)
-    block_var_dom = tvm.ir.Range.from_min_extent(begin, extent)
-
-    iter_type_dict = {"data_par": 0, "reduce": 2, "scan": 3, "opaque": 4}
-    return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type], span=span)
-
-
-@register
-def range(begin, end, span):
-    return get_axis(begin, end, "data_par", span)
-
-
-@register
-def reduce_axis(begin, end, span):
-    return get_axis(begin, end, "reduce", span)
-
-
-@register
-def scan_axis(begin, end, span):
-    return get_axis(begin, end, "scan", span)
-
-
-@register
-def opaque_axis(begin, end, span):
-    return get_axis(begin, end, "opaque", span)
-
-
-@register
-def Select(cond, if_body, else_body, span):  # pylint: disable=invalid-name
-    return tvm.tir.Select(cond, if_body, else_body, span)
-
-
-@register
-def Let(var, value, body, span):  # pylint: disable=invalid-name
-    return tvm.tir.Let(var, value, body, span)
-
-
-@register
-class EvaluateIntrin(Intrin):
-    def __init__(self):
-        def evaluate(value, span):
-            return tvm.tir.Evaluate(value, span)
-
-        super().__init__(evaluate, stmt=True)
-
-
-@register
-class StoreIntrin(Intrin):
-    def __init__(self):
-        def store(var, index, value, predicate=True, span=None):
-            return tvm.tir.Store(var, value, index, predicate, span)
-
-        super().__init__(store, stmt=True)
-
-
-@register
-class AssumeIntrin(Intrin):
-    def __init__(self):
-        def assume(constraint, span):
-            return tvm.tir.Evaluate(
-                tvm.tir.call_intrin("bool", "tir.assume", constraint, span=span)
-            )
-
-        super().__init__(assume, stmt=True)
-
-
-@register
-def comm_reducer(lambda_io, identities, span):
-    """Create a CommReducer from lambda inputs/outputs and the identities"""
-    lambda_input = lambda_io[0]
-    lambda_output = lambda_io[1]
-
-    num_args = len(lambda_input)
-    num_arg_per_group = num_args // 2
-    x = [lambda_input[i] for i in builtins.range(0, num_arg_per_group)]
-    y = [lambda_input[i] for i in builtins.range(num_arg_per_group, num_args)]
-
-    if not isinstance(lambda_output, tuple):
-        lambda_output = (lambda_output,)
-
-    return tvm.tir.CommReducer(x, y, lambda_output, identities, span)
-
-
-@register
-def llvm_lookup_intrinsic_id(name, span):
-    # pylint: disable=unused-argument
-    return codegen.llvm_lookup_intrinsic_id(name)
diff --git a/python/tvm/script/tir/node.py b/python/tvm/script/tir/node.py
deleted file mode 100644
index 29e79607fb..0000000000
--- a/python/tvm/script/tir/node.py
+++ /dev/null
@@ -1,218 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=redefined-builtin
-"""TVM Script nodes."""
-
-from typing import Optional, Union, List, Callable
-import synr
-from tvm.arith import Analyzer
-from tvm.runtime import ObjectGeneric, convert
-from tvm.tir import PrimExpr, Buffer, BufferLoad, IntImm, Ramp, BufferRegion
-from tvm.ir import Span, Range
-
-
-class Slice:
-    """A helper class to present slice information for BufferSlice
-
-    Parameters
-    ----------
-    start : Union[PrimExpr, int]
-        The start index.
-
-    stop : Optional[Union[PrimExpr, int]]
-        The stop index, None means the Slice is an element-wise index
-
-    step : int
-        The slice step
-
-    span : Optional[Span]
-        The location of the slice in the source.
-    """
-
-    start: Union[PrimExpr, int]
-    stop: Optional[Union[PrimExpr, int]]
-    step: int
-    span: Optional[Span]
-
-    def __init__(
-        self,
-        start: Union[PrimExpr, int],
-        stop: Optional[Union[PrimExpr, int]] = None,
-        step: int = 1,
-        span: Optional[Span] = None,
-    ):
-        self.start = start
-        self.stop = stop
-        self.step = step
-        self.span = span
-
-    def as_index_expr(self, report_error: Callable[[str, Union[Span, synr.ast.Span]], None]):
-        """Helper to create index PrimExpr from slice object
-        Parameters
-        ----------
-        report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
-            The error report func
-        """
-        if self.stop is None:
-            # scalar index
-            return self.start
-        if self.step < 1:
-            report_error("Slice's step should be positive integer", self.span)
-        lanes = Analyzer().simplify((self.stop - self.start + self.step - 1) // self.step)
-        if not isinstance(lanes, (int, IntImm)):
-            report_error("Slice's lanes should be constant for buffer indices", self.span)
-        if lanes == 1:
-            return self.start
-        return Ramp(self.start, self.step, int(lanes), self.span)
-
-
-class BufferSlice(ObjectGeneric):
-    """A generic object for representing general buffer access. Following cases are supported:
-        - element wise access buffer[i, j], which can be converted to BufferLoad if necessary
-        - slice access buffer[i: i + 1, j : j + 2]
-        - union of element and slice buffer[i, j: j + 2]
-
-        This node is used in TVMScript to parse BufferLoad, BufferRegion and Realize
-
-    Parameters
-    ----------
-    buffer : Buffer
-        The buffer.
-
-    indices : List[Union[Slice, PrimExpr, int]]
-        The access indexes can be slice, PrimExpr or int.
-
-    report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
-        The error report func
-
-    span : Optional[Span]
-        The location of the buffer access in the source.
-    """
-
-    buffer: Buffer
-    slices: List[Slice]
-    report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
-    span: Optional[Span]
-
-    def __init__(
-        self,
-        buffer: Buffer,
-        indices: List[Union[Slice, PrimExpr, int]],
-        report_error: Callable[[str, Union[Span, synr.ast.Span]], None],
-        span: Optional[Span] = None,
-    ):
-        def check_index(index: Union[int, PrimExpr]):
-            """Check input index is non-negative integer or PrimExpr"""
-            if isinstance(index, int):
-                if index < 0:
-                    report_error("Negative index is not allowed during buffer access", span)
-            elif isinstance(index, PrimExpr):
-                element_dtype = index.dtype.split("x", maxsplit=1)[0]
-                if element_dtype[:3] != "int":
-                    report_error(
-                        "index expected an integer type PrimExpr but got " + str(index.dtype),
-                        index.span,
-                    )
-            else:
-                report_error(
-                    "Unsupported index type, expected int or tvm.tir.PrimExpr, but got "
-                    + str(type(index)),
-                    span,
-                )
-
-        slices: List[Union[Slice, BufferSlice]] = []
-        for index in indices:
-            if isinstance(index, Slice):
-                index.start, index.stop = [convert(_) for _ in [index.start, index.stop]]
-                check_index(index.start)
-                check_index(index.stop)
-                slices.append(index)
-            elif isinstance(index, (PrimExpr, int)):
-                check_index(index)
-                slices.append(Slice(index))
-            elif isinstance(index, BufferSlice):
-                buffer_load = index.asobject()
-                check_index(buffer_load)
-                slices.append(Slice(buffer_load))
-            else:
-                report_error(
-                    "Unsupported index type for BufferSlice, "
-                    + "expected int, tvm.tir.PrimExpr, tvm.tir.Slice, but got "
-                    + str(type(index)),
-                    span,
-                )
-
-        self.buffer = buffer
-        self.slices = slices
-        self.report_error = report_error
-        self.span = span
-
-    def __str__(self):
-        regions: List[str] = []
-        for s in self.slices:
-            if s.stop is None:
-                regions.append(str(s.start))
-            else:
-                regions.append(str(s.start) + ": " + str(s.stop))
-
-        return self.buffer.name + "[" + ", ".join(regions) + "]"
-
-    def asobject(self) -> BufferLoad:
-        """Convert object."""
-        indices = [s.as_index_expr(self.report_error) for s in self.slices]
-        return BufferLoad(self.buffer, indices, span=self.span)
-
-    def as_buffer_region(self, analyzer: Optional[Analyzer] = None) -> BufferRegion:
-        """Construct BufferRegion from BufferSlice
-
-        Parameters
-        ----------
-        analyzer : Optional[tvm.arith.Analyzer]
-            The analyzer for simplifying. If not provided, the method will construct a new one
-
-        Returns
-        -------
-        buffer_region : BufferRegion
-            The constructed BufferRegion.
-        """
-        region: List[Range] = []
-        for s in self.slices:
-            start = s.start if isinstance(s.start, PrimExpr) else IntImm("int32", s.start)
-            extent = IntImm(start.dtype, 1) if s.stop is None else s.stop - s.start
-            if not analyzer:
-                analyzer = Analyzer()
-            if isinstance(extent, PrimExpr):
-                extent = analyzer.simplify(extent)
-            if s.step != 1:
-                self.report_error("BufferRegion do not support non-trivial stride", s.span)
-            region.append(Range.from_min_extent(start, extent, span=s.span))
-        return BufferRegion(self.buffer, region)
-
-    def astype(self, dtype: str, span: Optional[Span] = None) -> PrimExpr:
-        return self.asobject().astype(dtype, span)
-
-    @property
-    def dtype(self) -> str:
-        """Return the dtype referenced by the slice.
-
-        Implemented as a property so that ``slice.dtype`` has the same
-        calling convention as ``primexpr.dtype``.  This allows a
-        BufferSlice object can be assigned to a variable without
-        requiring a type annotation on the variable, similar to other
-        expressions.
-        """
-        return self.asobject().dtype
diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py
deleted file mode 100644
index 1d2550eecd..0000000000
--- a/python/tvm/script/tir/scope_handler.py
+++ /dev/null
@@ -1,793 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser Scope Handler Classes"""
-# pylint: disable=redefined-builtin, unused-argument, invalid-name, relative-beyond-top-level
-from typing import Tuple, Any, Callable, Optional, List, Union, Mapping
-
-import synr
-import numpy as np
-import tvm.tir
-from tvm.runtime import Object, String, convert
-from tvm.ir import Span, Range
-from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind, IntImm
-
-from .node import BufferSlice
-
-from ..context_maintainer import ContextMaintainer
-from ..registry import register
-from ..utils import (
-    get_param_list,
-    tvm_span_from_synr,
-    call_with_error_reporting,
-)
-
-
-class ScopeHandler:
-    """Base class for all scope handlers"""
-
-    def __init__(self, func: Callable):
-        self.func: Callable = func
-        self.body: Optional[Stmt] = None
-        self.node: Optional[synr.ast.Node] = None
-        self.context: Optional[ContextMaintainer] = None
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir." + self.func.__name__, get_param_list(self.func)
-
-    def enter_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        pass
-
-    def exit_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        self.node = node
-        self.context = context
-        return call_with_error_reporting(
-            context.report_error, span, self.func, *arg_list, span=tvm_span_from_synr(span)
-        )
-
-
-class WithScopeHandler(ScopeHandler):
-    """Base class for all with scope handlers"""
-
-    def __init__(self, func, concise_scope, def_symbol):
-        super().__init__(func)
-        self.concise_scope = concise_scope
-        self.def_symbol = def_symbol
-
-    @staticmethod
-    def get_optional_vars(node, context):
-        """Get a list synr.ast.With's optional_vars"""
-        assert isinstance(
-            node, synr.ast.With
-        ), f"WithScopeHandler expected synr.ast.With but got {type(node)}"
-
-        if isinstance(node.lhs, list):
-            for var in node.lhs:
-                if not isinstance(var, synr.ast.Var):
-                    context.report_error(
-                        f"Invalid optional var definition, expected Var but got {type(var)}",
-                        node.span,
-                    )
-            vars = node.lhs
-        else:
-            context.report_error(
-                f"Invalid optional var definition, expected list of Var but got {type(node.lhs)}",
-                node.span,
-            )
-        return vars
-
-
-@register
-class Allocate(WithScopeHandler):
-    """With scope handler T.allocate(extents, dtype, scope, condition, annotations)"""
-
-    def __init__(self):
-        def allocate(extents, dtype, scope, condition=True, annotations=None, span=None):
-            condition = tvm.runtime.convert(condition)
-            scope = tvm.runtime.convert(scope)
-
-            return tvm.tir.Allocate(
-                self.buffer_var,
-                dtype,
-                extents,
-                condition,
-                self.body,
-                annotations=annotations,
-                span=span,
-            )
-
-        super().__init__(allocate, concise_scope=True, def_symbol=True)
-        self.buffer_var = None
-
-    def enter_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        # define buffer vars in symbol table
-        if isinstance(node, synr.ast.With):
-            vars = WithScopeHandler.get_optional_vars(node, context)
-            if len(vars) != 1:
-                context.report_error(f"Unexpected number of vars: 1 vs. {len(vars)}", node.span)
-            name = vars[0].id.name
-            var_span = vars[0].id.span
-        elif isinstance(node, synr.ast.Assign):
-            if len(node.lhs) != 1:
-                context.report_error(f"Unexpected number of vars: 1 vs. {len(node.lhs)}", node.span)
-            name = node.lhs[0].id.name
-            var_span = node.lhs[0].id.span
-        else:
-            raise Exception("Internal Bug")
-
-        def setup_buffer_var(
-            extents, dtype, scope, condition=True, annotations=None, span: Span = None
-        ):
-            """Setup buffer var for a given type."""
-            buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), scope)
-            self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)
-
-        setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span))
-        context.update_symbol(name, self.buffer_var, node)
-
-
-@register
-class AllocateConst(WithScopeHandler):
-    """With scope handler T.allocate_const(data, extents, dtype, condition)
-
-    TIR constant node to represent non-scalar constant
-    """
-
-    def __init__(self):
-        def allocate_const(raw_data, dtype, shape, annotations=None, span=None):
-            list_data = []
-            for i in raw_data:
-                list_data.append(i.value)
-            nd_data = tvm.nd.array(np.asarray(list_data, dtype=dtype))
-            n = tvm.tir.AllocateConst(
-                self.buffer_var,
-                dtype,
-                shape,
-                nd_data,
-                self.body,
-                annotations=annotations,
-                span=span,
-            )
-            return n
-
-        super().__init__(allocate_const, concise_scope=True, def_symbol=True)
-        self.buffer_var = None
-
-    def enter_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        # define buffer vars in symbol table
-        if isinstance(node, synr.ast.With):
-            vars = WithScopeHandler.get_optional_vars(node, context)
-            if len(vars) != 1:
-                context.report_error(f"Unexpected number of vars: 1 vs. {len(vars)}", node.span)
-            name = vars[0].id.name
-            var_span = vars[0].id.span
-        elif isinstance(node, synr.ast.Assign):
-            if len(node.lhs) != 1:
-                context.report_error(f"Unexpected number of vars: 1 vs. {len(node.lhs)}", node.span)
-            name = node.lhs[0].id.name
-            var_span = node.lhs[0].id.span
-        else:
-            raise Exception("Internal Bug")
-
-        def setup_buffer_var(data, dtype, shape, annotations: dict = None, span: Span = None):
-            """Setup buffer var for a given type."""
-            buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype))
-            self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)
-
-        setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span))
-        context.update_symbol(name, self.buffer_var, node)
-
-
-@register
-class DeclBuffer(WithScopeHandler):
-    """Special Stmt decl_buffer(shape, dtype, data, strides, elem_offset, scope, align,
-                                offset_factor, buffer_type, axis_separators)
-    Example
-    -------
-    .. code-block:: python
-        A = T.decl_buffer((128, 128), dtype="float32")
-    """
-
-    def __init__(self):
-        def decl_buffer(
-            shape,
-            dtype="float32",
-            data=None,
-            strides=None,
-            elem_offset=None,
-            scope="global",
-            align=-1,
-            offset_factor=0,
-            buffer_type="default",
-            axis_separators=None,
-            span=None,
-        ):
-            decl_buffer = tvm.tir.DeclBuffer(self.buffer, self.body, span=span)
-            if data is None:
-                # when data is not specified, the buffer is implicitly allocated
-                return tvm.tir.Allocate(
-                    self.buffer.data,
-                    dtype,
-                    shape,
-                    tvm.runtime.convert(True),
-                    decl_buffer,
-                    span=span,
-                )
-            return decl_buffer
-
-        super().__init__(decl_buffer, concise_scope=True, def_symbol=True)
-
-    def enter_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        # define buffer vars in symbol table
-        if isinstance(node, synr.ast.With):
-            vars = WithScopeHandler.get_optional_vars(node, context)
-            if len(vars) != 1:
-                context.report_error(f"Unexpected number of vars: 1 vs. {len(vars)}", node.span)
-            name = vars[0].id.name
-            var_span = vars[0].id.span
-        elif isinstance(node, synr.ast.Assign):
-            if len(node.lhs) != 1:
-                context.report_error(f"Unexpected number of vars: 1 vs. {len(node.lhs)}", node.span)
-            name = node.lhs[0].id.name
-            var_span = node.lhs[0].id.span
-        else:
-            raise Exception("Internal Bug")
-
-        def setup_buffer(
-            shape,
-            dtype,
-            data,
-            strides,
-            elem_offset,
-            scope,
-            align,
-            offset_factor,
-            buffer_type,
-            axis_separators,
-            span: Span = None,
-        ):
-            self.buffer = tvm.tir.decl_buffer(
-                shape=shape,
-                dtype=dtype,
-                data=data,
-                strides=strides,
-                elem_offset=elem_offset,
-                scope=scope,
-                data_alignment=align,
-                offset_factor=offset_factor,
-                buffer_type=buffer_type,
-                axis_separators=axis_separators,
-                name=name,
-                span=span,
-            )
-
-        setup_buffer(*arg_list, span=tvm_span_from_synr(var_span))
-        context.update_symbol(name, self.buffer, node)
-
-
-@register
-class LaunchThread(WithScopeHandler):
-    """With scope handler T.launch_thread(env_var, extent)"""
-
-    def __init__(self):
-        def launch_thread(env_var, extent, span):
-            extent = tvm.runtime.convert(extent, span=span)
-            thread_id = self.context.func_var_env_dict[env_var]
-            attr_key = "virtual_thread" if thread_id == "vthread" else "thread_extent"
-            return tvm.tir.AttrStmt(
-                IterVar(
-                    (0, extent),
-                    env_var,
-                    getattr(IterVar, "ThreadIndex"),
-                    thread_id,
-                    span=span,
-                ),
-                attr_key,
-                extent,
-                self.body,
-                span=span,
-            )
-
-        super().__init__(launch_thread, concise_scope=True, def_symbol=False)
-
-
-@register
-class Realize(WithScopeHandler):
-    """With scope handler T.realize(buffer_bounds, scope, condition)"""
-
-    def __init__(self):
-        def realize(
-            buffer_slice: BufferSlice, scope: str, condition: bool = True, span: bool = None
-        ):
-            assert self.context, "call 'exit_scope' before 'enter_scope'"
-            buffer: Buffer = buffer_slice.buffer
-            bounds: List[Range] = []
-            for s in buffer_slice.slices:
-                min: Union[PrimExpr, int] = s.start
-                extent: Union[PrimExpr, int] = 1 if s.stop is None else s.stop - s.start
-                if isinstance(extent, PrimExpr):
-                    extent = self.context.analyzer.simplify(extent)
-                bounds.append(Range.from_min_extent(min, extent, span=s.span))
-
-            scope = tvm.runtime.convert(scope, span=span)
-            return tvm.tir.AttrStmt(
-                buffer,
-                "realize_scope",
-                scope,
-                tvm.tir.BufferRealize(buffer, bounds, condition, self.body, span=span),
-                span=span,
-            )
-
-        super().__init__(realize, concise_scope=True, def_symbol=False)
-
-
-@register
-class Attr(WithScopeHandler):
-    """With scope handler T.attr(attr_node, attr_key, value)"""
-
-    def __init__(self):
-        def attr(attr_node, attr_key, value, span):
-            attr_node = tvm.runtime.convert(attr_node, span=span)
-            value = tvm.runtime.convert(value, span=span)
-            return tvm.tir.AttrStmt(attr_node, attr_key, value, self.body, span=span)
-
-        super().__init__(attr, concise_scope=True, def_symbol=False)
-
-
-@register
-class AssertHandler(WithScopeHandler):
-    """With scope handler T.Assert(condition, message)"""
-
-    def __init__(self):
-        def Assert(condition, message, span):
-            return tvm.tir.AssertStmt(condition, tvm.runtime.convert(message), self.body, span=span)
-
-        super().__init__(Assert, concise_scope=True, def_symbol=False)
-
-
-@register
-class Let(WithScopeHandler):
-    """With scope handler T.let(var, value)"""
-
-    def __init__(self):
-        def let(var, value, span):
-            return tvm.tir.LetStmt(var, value, self.body, span=span)
-
-        super().__init__(let, concise_scope=False, def_symbol=False)
-
-    def __call__(self, var: tvm.tir.Var, value: tvm.tir.PrimExpr, body: tvm.tir.PrimExpr):
-        return tvm.tir.Let(var, value, body)
-
-
-@register
-class Block(WithScopeHandler):
-    """With scope handler T.block(name)"""
-
-    def __init__(self):
-        def block(name_hint: str = "", span: Optional[Span] = None):
-            assert (
-                self.node and self.context and self.body
-            ), "call 'exit_scope' before 'enter_scope'"
-            block_info = self.context.block_info_stack[-1]
-
-            # create block read/write regions
-            reads: List[BufferRegion] = (
-                [read.as_buffer_region() for read in block_info.reads] if block_info.reads else []
-            )
-            writes: List[BufferRegion] = (
-                [write.as_buffer_region() for write in block_info.writes]
-                if block_info.writes
-                else []
-            )
-
-            region_detect_mask: int = (block_info.reads is None) | (
-                (block_info.writes is None) << 1
-            )
-            annotations = {} if block_info.annotations is None else block_info.annotations
-            if region_detect_mask != 0:
-                annotations["tir.script_parsing_detect_access"] = region_detect_mask
-            inner = tvm.tir.Block(
-                block_info.iter_vars,
-                reads,
-                writes,
-                name_hint,
-                self.body,
-                block_info.init,
-                block_info.alloc_buffers,
-                block_info.match_buffers,
-                annotations,
-                span,
-            )
-            assert len(block_info.iter_vars) == len(block_info.iter_values)
-            predicate = (
-                tvm.tir.const(True, "bool")
-                if block_info.predicate is None
-                else block_info.predicate
-            )
-            body = tvm.tir.BlockRealize(block_info.iter_values, predicate, inner, span)
-            return body
-
-        super().__init__(func=block, concise_scope=False, def_symbol=True)
-        self.block_vars = None
-
-    def enter_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        # define block vars
-        assert isinstance(
-            node, synr.ast.With
-        ), f"BlockScopeHandler expected to work on synr.ast.With but got {type(node)}"
-
-        optional_vars = [var.id.name for var in WithScopeHandler.get_optional_vars(node, context)]
-        if optional_vars:
-            context.report_error(
-                f"Block expected no optional_vars (e.g., `x` in `with block() as x`), "
-                f"but got {optional_vars}",
-                node.span,
-            )
-
-
-@register
-class InitBlock(WithScopeHandler):
-    """With scope handler T.init()"""
-
-    def __init__(self):
-        def init(span: Span = None):
-            assert self.context, "call 'exit_scope' before 'enter_scope'"
-            if self.context.block_info_stack[-2].init is not None:
-                self.context.report_error("Duplicate init block declaration", span)
-            self.context.block_info_stack[-2].init = self.body
-
-        super().__init__(func=init, concise_scope=False, def_symbol=True)
-
-
-class LoopInfo:
-    """Helper class for loop information"""
-
-    loop_var: Var
-    begin: PrimExpr
-    extent: PrimExpr
-    kind: ForKind
-    thread_binding: Optional[str]
-    annotations: Optional[Mapping[str, Object]]
-
-    def __init__(
-        self,
-        begin: PrimExpr,
-        extent: PrimExpr,
-        kind: ForKind,
-        thread_binding: Optional[str] = None,
-        annotations: Optional[Mapping[str, Object]] = None,
-    ) -> None:
-        self.begin = begin
-        self.extent = extent
-        self.kind = kind
-        self.thread_binding = thread_binding
-        self.annotations = annotations
-
-
-class ForScopeHandler(ScopeHandler):
-    """Base class for all for scope handlers"""
-
-    def __init__(self, func):
-        super().__init__(func)
-        self.loop_vars: List[Var] = []
-        self.loop_info: List[LoopInfo] = []
-
-    def enter_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        assert isinstance(
-            node, synr.ast.For
-        ), f"ForScopeHandler expected synr.ast.For but got {type(node)}"
-
-        loop_var_names = list()
-        spans = list()
-        if isinstance(node.lhs, synr.ast.Var):
-            loop_var_names.append(node.lhs.id.name)
-            spans.append(tvm_span_from_synr(node.lhs.id.span))
-        elif isinstance(node.lhs, list):
-            for elt in node.lhs:
-                if not isinstance(elt, synr.ast.Var):
-                    context.report_error(
-                        f"Invalid loop var. Expected a var, but got {type(elt)}", elt.span
-                    )
-                loop_var_names.append(elt.id.name)
-                spans.append(tvm_span_from_synr(elt.id.span))
-        else:
-            context.report_error(
-                f"Invalid loop var. Expected var or list of vars as lhs, but got {type(node.lhs)}",
-                span,
-            )
-
-        self.node = node
-        self.context = context
-        # collect loop infos by calling self.func
-        call_with_error_reporting(context.report_error, span, self.func, *arg_list)
-        if len(loop_var_names) != len(self.loop_info):
-            self.context.report_error(
-                f"Inconsistent number of vars and loops, got {len(loop_var_names)} "
-                + f"vs {len(self.loop_info)}",
-                self.node.span,
-            )
-        # generate loop vars
-        self.loop_vars = []
-        for name, lv_span, li in zip(loop_var_names, spans, self.loop_info):
-            if not li.begin.dtype.startswith("int"):
-                raise NotImplementedError(f"Unsupported dtype in loop begin: {li.begin.dtype}")
-            if not li.extent.dtype.startswith("int"):
-                raise NotImplementedError(f"Unsupported dtype in loop extent: {li.extent.dtype}")
-            dtype = "int64" if "int64" in [li.begin.dtype, li.extent.dtype] else "int32"
-            self.loop_vars.append(tvm.te.var(name, dtype=dtype, span=lv_span))
-
-        for loop_var, loop_info in zip(self.loop_vars, self.loop_info):
-            context.update_symbol(loop_var.name, loop_var, node)
-            context.loop_stack[loop_var] = Range.from_min_extent(loop_info.begin, loop_info.extent)
-
-    def exit_scope(
-        self,
-        node: synr.ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        assert self.loop_vars, "call 'exit_scope' before 'enter_scope'"
-        for loop_var in self.loop_vars:
-            context.loop_stack.pop(loop_var)
-        # Use assert here since we have check it in `enter_scope`
-        assert len(self.loop_vars) == len(self.loop_info)
-
-        body = self.body
-        for var, info in zip(reversed(self.loop_vars), reversed(self.loop_info)):
-            body = tvm.tir.For(
-                var,
-                info.begin,
-                info.extent,
-                info.kind,
-                body,
-                info.thread_binding,
-                info.annotations,
-                span=tvm_span_from_synr(span),
-            )
-
-        return body
-
-    def create_loop_info(
-        self,
-        begin: PrimExpr,
-        end: PrimExpr,
-        kind: ForKind,
-        thread_binding: Optional[str] = None,
-        annotations: Optional[Mapping[str, Object]] = None,
-    ) -> None:
-        """
-        Helper function for creating For in TVM Script parser.
-
-        Parameters
-        ----------
-        begin : PrimExpr
-            The beginning value.
-
-        end : PrimExpr
-            The endding value.
-
-        kind : ForKind
-            The type of the for.
-
-        thread_binding: Optional[str]
-            The thread this loop binds to.
-
-        annotations : Optional[Mapping[str, Object]]
-            Additional annotation hints.
-
-        span : Optional[Span]
-            The location of this for in the source code.
-
-        Returns
-        -------
-        for : For
-            The constructed For.
-        """
-        begin, end = [convert(_) for _ in [begin, end]]
-        assert self.context and self.node, "call 'exit_scope' before 'enter_scope'"
-        extent = end if begin == 0 else self.context.analyzer.simplify(end - begin)
-        if begin == 0 and isinstance(extent, PrimExpr):
-            begin = IntImm(extent.dtype, 0, begin.span)
-        self.annotations: Mapping[str, Object] = {}
-        if annotations is not None:
-            self.annotations = {
-                key: String(val) if isinstance(val, str) else val
-                for key, val in annotations.items()
-            }
-
-        self.loop_info.append(LoopInfo(begin, extent, kind, thread_binding, annotations))
-
-
-@register
-class Serial(ForScopeHandler):
-    """For scope handler T.serial(begin, end, annotations)"""
-
-    def __init__(self):
-        def serial(
-            begin: PrimExpr,
-            end: PrimExpr = None,
-            annotations: Optional[Mapping[str, Object]] = None,
-        ):
-            if end is None:
-                end = begin
-                begin = 0
-            self.create_loop_info(begin, end, ForKind.SERIAL, annotations=annotations)
-
-        super().__init__(serial)
-
-
-@register
-class Parallel(ForScopeHandler):
-    """For scope handler T.parallel(begin, end, annotations)"""
-
-    def __init__(self):
-        def parallel(
-            begin: PrimExpr,
-            end: PrimExpr = None,
-            annotations: Optional[Mapping[str, Object]] = None,
-        ):
-            if end is None:
-                end = begin
-                begin = 0
-            self.create_loop_info(begin, end, ForKind.PARALLEL, annotations=annotations)
-
-        super().__init__(parallel)
-
-
-@register
-class Vectorized(ForScopeHandler):
-    """For scope handler T.vectorized(begin, end, annotations)"""
-
-    def __init__(self):
-        def vectorized(
-            begin: PrimExpr,
-            end: PrimExpr = None,
-            annotations: Optional[Mapping[str, Object]] = None,
-        ):
-            if end is None:
-                end = begin
-                begin = 0
-            self.create_loop_info(begin, end, ForKind.VECTORIZED, annotations=annotations)
-
-        super().__init__(vectorized)
-
-
-@register
-class Unroll(ForScopeHandler):
-    """For scope handler T.unroll(begin, end, annotations)"""
-
-    def __init__(self):
-        def unroll(
-            begin: PrimExpr,
-            end: PrimExpr = None,
-            annotations: Optional[Mapping[str, Object]] = None,
-        ):
-            if end is None:
-                end = begin
-                begin = 0
-            self.create_loop_info(begin, end, ForKind.UNROLLED, annotations=annotations)
-
-        super().__init__(unroll)
-
-
-@register
-class ThreadBinding(ForScopeHandler):
-    """For scope handler T.thread_binding(begin, end, thread, annotations)"""
-
-    def __init__(self):
-        def thread_binding(
-            begin: PrimExpr,
-            end: PrimExpr = None,
-            thread: str = None,
-            annotations: Optional[Mapping[str, Object]] = None,
-        ):
-            if thread is None:
-                if isinstance(end, str):  # handle case like thread_binding(128, "threadIdx.x")
-                    thread = end
-                    end = None
-                else:
-                    raise ValueError("Thread cannot be None for thread_binding")
-            if end is None:
-                end = begin
-                begin = 0
-            thread_iter_var = IterVar(None, None, IterVar.ThreadIndex, thread)
-            self.create_loop_info(
-                begin,
-                end,
-                ForKind.THREAD_BINDING,
-                thread_binding=thread_iter_var,
-                annotations=annotations,
-            )
-
-        super().__init__(thread_binding)
-
-
-@register
-class RangeHandler(ForScopeHandler):
-    """For scope handler range(begin, end, annotations)
-    Note that tir.range is totally the same as T.serial
-    """
-
-    def __init__(self):
-        def for_range(
-            begin: PrimExpr,
-            end: PrimExpr = None,
-            annotations: Optional[Mapping[str, Object]] = None,
-        ):
-            if end is None:
-                end = begin
-                begin = 0
-            self.create_loop_info(begin, end, ForKind.SERIAL, annotations=annotations)
-
-        super().__init__(for_range)
-
-    def signature(self):
-        return "range", get_param_list(self.func)
-
-
-@register
-class Grid(ForScopeHandler):
-    """For scope handler T.grid(extents)"""
-
-    def __init__(self):
-        def grid(*extents: List[PrimExpr]):
-            for extent in extents:
-                self.create_loop_info(0, extent, ForKind.SERIAL)
-
-        super().__init__(grid)
diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py
deleted file mode 100644
index 15502055b7..0000000000
--- a/python/tvm/script/tir/special_stmt.py
+++ /dev/null
@@ -1,964 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser Special Stmt Classes"""
-# pylint: disable=unused-argument, no-self-argument, inconsistent-return-statements
-# pylint: disable=relative-beyond-top-level
-from typing import Callable, List, Optional, Tuple, Any, Mapping, Union
-
-import synr
-from synr import ast
-from tvm.ir.expr import PrimExpr, Range
-
-import tvm.tir
-from tvm.runtime import Object, String
-from tvm.target import Target
-from tvm.ir import Span
-from tvm.tir import IntImm, IterVar, Var
-
-from .node import BufferSlice
-
-from ..context_maintainer import BlockInfo, ContextMaintainer
-from ..registry import register
-from ..utils import (
-    get_param_list,
-    tvm_span_from_synr,
-    call_with_error_reporting,
-)
-
-
-def convert_to_int(
-    value: Union[IntImm, int],
-    arg_name: str,
-    report_error: Callable,
-    span: Union[Span, synr.ast.Span],
-) -> int:
-    """convert a const int or TVM IntImm to Python int.
-    Reports an error when input cannot be converted to int.
-
-    Parameters
-    ----------
-    value : Union[tvm.tir.IntImm, int]
-        The input value to be converted.
-    arg_name : str
-        Function argument name for error reporting.
-    report_error: Callable
-        The report error function handle
-    span : Union[synr.ast.Span, tvm.ir.Span]
-        Location of the error
-    """
-    if isinstance(value, IntImm):
-        return value.value
-    if isinstance(value, int):
-        return value
-    report_error(
-        f"Expected int or IntImm for {arg_name}, but got {str(type(value))}",
-        span,
-    )
-
-
-class SpecialStmt:
-    """Base class for all Special Stmts"""
-
-    def __init__(self, func: Callable, def_symbol: bool):
-        self.func: Callable = func
-        self.def_symbol: bool = def_symbol
-        self.node: Optional[synr.ast.Node] = None
-        self.context: Optional[ContextMaintainer] = None
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir." + self.func.__name__, get_param_list(self.func)
-
-    def handle(
-        self,
-        node: ast.Node,
-        context: ContextMaintainer,
-        arg_list: List[Any],
-        span: synr.ast.Span,
-    ):
-        self.node = node
-        self.context = context
-        return call_with_error_reporting(
-            context.report_error, span, self.func, *arg_list, span=tvm_span_from_synr(span)
-        )
-
-
-@register
-class MatchBuffer(SpecialStmt):
-    """Special Stmt match_buffer(param, shape, dtype, data, strides, elem_offset, scope, align,
-                                 offset_factor, buffer_type, axis_separators)
-
-    Note
-    ----
-    This Special Stmt will perform different behavior depends on the type of param.
-    If the param is a var in function parameter, it will create a buffer from DLTensor.
-    Else if the param is a subregion of other buffers, then create a subregion match inside a block.
-
-    Example
-    -------
-    Match buffer from function parameter
-    .. code-block:: python
-        A = T.match_buffer(a, (128, 128), dtype="float32")
-
-    Match buffer from Buffer subregion
-    .. code-block:: python
-        A = T.match_buffer(B[0:128, i * 128 : i * 128 + 128], (128, 128), dtype="float32")
-    """
-
-    def __init__(self):
-        def match_buffer(
-            param,
-            shape,
-            dtype="float32",
-            data=None,
-            strides=None,
-            elem_offset=None,
-            scope="global",
-            align=-1,
-            offset_factor=0,
-            buffer_type="default",
-            axis_separators=None,
-            span=None,
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`match_buffer` must be assigned to a single buffer, "
-                    "e.g. A = match_buffer(...)",
-                    self.node.span,
-                )
-            if strides is None:
-                strides = []
-            align = convert_to_int(align, "align", self.context.report_error, self.node.span)
-            offset_factor = convert_to_int(
-                offset_factor, "offset_factor", self.context.report_error, self.node.span
-            )
-            buffer_name: str = self.node.lhs[0].id.name
-            buffer = tvm.tir.decl_buffer(
-                shape,
-                dtype,
-                buffer_name,
-                data,
-                strides,
-                elem_offset,
-                scope,
-                align,
-                offset_factor,
-                buffer_type,
-                axis_separators,
-                span=span,
-            )
-            if isinstance(param, tvm.tir.Var):
-                if param not in self.context.func_params:
-                    self.context.report_error(
-                        "Can not bind non-input param to buffer", self.node.rhs.params[0].span
-                    )
-                self.context.func_buffer_map[param] = buffer
-            elif isinstance(param, BufferSlice):
-                buffer_region = param.as_buffer_region()
-                self.context.current_block_scope().match_buffers.append(
-                    tvm.tir.MatchBufferRegion(buffer, buffer_region)
-                )
-            else:
-                self.context.report_error(
-                    "The source of match_buffer expected Var or BufferSlice, but got "
-                    + str(type(param)),
-                    self.node.rhs.params[0].span,
-                )
-            self.context.update_symbol(buffer_name, buffer, self.node)
-
-        super().__init__(match_buffer, def_symbol=True)
-
-
-@register
-class BufferDeclare(SpecialStmt):
-    """Special Stmt buffer_decl(shape, dtype, data, strides, elem_offset, scope, align,
-                                offset_factor, buffer_type, axis_separators)
-    Example
-    -------
-    .. code-block:: python
-        A = T.buffer_decl((128, 128), dtype="float32")
-    """
-
-    def __init__(self):
-        def buffer_decl(
-            shape,
-            dtype="float32",
-            data=None,
-            strides=None,
-            elem_offset=None,
-            scope="global",
-            align=-1,
-            offset_factor=0,
-            buffer_type="default",
-            axis_separators=None,
-            span=None,
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`buffer_decl` must be assigned to a single buffer, e.g. A = buffer_decl(...)",
-                    self.node.span,
-                )
-
-            if strides is None:
-                strides = []
-            align = convert_to_int(align, "align", self.context.report_error, self.node.span)
-            offset_factor = convert_to_int(
-                offset_factor, "offset_factor", self.context.report_error, self.node.span
-            )
-            buffer_name: str = self.node.lhs[0].id.name
-            buffer = tvm.tir.decl_buffer(
-                shape,
-                dtype,
-                buffer_name,
-                data,
-                strides,
-                elem_offset,
-                scope,
-                align,
-                offset_factor,
-                buffer_type,
-                axis_separators,
-                span=span,
-            )
-            self.context.update_symbol(buffer_name, buffer, self.node)
-            return buffer
-
-        super().__init__(buffer_decl, def_symbol=True)
-
-
-@register
-class AllocBuffer(SpecialStmt):
-    """Special function alloc_buffer(shape, dtype, data, strides, elem_offset, scope, align,
-                                     offset_factor, buffer_type, axis_separators)
-
-    Example
-    -------
-    .. code-block:: python
-
-        A = T.alloc_buffer((128, 128), dtype="float32")
-    """
-
-    def __init__(self):
-        def alloc_buffer(
-            shape,
-            dtype="float32",
-            data=None,
-            strides=None,
-            elem_offset=None,
-            scope="global",
-            align=-1,
-            offset_factor=0,
-            buffer_type="default",
-            axis_separators=None,
-            span=None,
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`alloc_buffer` must be assigned to a single buffer, "
-                    "e.g. A = alloc_buffer(...)",
-                    self.node.span,
-                )
-
-            if strides is None:
-                strides = []
-            align = convert_to_int(align, "align", self.context.report_error, self.node.span)
-            offset_factor = convert_to_int(
-                offset_factor, "offset_factor", self.context.report_error, self.node.span
-            )
-            buffer_name: str = self.node.lhs[0].id.name
-            buffer = tvm.tir.decl_buffer(
-                shape,
-                dtype,
-                buffer_name,
-                data,
-                strides,
-                elem_offset,
-                scope,
-                align,
-                offset_factor,
-                buffer_type,
-                axis_separators,
-                span=span,
-            )
-            if self.context.current_block_scope():
-                self.context.current_block_scope().alloc_buffers.append(buffer)
-            else:
-                # If it is allocated outside all blocks, allocate it under root block.
-                self.context.root_alloc_buffers.append(buffer)
-            self.context.update_symbol(buffer_name, buffer, self.node)
-
-        super().__init__(alloc_buffer, def_symbol=True)
-
-
-@register
-class BlockReads(SpecialStmt):
-    """Special function reads([read_regions], *other_regions)
-
-    Note
-    ----
-    *other_region is an unpackable list of BufferSlice to support
-    reads syntax sugar like reads(BufferRegion1, BufferRegion2, ...)
-
-    Example
-    -------
-    .. code-block:: python
-
-        T.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]])
-    """
-
-    def __init__(self):
-        def reads(
-            *read_regions: Union[BufferSlice, List[BufferSlice]],
-            span: Span = None,
-        ):
-            assert self.context, "call 'exit_scope' before 'enter_scope'"
-            block_scope = self.context.current_block_scope()
-            if block_scope is None:
-                self.context.report_error(
-                    "Expected to declare read regions inside a block.",
-                    span,
-                )
-            if block_scope.reads is not None:
-                self.context.report_error(
-                    "Duplicate write region declaration, "
-                    + "previous one is "
-                    + str(", ".join(str(x) for x in block_scope.reads)),
-                    span,
-                )
-            if len(read_regions) > 1:
-                for read_region in read_regions:
-                    if not isinstance(read_region, BufferSlice):
-                        self.context.report_error(
-                            "Incorrect input type. Expected *BufferSlice or List[BufferSlice],"
-                            + f" but got {type(read_regions)}",
-                            span,
-                        )
-            elif len(read_regions) == 1:
-                if isinstance(read_regions[0], list):
-                    read_regions = read_regions[0]
-
-            block_scope.reads = read_regions
-
-        super().__init__(reads, def_symbol=False)
-
-
-@register
-class BlockWrites(SpecialStmt):
-    """Special function writes([write_regions], *other_regions)
-
-    Note
-    ----
-    *other_region is an unpackable list of BufferSlice to support
-    writes syntax sugar like writes(BufferRegion1, BufferRegion2, ...)
-
-    Example
-    -------
-    .. code-block:: python
-
-        T.writes([C[vi: vi + 4, vj])
-    """
-
-    def __init__(self):
-        def writes(
-            *write_regions: Union[BufferSlice, List[BufferSlice]],
-            span: Span = None,
-        ):
-            assert self.context, "call 'exit_scope' before 'enter_scope'"
-            block_scope = self.context.current_block_scope()
-            if block_scope is None:
-                self.context.report_error(
-                    "Expected to declare write regions inside a block.",
-                    span,
-                )
-            if block_scope.writes is not None:
-                self.context.report_error(
-                    "Duplicate write region declaration, "
-                    + "previous one is "
-                    + str(", ".join(str(x) for x in block_scope.writes)),
-                    span,
-                )
-            if len(write_regions) > 1:
-                for write_region in write_regions:
-                    if not isinstance(write_region, BufferSlice):
-                        self.context.report_error(
-                            "Incorrect input type. Expected *BufferSlice or List[BufferSlice],"
-                            + f" but got {type(write_regions)}",
-                            span,
-                        )
-            elif len(write_regions) == 1:
-                if isinstance(write_regions[0], list):
-                    write_regions = write_regions[0]
-            block_scope.writes = write_regions
-
-        super().__init__(writes, def_symbol=False)
-
-
-@register
-class BlockAttr(SpecialStmt):
-    """Special function block_attr({attr_key: attr_value})
-
-    Example
-    -------
-    .. code-block:: python
-
-        T.block_attr({"double_buffer_scope": 1})
-    """
-
-    def __init__(self):
-        def block_attr(attrs: Mapping[str, Object], span: Span = None):
-            assert self.context, "call 'exit_scope' before 'enter_scope'"
-            block_scope = self.context.current_block_scope()
-            if block_scope is None:
-                self.context.report_error(
-                    "Expected to declare block annotations inside a block.",
-                    span,
-                )
-            if block_scope.annotations is not None:
-                self.context.report_error(
-                    "Duplicate block annotations declaration, "
-                    + "previous one is "
-                    + str(block_scope.annotations),
-                    span,
-                )
-            attrs = {
-                key: String(val) if isinstance(val, str) else val for key, val in attrs.items()
-            }
-            block_scope.annotations = attrs
-
-        super().__init__(block_attr, def_symbol=False)
-
-
-class BlockAxis(SpecialStmt):
-    """Special stmt for defining a spatial block axis
-    axis.S(dom, iter_value)
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi = T.axis.S(128, i * 4 + j)
-    """
-
-    def axis(
-        self,
-        var_name: str,
-        dom: Union[PrimExpr, Range],
-        value: PrimExpr,
-        iter_type: int,
-        span: Optional[Span] = None,
-    ) -> None:
-        """
-        Helper function for creating block axis
-
-        Parameters
-        ----------
-        var_name : str
-            The name_hint of var
-
-        dom : Union[PrimExpr, Range]
-            The iter domain.
-
-        value : PrimExpr
-            The binding value
-
-        iter_type : int
-            The iteration type.
-
-        span : Optional[Span]
-            The location of this for in the source code.
-        """
-        assert self.context, "call 'exit_scope' before 'enter_scope'"
-        block_scope: BlockInfo = self.context.current_block_scope()
-        if block_scope is None:
-            self.context.report_error(
-                "Expected to declare block axes inside a block.",
-                self.node.span,
-            )
-        if var_name in [iter_var.var.name for iter_var in block_scope.iter_vars]:
-            self.context.report_error("Duplicate block axis " + var_name, self.node.span)
-
-        dom = tvm.runtime.convert(dom)
-        if isinstance(dom, PrimExpr):
-            dom = tvm.ir.Range(dom)
-        elif isinstance(dom, tvm.ir.container.Array) and len(dom) == 2:
-            dom = tvm.ir.Range(dom[0], dom[1])
-        elif not isinstance(dom, tvm.ir.Range):
-            self.context.report_error(
-                f"Block axis domain expected PrimExpr or Range, but got {type(dom)}",
-                self.node.span,
-            )
-        block_var = tvm.tir.Var(var_name, dtype=dom.extent.dtype)
-        value = tvm.runtime.convert(value)
-        if not isinstance(value, PrimExpr):
-            self.context.report_error(
-                f"Block axis value expected PrimExpr, but got {type(value)}",
-                self.node.span,
-            )
-        iter_var = tvm.tir.IterVar(dom, block_var, iter_type)
-        block_scope.iter_vars.append(iter_var)
-        block_scope.iter_values.append(value)
-        self.context.update_symbol(var_name, block_var, self.node)
-
-
-@register
-class BlockAxisSpatial(BlockAxis):
-    """Special stmt for defining a spatial block axis
-    axis.spatial(dom, iter_value)
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi = T.axis.spatial(128, k)
-    """
-
-    def __init__(self):
-        def axis_spatial(
-            dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`axis.spatial` must be assigned to a var, e.g. vi = axis.spatial(...)",
-                    self.node.span,
-                )
-            self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DataPar)
-
-        super().__init__(axis_spatial, def_symbol=True)
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir.axis.spatial", get_param_list(self.func)
-
-
-@register
-class BlockAxisS(BlockAxis):
-    """The sugar special stmt for defining a spatial block axis
-    axis.S(dom, iter_value)
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi = T.axis.S(128, k)
-    """
-
-    def __init__(self):
-        def axis_spatial(
-            dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`axis.S` must be assigned to a var, e.g. vi = axis.S(...)",
-                    self.node.span,
-                )
-            self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DataPar)
-
-        super().__init__(axis_spatial, def_symbol=True)
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir.axis.S", get_param_list(self.func)
-
-
-@register
-class BlockAxisReduce(BlockAxis):
-    """Special stmt for defining a reduce block axis
-    axis.reduce(dom, iter_value)
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi = T.axis.reduce(128, k)
-    """
-
-    def __init__(self):
-        def axis_reduce(
-            dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`axis.reduce` must be assigned` to a var, e.g. vi = axis.reduce(...)",
-                    self.node.span,
-                )
-            self.axis(self.node.lhs[0].id.name, dom, value, IterVar.CommReduce)
-
-        super().__init__(axis_reduce, def_symbol=True)
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir.axis.reduce", get_param_list(self.func)
-
-
-@register
-class BlockAxisR(BlockAxis):
-    """The sugar special stmt for defining a reduce block axis
-    axis.R(dom, iter_value)
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi = T.axis.R(128, k)
-    """
-
-    def __init__(self):
-        def axis_reduce(
-            dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`axis.R` must be assigned to a var, e.g. vi = axis.R(...)",
-                    self.node.span,
-                )
-            self.axis(self.node.lhs[0].id.name, dom, value, IterVar.CommReduce)
-
-        super().__init__(axis_reduce, def_symbol=True)
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir.axis.R", get_param_list(self.func)
-
-
-@register
-class BlockAxisScan(BlockAxis):
-    """Special stmt for defining a ordered block axis
-    axis.scan(dom, iter_value)
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi = T.axis.scan(128, k)
-    """
-
-    def __init__(self):
-        def axis_scan(
-            dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`axis.scan` must be assigned to a var, e.g. vi = axis.scan(...)",
-                    self.node.span,
-                )
-            self.axis(self.node.lhs[0].id.name, dom, value, IterVar.Ordered)
-
-        super().__init__(axis_scan, def_symbol=True)
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir.axis.scan", get_param_list(self.func)
-
-
-@register
-class BlockAxisOpaque(BlockAxis):
-    """Special stmt for defining a opaque block axis
-    axis.opaque(dom, iter_value)
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi = T.axis.opaque(128, k)
-    """
-
-    def __init__(self):
-        def axis_opaque(
-            dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None
-        ):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1:
-                self.context.report_error(
-                    "`axis.opaque` must be assigned to a var, e.g. vi = axis.opaque(...)",
-                    self.node.span,
-                )
-            self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DimInfo)
-
-        super().__init__(axis_opaque, def_symbol=True)
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir.axis.opaque", get_param_list(self.func)
-
-
-@register
-class BlockAxisRemap(BlockAxis):
-    """Special stmt for remapping loops vars to block axes.
-    axis.remap(iter_type, iter_value)
-
-    Note
-    ----
-    Iter_type is a string consisting of 'S' and 'R', where 'S' means
-    for spatial and 'R' means for reduce.
-
-    Example
-    -------
-    .. code-block:: python
-
-        vi, vj = T.axis.remap("SS", [i, j])
-    """
-
-    def __init__(self):
-        def axis_remap(iter_types: str, loop_vars: List[tvm.tir.expr.Var], span: Span = None):
-            if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) >= 1:
-                self.context.report_error(
-                    "`axis.remap` must be assigned to one or more vars, "
-                    "e.g. vi, vj = axis.remap(...)",
-                    self.node.span,
-                )
-            var_num: int = len(self.node.lhs)
-            if var_num != len(iter_types):
-                self.context.report_error(
-                    f"`iter_type` expected {var_num} charactor(s), "
-                    f"but got {len(iter_types)}: {iter_types}",
-                    span,
-                )
-            if var_num != len(loop_vars):
-                self.context.report_error(
-                    f"`iter_type` expected {var_num} loop var(s), "
-                    f"but got {len(loop_vars)}: {loop_vars}",
-                    span,
-                )
-            for var, iter_ty, loop_var in zip(self.node.lhs, iter_types, loop_vars):
-                iter_type: int
-                if iter_ty == "S":
-                    iter_type = IterVar.DataPar
-                elif iter_ty == "R":
-                    iter_type = IterVar.CommReduce
-                else:
-                    self.context.report_error(
-                        f'`iter_type` only expected "S" (for spatial) or "R" (for reduce), '
-                        f'but got "{iter_ty}"',
-                        span,
-                    )
-
-                if not isinstance(loop_var, tvm.tir.expr.Var):
-                    self.context.report_error(
-                        f"Values of `axis.remap` expected single loop var, but got {loop_var}",
-                        loop_var.span,
-                    )
-                loops = self.context.loop_stack
-                if loop_var not in loops:
-                    self.context.report_error(
-                        f"Cannot find loop var {loop_var} in loop nesting.",
-                        span,
-                    )
-                self.axis(var.id.name, loops[loop_var], loop_var, iter_type)
-
-        super().__init__(axis_remap, def_symbol=True)
-
-    def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
-        return "tir.axis.remap", get_param_list(self.func)
-
-
-@register
-class BlockPredicate(SpecialStmt):
-    """Special function where(predicate)
-
-    Example
-    -------
-    .. code-block:: python
-
-        T.where(i < 4)
-    """
-
-    def __init__(self):
-        def where(predicate, span=None):
-            assert self.context, "call 'exit_scope' before 'enter_scope'"
-            block_scope = self.context.current_block_scope()
-            if block_scope is None:
-                self.context.report_error(
-                    "Expected to declare the predicate inside a block.",
-                    span,
-                )
-            if block_scope.predicate is not None:
-                self.context.report_error(
-                    "Duplicate block predicate declaration, "
-                    + "previous one is "
-                    + str(block_scope.predicate),
-                    span,
-                )
-
-            block_scope.predicate = predicate
-
-        super().__init__(where, def_symbol=False)
-
-
-@register
-class VarDef(SpecialStmt):
-    """Special function for defining a Var"""
-
-    def __init__(self):
-        def var(dtype, span):
-            assert isinstance(
-                self.node, ast.Assign
-            ), f"VarDef expected ast.Assign but got {type(self.node)}"
-            names = [x.id.name for x in self.node.lhs]
-            if len(names) != 1:
-                self.context.report_error(
-                    f"VarDef expected assign to only one var, but got {names}", span
-                )
-            v = Var(names[0], dtype, span=span)
-            self.context.update_symbol(v.name, v, self.node)
-
-        super().__init__(var, def_symbol=True)
-
-
-@register
-class BufferVarDef(SpecialStmt):
-    """Special function for defining a variable of pointer type"""
-
-    def __init__(self):
-        def buffer_var(dtype, storage_scope, span):
-            assert isinstance(
-                self.node, ast.Assign
-            ), f"BufferVarDef expected ast.Assign but got {type(self.node)}"
-            names = [x.id.name for x in self.node.lhs]
-            if len(names) != 1:
-                self.context.report_error(
-                    f"VarDef expected assign to only one var, but got {names}", span
-                )
-            ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope)
-            v = Var(names[0], ptr_type, span=span)
-            self.context.update_symbol(v.name, v, self.node)
-
-        super().__init__(buffer_var, def_symbol=True)
-
-
-@register
-class EnvThread(SpecialStmt):
-    """Bind a var to thread env"""
-
-    def __init__(self):
-        def env_thread(env_name, span):
-            assert isinstance(
-                self.node, ast.Assign
-            ), f"EnvThread expected ast.Assign but got {type(self.node)}"
-            names = [x.id.name for x in self.node.lhs]
-            if len(names) != 1:
-                self.context.report_error(
-                    f"VarDef expected assign to only one var, but got {names}", span
-                )
-            v = Var(names[0], dtype="int32", span=span)
-            self.context.func_var_env_dict[v] = env_name
-            self.context.update_symbol(v.name, v, self.node)
-
-        super().__init__(env_thread, def_symbol=True)
-
-
-@register
-class FuncAttr(SpecialStmt):
-    """Special Stmt for declaring the DictAttr of PrimFunc
-    Example
-    -------
-    .. code-block:: python
-         T.func_attr({"tir.noalias": True, "global_symbol"})
-    """
-
-    def __init__(self):
-        def func_attr(dict_attr, span):
-            self.context.func_dict_attr = dict_attr
-
-        super().__init__(func_attr, def_symbol=False)
-
-
-@register
-class PreflattenedBufferMap(SpecialStmt):
-    """Special Stmt for declaring the PrimFunc::preflattened_buffer_map
-
-    Example
-    -------
-    .. code-block:: python
-         A0 = T.match_buffer(A, (48,), dtype="float32")
-         T.preflattened_buffer_map(A, (1, 4, 4, 3), elem_offset=1, align=4, dtype="float32")
-    """
-
-    def __init__(self):
-        def preflattened_buffer(
-            postflattened,
-            shape,
-            dtype="float32",
-            data=None,
-            strides=None,
-            elem_offset=None,
-            scope="global",
-            align=-1,
-            offset_factor=0,
-            buffer_type="default",
-            span=None,
-        ):
-
-            param = None
-            for key, value in self.context.func_buffer_map.items():
-                if value.same_as(postflattened):
-                    param = key
-                    break
-
-            assert (
-                param is not None
-            ), f"Post-flatten buffer {postflattened.name} does not appear in the buffer map."
-
-            if data is None:
-                data = self.context.func_buffer_map[param].data
-
-            buffer_name: str = f"{postflattened.name}_preflatten"
-            if align != -1:
-                if isinstance(align, IntImm):
-                    align = align.value
-                else:
-                    assert isinstance(align, int), f"align: want int or IntImm, got {align!r}"
-
-            if offset_factor != 0:
-                if isinstance(offset_factor, IntImm):
-                    offset_factor = offset_factor.value
-                else:
-                    assert isinstance(
-                        offset_factor, int
-                    ), f"offset_factor: want int or IntImm, got {offset_factor!r}"
-
-            preflattened = tvm.tir.decl_buffer(
-                shape,
-                dtype,
-                buffer_name,
-                data,
-                strides,
-                elem_offset,
-                scope,
-                align,
-                offset_factor,
-                buffer_type,
-                span=span,
-            )
-
-            self.context.func_preflattened_buffer_map[param] = preflattened
-
-        super().__init__(preflattened_buffer, def_symbol=False)
-
-
-@register
-class TargetAttrValue(SpecialStmt):
-    """Special Stmt for target attr value.
-    Example
-    -------
-    .. code-block:: python
-        T.target("llvm")
-    """
-
-    def __init__(self):
-        def target(*args, span):
-            self.context.report_error(f"T.target should not appear as a stmt", span)
-
-        super().__init__(target, def_symbol=False)
-
-    def __call__(self, target_config):
-        if not isinstance(target_config, (str, dict)):
-            raise ValueError(
-                f"T.target expected a config dict or string, but got {type(target_config)}"
-            )
-        return Target(target_config)
diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py
deleted file mode 100644
index 4548102a9e..0000000000
--- a/python/tvm/script/tir/ty.py
+++ /dev/null
@@ -1,216 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM Script Parser Typing Class for TIR
-
-This module provides typing class for TVM script type annotation usage, it can be viewed as
-a wrapper for uniform Type system in IR
-"""
-# pylint: disable=invalid-name
-from numbers import Integral
-
-import tvm
-from .special_stmt import SpecialStmt, convert_to_int
-
-
-class TypeGeneric:  # pylint: disable=too-few-public-methods
-    """Base class for all the TVM script typing class"""
-
-    def evaluate(self):
-        """Return an actual ir.Type Object that this Generic class wraps"""
-        raise TypeError("Cannot get tvm.Type from a generic type")
-
-    def require_type_generic_at(self, idx):  # pylint: disable=unused-argument
-        """If True, the `idx`th type argument must be TypeGeneric"""
-        return True
-
-    # This function is added here to avoid a pylint error
-    # for T.int/float below not being callable
-    def __call__(self):
-        raise NotImplementedError()
-
-
-class ConcreteType(TypeGeneric):  # pylint: disable=too-few-public-methods, abstract-method
-    """TVM script typing class for uniform Type objects
-
-    Params
-    ------
-    vtype: Union[str, tvm.ir.Type]
-
-        The IR type represented by the type annotation.  If a string
-        (e.g. "float32"), this represents a `ir.PrimType` generated
-        from that string.  If a `ir.Type` is provided, this represents
-        the type provided.
-    """
-
-    def __init__(self, vtype):
-        if isinstance(vtype, tvm.ir.Type):
-            self.type = vtype
-        else:
-            self.type = tvm.ir.PrimType(vtype)
-
-    def __call__(self, *args):  # pylint: disable=arguments-differ
-        pass
-
-    def evaluate(self):
-        return self.type
-
-
-class VoidType(ConcreteType):  # pylint: disable=too-few-public-methods, abstract-method
-    """TVM script typing class for void type"""
-
-    def __init__(self):
-        super().__init__("")
-
-
-class GenericPtrType(TypeGeneric):  # pylint: disable=abstract-method
-    """TVM script typing class generator for PtrType
-
-    [] operator is overloaded, accepts a ConcreteType and an optional storage scope string,
-    returns a ConcreteType wrapping PtrType
-    """
-
-    def __getitem__(self, args):
-        if isinstance(args, TypeGeneric):
-            args = [args]
-        if len(args) == 1:
-            vtype, scope = args[0], "global"
-        elif len(args) == 2:
-            vtype, scope = args[0], args[1]
-        else:
-            raise TypeError(f"Illegal type argument num for Ptr")
-        if not isinstance(vtype, TypeGeneric):
-            raise TypeError(f"Ptr expects a type argument, but received {type(vtype).__name__}")
-        if not isinstance(scope, str):
-            raise TypeError(f"Ptr expects storage scope argument be a string")
-        return ConcreteType(tvm.ir.PointerType(vtype.evaluate(), scope))
-
-    def require_type_generic_at(self, idx):
-        return idx != 1  # the second argument is storage scope for Ptr
-
-
-class GenericTupleType(TypeGeneric):  # pylint: disable=abstract-method
-    """TVM script typing class generator for TupleType
-
-    [] operator is overloaded, accepts a list of ConcreteType and returns a ConcreteType
-    wrapping TupleType
-    """
-
-    def __getitem__(self, vtypes):
-        if isinstance(vtypes, TypeGeneric):
-            vtypes = [vtypes]
-        return ConcreteType(tvm.ir.TupleType([vtype.evaluate() for vtype in vtypes]))
-
-
-class GenericBufferType(SpecialStmt):  # pylint: disable=too-few-public-methods, abstract-method
-    """TVM script typing class for uniform Type objects"""
-
-    def __init__(self, vtype):
-        def match_buffer_syntax_sugar(
-            shape,
-            dtype: str = "float32",
-            name: str = None,
-            data=None,
-            strides=None,
-            elem_offset=None,
-            scope="global",
-            align=-1,
-            offset_factor=0,
-            buffer_type="default",
-            axis_separators=None,
-            span=None,
-        ):
-            if strides is None:
-                strides = []
-            align = convert_to_int(align, "align", self.context.report_error, self.node.span)
-            offset_factor = convert_to_int(
-                offset_factor, "offset_factor", self.context.report_error, self.node.span
-            )
-            buffer = tvm.tir.decl_buffer(
-                shape,
-                dtype,
-                name,
-                data,
-                strides,
-                elem_offset,
-                scope,
-                align,
-                offset_factor,
-                buffer_type,
-                axis_separators,
-                span=span,
-            )
-            return buffer
-
-        self.type = vtype
-        super().__init__(match_buffer_syntax_sugar, def_symbol=True)
-
-    def __call__(
-        self,
-        shape,
-        dtype="float32",
-        *,
-        name: str = None,
-        data=None,
-        strides=None,
-        elem_offset=None,
-        scope="global",
-        align=-1,
-        offset_factor=0,
-        buffer_type="default",
-        axis_separators=None,
-        span=None,
-    ):
-        """
-        This function is for Buffer(...) syntax sugar.
-        """
-        pass  # pylint: disable=unnecessary-pass
-
-    def __getitem__(self, args):
-        """
-        This function is for Buffer[...] syntax sugar
-        Note that args is the list of all arguments
-        """
-        if len(args) < 2:
-            raise ValueError("T.Buffer[...] needs at least two arguments: shape and dtype.")
-
-        shape = args[0]
-        dtype = args[1]
-
-        valid_shape = isinstance(shape, (tvm.ir.PrimExpr, Integral, tuple, list))
-        valid_dtype = isinstance(dtype, str)
-        if not (valid_shape and valid_dtype):
-            raise ValueError(
-                "The first argument of T.Buffer[...] needs to be a tuple, "
-                "followed by the second argument dtype as a string"
-            )
-
-
-# add all floating point and integer datatypes to the module
-for _dtype in ["float", "uint", "int"]:
-    for _size in ["8", "16", "32", "64"]:
-        for _lanes in ["", "x4", "x8", "x16", "x32"]:
-            _name = _dtype + _size + _lanes
-            globals()[_name] = ConcreteType(_name)
-
-boolean = ConcreteType("bool")
-handle = ConcreteType("handle")
-void = VoidType()
-Ptr = GenericPtrType()
-Tuple = GenericTupleType()
-# we don't have 'buffer' type on the cpp side
-# thus 'handle' is used here for convenience's sake
-Buffer = GenericBufferType("handle")
diff --git a/python/tvm/script/utils.py b/python/tvm/script/utils.py
deleted file mode 100644
index c655a62237..0000000000
--- a/python/tvm/script/utils.py
+++ /dev/null
@@ -1,105 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Helper functions in TVM Script Parser"""
-
-from typing import Callable, List, Any, Optional, Tuple
-
-import inspect
-import synr
-
-from tvm.ir import Span, SourceName
-from tvm.error import DiagnosticError
-
-
-def get_param_list(
-    func: Callable,
-) -> Tuple[List[str], List[Tuple[str, Tuple[Any, ...]]], Optional[str]]:
-    """Get the parameter list from definition of function"""
-    full_arg_spec: inspect.FullArgSpec = inspect.getfullargspec(func)
-
-    args: List[str]
-    defaults: Optional[Tuple[Any, ...]]
-    kwonlyargs: List[str]
-    args, defaults, kwonlyargs = (
-        full_arg_spec.args,
-        full_arg_spec.defaults,
-        full_arg_spec.kwonlyargs,
-    )
-
-    if defaults is None:
-        defaults = tuple()
-
-    if full_arg_spec.varkw is not None:
-        raise RuntimeError(
-            "TVM Script register error : variable keyword argument is not supported now"
-        )
-
-    if len(kwonlyargs) == 1 and kwonlyargs[0] == "span":
-        pass
-    elif not len(kwonlyargs) == 0:
-        raise RuntimeError("TVM Script register error : keyword only argument is not supported now")
-
-    pos_only: List[str] = list()
-    for arg in args[: len(args) - len(defaults)]:
-        if arg != "span":
-            pos_only.append(arg)
-    kwargs: List[Tuple[str, Tuple[Any, ...]]] = list()
-    for default, arg in zip(defaults, args[len(args) - len(defaults) :]):
-        if arg != "span":
-            kwargs.append((arg, default))
-
-    return pos_only, kwargs, full_arg_spec.varargs
-
-
-def tvm_span_from_synr(span: synr.ast.Span) -> Span:
-    """Convert a synr span to a TVM span"""
-    return Span(
-        SourceName(span.filename),
-        span.start_line,
-        span.end_line,
-        span.start_column,
-        span.end_column,
-    )
-
-
-def synr_span_from_tvm(span: Span) -> synr.ast.Span:
-    """Convert a TVM span to a synr span"""
-    return synr.ast.Span(
-        span.source_name.name,
-        span.line,
-        span.column,
-        span.end_line,
-        span.end_column,
-    )
-
-
-def call_with_error_reporting(
-    report_error,
-    node_span,
-    func,
-    *args,
-    **kwargs,
-):
-    """Call function with exception handling and report error using node_span"""
-    try:
-        return func(*args, **kwargs)
-    except DiagnosticError:
-        raise
-    except Exception as err:  # pylint: disable=broad-except
-        # printing last non-empty row of error message.
-        error_msg = list(filter(None, str(err).split("\n")))[-1]
-        report_error(error_msg, node_span)
diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py
index 5454041713..e769559b47 100644
--- a/python/tvm/tir/analysis/analysis.py
+++ b/python/tvm/tir/analysis/analysis.py
@@ -20,11 +20,11 @@ from typing import Dict, List, Union
 
 from tvm import Object
 from tvm.ir import IRModule
-from tvm.tir.expr import Var
-from tvm.tir.stmt import Block, BufferRegion, PrimExpr
 
-from .. import Buffer, Stmt
+from ..buffer import Buffer
+from ..expr import Var
 from ..function import PrimFunc
+from ..stmt import Block, BufferRegion, PrimExpr, Stmt
 from . import _ffi_api
 
 
diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py
index beefcb0d28..5742999c67 100644
--- a/python/tvm/tir/expr.py
+++ b/python/tvm/tir/expr.py
@@ -28,15 +28,16 @@ For example, you can use addexp.a to get the left operand of an Add node.
   assert(y.a == x)
 """
 from typing import Optional, Union
-from tvm import ir
+
 import tvm._ffi
+import tvm.ir._ffi_api
+from tvm import ir
+from tvm.ir import Op, PrimExpr
 from tvm.ir.base import Span
+from tvm.runtime import DataType, DataTypeCode, Object, ObjectGeneric, const
 
-from tvm.runtime import Object, ObjectGeneric, DataType, DataTypeCode, const
-from tvm.ir import PrimExpr, Op
-import tvm.ir._ffi_api
-from . import generic as _generic
 from . import _ffi_api
+from . import generic as _generic
 
 
 def div_ambiguity_error():
@@ -66,8 +67,6 @@ def _dtype_is_float(value):
 class ExprOp(object):
     """Operator overloading for Expr like expressions."""
 
-    # TODO(tkonolige): use inspect to add source information to these objects
-
     def __add__(self, other):
         return _generic.add(self, other)
 
@@ -1005,6 +1004,8 @@ class Select(PrimExprWithOp):
     """
 
     def __init__(self, condition, true_value, false_value, span=None):
+        if isinstance(condition, bool):
+            condition = IntImm("bool", condition)
         self.__init_handle_by_constructor__(
             _ffi_api.Select, condition, true_value, false_value, span  # type: ignore
         )
diff --git a/python/tvm/tir/schedule/block_scope.py b/python/tvm/tir/schedule/block_scope.py
index 30e047b4f7..0ebaf212d1 100644
--- a/python/tvm/tir/schedule/block_scope.py
+++ b/python/tvm/tir/schedule/block_scope.py
@@ -20,8 +20,8 @@ from typing import List, Optional, Union
 
 from tvm._ffi import register_object
 from tvm.runtime import Object
-from tvm.tir import Block, For
 
+from ..stmt import Block, For
 from . import _ffi_api
 
 
diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py
index fdc8717032..043f7922be 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -21,9 +21,11 @@ from tvm._ffi import register_object as _register_object
 from tvm.error import TVMError, register_error
 from tvm.ir import IRModule, PrimExpr
 from tvm.runtime import Object, String
-from tvm.tir import Block, Buffer, FloatImm, For, IntImm, PrimFunc
 
-from ..function import IndexMap
+from ..buffer import Buffer
+from ..expr import FloatImm, IntImm
+from ..function import IndexMap, PrimFunc
+from ..stmt import Block, For
 from . import _ffi_api
 from ._type_checker import type_checked
 from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod
diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py
index fbf21843e7..3aed52fb50 100644
--- a/python/tvm/tir/schedule/state.py
+++ b/python/tvm/tir/schedule/state.py
@@ -22,8 +22,9 @@ from typing import Dict, Optional, Union
 from tvm._ffi import register_object
 from tvm.ir import IRModule
 from tvm.runtime import Object
-from tvm.tir import Block, BlockRealize, For, PrimFunc
 
+from ..function import PrimFunc
+from ..stmt import Block, BlockRealize, For
 from . import _ffi_api
 from .block_scope import BlockScope, StmtSRef
 
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index 4847e377de..3c2228e6d9 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -754,3 +754,7 @@ def stmt_list(stmt):
             res += stmt_list(x)
         return res
     return [stmt]
+
+
+def type_annotation(dtype, span=None):
+    return _ffi_api.TypeAnnotation(dtype, span)
diff --git a/python/tvm/tir/usmp/transform/transform.py b/python/tvm/tir/usmp/transform/transform.py
index f472172cf3..86d8bef356 100644
--- a/python/tvm/tir/usmp/transform/transform.py
+++ b/python/tvm/tir/usmp/transform/transform.py
@@ -20,8 +20,9 @@
 from typing import Dict
 
 import tvm
-from tvm.tir import Stmt
-from tvm.tir.usmp.utils import PoolAllocation
+
+from ...stmt import Stmt
+from ..utils import PoolAllocation
 from . import _ffi_api
 
 
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index e21d014fe1..0cb8f79942 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -810,6 +810,14 @@ BufferRegion::BufferRegion(Buffer buffer, Array<Range> region) {
   CHECK_EQ(buffer->shape.size(), region.size())
       << "The dimension between " << buffer << " and region " << region
       << " mismatched, the buffer is " << buffer;
+  for (const Range& r : region) {
+    ICHECK(r->min->dtype.is_int() || r->min->dtype.is_uint())
+        << "ValueError: ranges of BufferRegion should be int, but got type " << r->min->dtype
+        << " for range " << r << " in its min value " << r->min;
+    ICHECK(r->extent->dtype.is_int() || r->extent->dtype.is_uint())
+        << "ValueError: ranges of BufferRegion should be int, but got type " << r->extent->dtype
+        << " for range " << r << " in its extent value " << r->extent;
+  }
   ObjectPtr<BufferRegionNode> node = make_object<BufferRegionNode>();
   node->buffer = std::move(buffer);
   node->region = std::move(region);
@@ -1091,5 +1099,7 @@ PrimExpr TypeAnnotation(DataType dtype, Span span) {
 TVM_REGISTER_OP("tir.type_annotation")
     .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
 
+TVM_REGISTER_GLOBAL("tir.TypeAnnotation").set_body_typed(TypeAnnotation);
+
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 509badbebb..0e605d1439 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -97,6 +97,17 @@ PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span s
                    {x, y, q, s}, span);
 }
 
+// address_of
+PrimExpr address_of(tir::BufferLoad buffer_load, Span span) {
+  return tir::Call(DataType::Handle(), tir::builtin::address_of(), {buffer_load}, span);
+}
+
+// lookup_param
+PrimExpr lookup_param(String param_name, Span span) {
+  return tir::Call(DataType::Handle(), tir::builtin::lookup_param(), {tir::StringImm(param_name)},
+                   span);
+}
+
 // The public function with a quick checking path.
 void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) {  // NOLINT(*)
   CHECK(lhs.defined()) << "ValueError: `lhs` is null in the binary operator";
@@ -705,6 +716,11 @@ PrimExpr isnan(PrimExpr x, Span span) {
   }
 }
 
+// isnullptr
+PrimExpr isnullptr(PrimExpr x, Span span) {
+  return tir::Call(DataType::Bool(1), tir::builtin::isnullptr(), {x}, span);
+}
+
 // isinf
 PrimExpr isinf(PrimExpr x, Span span) {
   DataType t = DataType::Bool(x.dtype().lanes());
@@ -956,6 +972,10 @@ TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc);
 
 TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast);
 
+TVM_REGISTER_GLOBAL("tir.infinity").set_body_typed(tvm::infinity);
+
+TVM_REGISTER_GLOBAL("tir.reinterpret").set_body_typed(tvm::reinterpret);
+
 // operator overloading, smarter than make
 #define REGISTER_MAKE_BINARY_OP(Node, Func)                                                \
   TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \
@@ -1004,6 +1024,8 @@ REGISTER_MAKE_BIT_OP(bitwise_xor, bitwise_xor);
 REGISTER_MAKE_BIT_OP(left_shift, left_shift);  // NOLINT(*)
 REGISTER_MAKE_BIT_OP(right_shift, right_shift);
 
+TVM_REGISTER_GLOBAL("tir._OpNot").set_body_typed(logical_not);
+
 TVM_REGISTER_GLOBAL("tir._OpIfThenElse")
     .set_body_typed([](PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span) {
       return if_then_else(cond, true_value, false_value, span);