You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ru...@apache.org on 2022/11/11 16:40:27 UTC

[tvm] branch main updated: [IRBuilder][Minor] Add intrinsics like `T.int32x4` (#13361)

This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8897983484 [IRBuilder][Minor] Add intrinsics like `T.int32x4` (#13361)
8897983484 is described below

commit 88979834842115ef9ea8487d9a631dc3275f7a7d
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Fri Nov 11 08:40:17 2022 -0800

    [IRBuilder][Minor] Add intrinsics like `T.int32x4` (#13361)
    
    This PR adds all common TIR intrinsics like `T.int32x4`, `T.floatx4`.
    
    Co-authored-by: Yaxing Cai <ca...@gmail.com>
---
 include/tvm/script/ir_builder/tir/frame.h          |  16 +-
 include/tvm/script/ir_builder/tir/ir.h             |  46 +-
 python/tvm/script/ir_builder/tir/frame.py          |   4 +-
 python/tvm/script/ir_builder/tir/ir.py             | 473 +++++++++------------
 python/tvm/tir/op.py                               |  57 ++-
 src/script/ir_builder/tir/frame.cc                 |  14 +-
 src/script/ir_builder/tir/ir.cc                    |  56 ++-
 .../python/unittest/test_tvmscript_error_report.py |   2 +-
 .../unittest/test_tvmscript_ir_builder_tir.py      |  21 +-
 9 files changed, 348 insertions(+), 341 deletions(-)

diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h
index aa2386e7f1..b95d575360 100644
--- a/include/tvm/script/ir_builder/tir/frame.h
+++ b/include/tvm/script/ir_builder/tir/frame.h
@@ -453,8 +453,8 @@ class AllocateFrameNode : public TIRFrameNode {
   PrimExpr condition;
   /*! \brief Additional annotation hints. */
   Map<String, ObjectRef> annotations;
-  /*! \brief The buffer. */
-  tvm::tir::Buffer buffer;
+  /*! \brief The buffer var. */
+  tvm::tir::Var buffer_var;
 
   void VisitAttrs(tvm::AttrVisitor* v) {
     TIRFrameNode::VisitAttrs(v);
@@ -463,7 +463,7 @@ class AllocateFrameNode : public TIRFrameNode {
     v->Visit("storage_scope", &storage_scope);
     v->Visit("condition", &condition);
     v->Visit("annotations", &annotations);
-    v->Visit("buffer", &buffer);
+    v->Visit("buffer_var", &buffer_var);
   }
 
   static constexpr const char* _type_key = "script.ir_builder.tir.AllocateFrame";
@@ -500,8 +500,8 @@ class AllocateConstFrameNode : public TIRFrameNode {
   Array<PrimExpr> extents;
   /*! \brief The data associated with the constant. */
   tvm::runtime::NDArray data;
-  /*! \brief The buffer */
-  tvm::tir::Buffer buffer;
+  /*! \brief The buffer var */
+  tvm::tir::Var buffer_var;
   /*! \brief Additional annotations about the allocation. */
   Map<String, ObjectRef> annotations;
 
@@ -510,7 +510,7 @@ class AllocateConstFrameNode : public TIRFrameNode {
     v->Visit("dtype", &dtype);
     v->Visit("extents", &extents);
     v->Visit("data", &data);
-    v->Visit("buffer", &buffer);
+    v->Visit("buffer_var", &buffer_var);
     v->Visit("annotations", &annotations);
   }
 
@@ -723,11 +723,15 @@ class ElseFrame : public TIRFrame {
 
 class DeclBufferFrameNode : public TIRFrameNode {
  public:
+  /*! \brief The declared buffer. */
   tvm::tir::Buffer buffer;
+  /*! \brief The buffer allocated or not. */
+  bool allocated;
 
   void VisitAttrs(tvm::AttrVisitor* v) {
     TIRFrameNode::VisitAttrs(v);
     v->Visit("buffer", &buffer);
+    v->Visit("allocated", &allocated);
   }
 
   static constexpr const char* _type_key = "script.ir_builder.tir.DeclBufferFrame";
diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h
index 7460099f94..d9e1a1b490 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -339,9 +339,8 @@ AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_s
  * \param annotations Additional annotation hints.
  * \return The created AllocateConstFrame.
  */
-AllocateConstFrame AllocateConst(
-    NDArray data, DataType dtype, Array<PrimExpr> extents,
-    Map<String, ObjectRef> annotations = NullValue<Map<String, ObjectRef>>());
+AllocateConstFrame AllocateConst(NDArray data, DataType dtype, Array<PrimExpr> extents,
+                                 Optional<Map<String, ObjectRef>> annotations = NullOpt);
 
 /*!
  * \brief Create an attribute.
@@ -449,21 +448,32 @@ PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global");
     return expr.defined() ? tvm::cast(dtype, expr.value()) : tvm::tir::Var("", dtype); \
   }
 
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int8, DataType::Int(8));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int16, DataType::Int(16));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32, DataType::Int(32));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int64, DataType::Int(64));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt8, DataType::UInt(8));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt16, DataType::UInt(16));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt32, DataType::UInt(32));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt64, DataType::UInt(64));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float8, DataType::Float(8));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float16, DataType::Float(16));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float32, DataType::Float(32));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float64, DataType::Float(64));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x4, DataType::Int(32, 4));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x8, DataType::Int(32, 8));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x16, DataType::Int(32, 16));
+#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType) \
+  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##8, FDType(8));      \
+  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##16, FDType(16));    \
+  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##32, FDType(32));    \
+  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##64, FDType(64));
+
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Float, DataType::Float);
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(UInt, DataType::UInt);
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int);
+
+#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(FuncName, FDType, Size) \
+  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x4, FDType(Size, 4));     \
+  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x8, FDType(Size, 8));     \
+  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x16, FDType(Size, 16));   \
+  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x32, FDType(Size, 32));   \
+  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x64, FDType(Size, 64));
+
+#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(DType, FDType) \
+  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##8, FDType, 8);      \
+  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##16, FDType, 16);    \
+  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##32, FDType, 32);    \
+  TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##64, FDType, 64);
+
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, DataType::Float);
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt);
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int);
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle());
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py
index b9b50dfa98..a57c878bd9 100644
--- a/python/tvm/script/ir_builder/tir/frame.py
+++ b/python/tvm/script/ir_builder/tir/frame.py
@@ -69,14 +69,14 @@ class RealizeFrame(TIRFrame):
 class AllocateFrame(TIRFrame):
     def __enter__(self) -> Buffer:
         super().__enter__()
-        return self.buffer
+        return self.buffer_var
 
 
 @_register_object("script.ir_builder.tir.AllocateConstFrame")
 class AllocateConstFrame(TIRFrame):
     def __enter__(self) -> Buffer:
         super().__enter__()
-        return self.buffer
+        return self.buffer_var
 
 
 @_register_object("script.ir_builder.tir.AttrFrame")
diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py
index 4ec1511f29..bd9e4e1db5 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -14,41 +14,75 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=missing-docstring
 """IRBuilder for TIR"""
 
-import inspect
 import functools
+import inspect
 from numbers import Integral
-from typing import Any, Callable, Dict, List, Optional, Union, Tuple
-import numpy as np  # type: ignore
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+# isort: off
+from typing_extensions import Literal
 
+# isort: on
+
+import numpy as np  # type: ignore
 from tvm.ir import Range, Type
 from tvm.runtime import convert, ndarray
+from tvm.target import Target
+
+# pylint: disable=unused-import
 from tvm.target.codegen import llvm_lookup_intrinsic_id
-from tvm.tir import (
-    Buffer,
+from tvm.tir import Buffer, BufferRegion, PrimExpr
+from tvm.tir import op as _tir_op
+from tvm.tir import type_annotation
+
+# import tir.expr for direct ir construction to pass structural_equal comparison
+from tvm.tir.expr import (
+    EQ,
+    GE,
+    GT,
+    LE,
+    LT,
+    NE,
+    Add,
+    And,
+    Broadcast,
     BufferLoad,
-    BufferRegion,
+    Call,
+    CallEffectKind,
     Cast,
     CommReducer,
+    Div,
+    FloatImm,
+    FloorDiv,
+    FloorMod,
     IntImm,
     IterVar,
     Let,
-    PrimExpr,
+    Load,
+    Max,
+    Min,
+    Mod,
+    Mul,
+    Not,
+    Or,
+    ProducerLoad,
+    Ramp,
+    Reduce,
     Select,
     Shuffle,
+    SizeVar,
     StringImm,
-    type_annotation,
+    Sub,
     Var,
 )
-from tvm.tir import Broadcast as broadcast
-from tvm.tir import Ramp as ramp
-from tvm.tir import op as _tir_op
 from tvm.tir.generic import cast
 
 from . import _ffi_api, frame
 
+# pylint: enable=unused-import
+
 
 def buffer_decl(
     shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral],
@@ -56,7 +90,7 @@ def buffer_decl(
     data: Var = None,
     strides: List[PrimExpr] = None,
     elem_offset: PrimExpr = None,
-    scope: str = "",
+    scope: str = "global",
     align: int = 0,
     offset_factor: int = 0,
     buffer_type: str = "",
@@ -187,7 +221,7 @@ def func_ret(ret_type: Type) -> Type:
 
 def match_buffer(
     param: Union[Var, BufferLoad, BufferRegion],
-    shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral],
+    shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] = None,
     dtype: str = "float32",
     data: Var = None,
     strides: List[PrimExpr] = None,
@@ -256,6 +290,12 @@ def match_buffer(
     res : Buffer
         The matched buffer.
     """
+    if shape is None:
+        if isinstance(param, BufferRegion):
+            dtype = param.buffer.dtype
+            shape = [region.extent for region in param.region]
+        else:
+            raise ValueError("Shape must be specified when binding input param")
     shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
     if strides is None:
         strides = []
@@ -447,7 +487,7 @@ def alloc_buffer(
     data: Var = None,
     strides: List[PrimExpr] = None,
     elem_offset: PrimExpr = None,
-    scope: str = "",
+    scope: str = "global",
     align: int = -1,
     offset_factor: int = 0,
     buffer_type: str = "default",
@@ -526,10 +566,14 @@ def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range:
         return dom
     if isinstance(dom, (list, tuple)):
         return Range(dom[0], dom[1])
+    if hasattr(dom, "dtype"):
+        return Range(IntImm(dom.dtype, 0), dom)
     return Range(0, dom)
 
 
 class axis:  # pylint: disable=invalid-name
+    """The axis class"""
+
     @staticmethod
     def spatial(
         dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32"
@@ -686,7 +730,10 @@ def serial(
     """
     if stop is None:
         stop = start
-        start = 0
+        if hasattr(start, "dtype"):
+            start = IntImm(start.dtype, 0)
+        else:
+            start = 0
     return _ffi_api.Serial(start, stop, annotations)  # type: ignore[attr-defined] # pylint: disable=no-member
 
 
@@ -713,7 +760,10 @@ def parallel(
     """
     if stop is None:
         stop = start
-        start = 0
+        if hasattr(start, "dtype"):
+            start = IntImm(start.dtype, 0)
+        else:
+            start = 0
     return _ffi_api.Parallel(start, stop, annotations)  # type: ignore[attr-defined] # pylint: disable=no-member
 
 
@@ -740,7 +790,10 @@ def vectorized(
     """
     if stop is None:
         stop = start
-        start = 0
+        if hasattr(start, "dtype"):
+            start = IntImm(start.dtype, 0)
+        else:
+            start = 0
     return _ffi_api.Vectorized(start, stop, annotations)  # type: ignore[attr-defined] # pylint: disable=no-member
 
 
@@ -767,7 +820,10 @@ def unroll(
     """
     if stop is None:
         stop = start
-        start = 0
+        if hasattr(start, "dtype"):
+            start = IntImm(start.dtype, 0)
+        else:
+            start = 0
     return _ffi_api.Unroll(start, stop, annotations)  # type: ignore[attr-defined] # pylint: disable=no-member
 
 
@@ -804,10 +860,16 @@ def thread_binding(
             raise ValueError("Thread cannot be None for thread_binding")
         thread = stop
         stop = start
-        start = 0
+        if hasattr(start, "dtype"):
+            start = IntImm(start.dtype, 0)
+        else:
+            start = 0
     elif stop is None:
         stop = start
-        start = 0
+        if hasattr(start, "dtype"):
+            start = IntImm(start.dtype, 0)
+        else:
+            start = 0
     return _ffi_api.ThreadBinding(  # type: ignore[attr-defined] # pylint: disable=no-member
         start, stop, thread, annotations
     )
@@ -907,7 +969,7 @@ def realize(
 def allocate(
     extents: List[PrimExpr],
     dtype: str,
-    scope: str = "",
+    scope: str = "global",
     condition: PrimExpr = None,
     annotations=None,
 ) -> frame.AllocateFrame:
@@ -959,9 +1021,18 @@ def allocate_const(
     annotations : Optional[Map]
         Additional annotations about the allocation.
     """
+    np_data = np.asarray(data, dtype=dtype)
+    prod_extent = 1
+    for extent in extents:
+        prod_extent *= extent
+    prod_shape = 1
+    for shape in np_data.shape:
+        prod_shape *= shape
+    if prod_extent == prod_shape:
+        np_data = np_data.reshape(extents)
 
     return _ffi_api.AllocateConst(  # type: ignore[attr-defined] # pylint: disable=no-member
-        ndarray.array(np.asarray(data, dtype)), dtype, extents, annotations
+        ndarray.array(np_data), dtype, extents, annotations
     )
 
 
@@ -1054,7 +1125,7 @@ def decl_buffer(
     data=None,
     strides=None,
     elem_offset=None,
-    scope="",
+    scope="global",
     align=0,
     offset_factor=0,
     buffer_type="",
@@ -1221,247 +1292,41 @@ def evaluate(value: PrimExpr) -> None:
     """
     if isinstance(value, str):
         value = StringImm(value)
+    if isinstance(value, bool):
+        value = cast(value, "bool")
     return _ffi_api.Evaluate(value)  # type: ignore[attr-defined] # pylint: disable=no-member
 
 
-def int8(expr: Optional[PrimExpr] = None) -> PrimExpr:
-    """Construct a new tir.Var with type int8 or cast expression to type int8.
+__all__ = []
+for _dtype in ["Float", "UInt", "Int"]:
+    for _size in ["8", "16", "32", "64"]:
+        for _lanes in ["", "x4", "x8", "x16", "x32", "x64"]:
+            _name = _dtype + _size + _lanes  # pylint: disable=invalid-name
 
-    Parameters
-    ----------
-    expr: PrimExpr
-        The expression to be cast.
+            def func_gen(name: str):
+                """Generate a function for each PrimExpr dtype.
 
-    Returns
-    -------
-    res : PrimExpr
-        The new tir.Var with type int8 or casted expression with type int8.
-    """
-    return _ffi_api.Int8(expr)  # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def int16(expr: Optional[PrimExpr] = None) -> PrimExpr:
-    """Construct a new tir.Var with type int16 or cast expression to type int16.
+                Parameters
+                ----------
+                name: str
+                    The ffi function name to call.
+                """
 
-    Parameters
-    ----------
-    expr: PrimExpr
-        The expression to be cast.
+                def func(
+                    expr: Union[
+                        None,
+                        PrimExpr,
+                        Literal["inf", "-inf", "nan"],
+                    ] = None
+                ) -> PrimExpr:
+                    if isinstance(expr, str):
+                        expr = float(expr)
+                    return getattr(_ffi_api, name)(expr)
 
-    Returns
-    -------
-    res : PrimExpr
-        The new tir.Var with type int16 or casted expression with type int16.
-    """
-    return _ffi_api.Int16(expr)  # type: ignore[attr-defined] # pylint: disable=no-member
+                return func
 
-
-def int32(expr: Optional[PrimExpr] = None) -> PrimExpr:
-    """Construct a new tir.Var with type int32 or cast expression to type int32.
-
-    Parameters
-    ----------
-    expr: PrimExpr
-        The expression to be cast.
-
-    Returns
-    -------
-    res : PrimExpr
-        The new tir.Var with type int32 or casted expression with type int32.
-    """
-    return _ffi_api.Int32(expr)  # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def int64(expr: Optional[PrimExpr] = None) -> PrimExpr:
-    """Construct a new tir.Var with type int64 or cast expression to type int64.
-
-    Parameters
-    ----------
-    expr: PrimExpr
-        The expression to be cast.
-
-    Returns
-    -------
-    res : PrimExpr
-        The new tir.Var with type int64 or casted expression with type int64.
-    """
-    return _ffi_api.Int64(expr)  # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def uint8(expr: Optional[PrimExpr] = None) -> PrimExpr:
-    """Construct a new tir.Var with type uint8 or cast expression to type uint8.
-
-    Parameters
-    ----------
-    expr: PrimExpr
-        The expression to be cast.
-
-    Returns
-    -------
-    res : PrimExpr
-        The new tir.Var with type uint8 or casted expression with type uint8.
-    """
-    return _ffi_api.UInt8(expr)  # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def uint16(expr: Optional[PrimExpr] = None) -> PrimExpr:
-    """Construct a new tir.Var with type uint16 or cast expression to type uint16.
-
-    Parameters
-    ----------
-    expr: PrimExpr
-        The expression to be cast.
-
-    Returns
-    -------
-    res : PrimExpr
-        The new tir.Var with type uint16 or casted expression with type uint16.
-    """
-    return _ffi_api.UInt16(expr)  # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def uint32(expr: Optional[PrimExpr] = None) -> PrimExpr:
-    """Construct a new tir.Var with type uint32 or cast expression to type uint32.
-
-    Parameters
-    ----------
-    expr: PrimExpr
-        The expression to be cast.
-
-    Returns
-    -------
-    res : PrimExpr
-        The new tir.Var with type uint32 or casted expression with type uint32.
-    """
-    return _ffi_api.UInt32(expr)  # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def uint64(expr: Optional[PrimExpr] = None) -> PrimExpr:
-    """Construct a new tir.Var with type uint64 or cast expression to type uint64.
-
-    Parameters
-    ----------
-    expr: PrimExpr
-        The expression to be cast.
-
-    Returns
-    -------
-    res : PrimExpr
-        The new tir.Var with type uint64 or casted expression with type uint64.
-    """
-    return _ffi_api.UInt64(expr)  # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def float8(expr: Optional[PrimExpr] = None) -> PrimExpr:
-    """Construct a new tir.Var with type float8 or cast expression to type float8.
-
-    Parameters
-    ----------
-    expr: PrimExpr
-        The expression to be cast.
-
-    Returns
-    -------
-    res : PrimExpr
-        The new tir.Var with type float8 or casted expression with type float8.
-    """
-    return _ffi_api.Float8(expr)  # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def float16(expr: Optional[PrimExpr] = None) -> PrimExpr:
-    """Construct a new tir.Var with type float16 or cast expression to type float16.
-
-    Parameters
-    ----------
-    expr: PrimExpr
-        The expression to be cast.
-
-    Returns
-    -------
-    res : PrimExpr
-        The new tir.Var with type float16 or casted expression with type float16.
-    """
-    return _ffi_api.Float16(expr)  # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def float32(expr: Optional[PrimExpr] = None) -> PrimExpr:
-    """Construct a new tir.Var with type float32 or cast expression to type float32.
-
-    Parameters
-    ----------
-    expr: PrimExpr
-        The expression to be cast.
-
-    Returns
-    -------
-    res : PrimExpr
-        The new tir.Var with type float32 or casted expression with type float32.
-    """
-    return _ffi_api.Float32(expr)  # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def float64(expr: Optional[PrimExpr] = None) -> PrimExpr:
-    """Construct a new tir.Var with type float64 or cast expression to type float64.
-
-    Parameters
-    ----------
-    expr: PrimExpr
-        The expression to be cast.
-
-    Returns
-    -------
-    res : PrimExpr
-        The new tir.Var with type float64 or casted expression with type float64.
-    """
-    return _ffi_api.Float64(expr)  # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def int32x4(expr: Optional[PrimExpr] = None) -> PrimExpr:
-    """Construct a new tir.Var with type int32x4 or cast expression to type int32x4.
-
-    Parameters
-    ----------
-    expr: PrimExpr
-        The expression to be cast.
-
-    Returns
-    -------
-    res : PrimExpr
-        The new tir.Var with type int32x4 or casted expression with type int32x4.
-    """
-    return _ffi_api.Int32x4(expr)  # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def int32x8(expr: Optional[PrimExpr] = None) -> PrimExpr:
-    """Construct a new tir.Var with type int32x8 or cast expression to type int32x8.
-
-    Parameters
-    ----------
-    expr: PrimExpr
-        The expression to be cast.
-
-    Returns
-    -------
-    res : PrimExpr
-        The new tir.Var with type int32x8 or casted expression with type int32x8.
-    """
-    return _ffi_api.Int32x8(expr)  # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def int32x16(expr: Optional[PrimExpr] = None) -> PrimExpr:
-    """Construct a new tir.Var with type int32x16 or cast expression to type int32x16.
-
-    Parameters
-    ----------
-    expr: PrimExpr
-        The expression to be cast.
-
-    Returns
-    -------
-    res : PrimExpr
-        The new tir.Var with type int32x16 or casted expression with type int32x16.
-    """
-    return _ffi_api.Int32x16(expr)  # type: ignore[attr-defined] # pylint: disable=no-member
+            globals()[_name.lower()] = func_gen(_name)
+            __all__.append(_name.lower())
 
 
 def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -1645,6 +1510,27 @@ def comm_reducer(combiner: Callable, identity: List[PrimExpr]) -> CommReducer:
     return CommReducer(args[: num_args // 2], args[num_args // 2 :], res, identity)
 
 
+def target(target_config: Union[Dict, str]) -> Target:
+    """
+    Create a target
+
+    Parameters
+    ----------
+    target_config : Union[Dict, str]
+        The target configuration.
+
+    Returns
+    -------
+    res : Target
+        The target.
+    """
+    if not isinstance(target_config, (str, dict)):
+        raise ValueError(
+            f"T.target expected a config dict or string, but got {type(target_config)}"
+        )
+    return Target(target_config)
+
+
 def _op_wrapper(func):
     @functools.wraps(func)
     def wrapped(*args, **kwargs):
@@ -1667,6 +1553,9 @@ def _dtype_forward(func):
 
 # pylint: disable=invalid-name
 
+broadcast = Broadcast
+ramp = Ramp
+
 buffer_var = ptr
 abs = _op_wrapper(_tir_op.abs)  # pylint: disable=redefined-builtin
 fabs = abs
@@ -1713,6 +1602,7 @@ nextafter = _op_wrapper(_tir_op.nextafter)
 popcount = _op_wrapper(_tir_op.popcount)
 power = _op_wrapper(_tir_op.power)
 q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift)
+q_multiply_shift_per_axis = _op_wrapper(_tir_op.q_multiply_shift_per_axis)
 ret = _op_wrapper(_tir_op.ret)
 reinterpret = _dtype_forward(_tir_op.reinterpret)
 round = _op_wrapper(_tir_op.round)  # pylint: disable=redefined-builtin
@@ -1733,6 +1623,7 @@ tvm_throw_last_error = _op_wrapper(_tir_op.tvm_throw_last_error)
 tvm_stack_alloca = _op_wrapper(_tir_op.tvm_stack_alloca)
 tvm_stack_make_shape = _op_wrapper(_tir_op.tvm_stack_make_shape)
 tvm_stack_make_array = _op_wrapper(_tir_op.tvm_stack_make_array)
+tvm_check_return = _op_wrapper(_tir_op.tvm_check_return)
 call_packed = _op_wrapper(_tir_op.call_packed)
 call_cpacked = _op_wrapper(_tir_op.call_cpacked)
 call_packed_lowered = _op_wrapper(_tir_op.call_packed_lowered)
@@ -1742,7 +1633,6 @@ call_intrin = _dtype_forward(_tir_op.call_intrin)
 call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin)
 call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin)
 call_pure_extern = _dtype_forward(_tir_op.call_pure_extern)
-tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr)
 tvm_tuple = _op_wrapper(_tir_op.tvm_tuple)
 tvm_struct_set = _op_wrapper(_tir_op.tvm_struct_set)
 tvm_struct_get = _tir_op.tvm_struct_get
@@ -1771,6 +1661,8 @@ tvm_call_packed_lowered = call_packed_lowered
 tvm_call_cpacked_lowered = call_cpacked_lowered
 TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace)
 TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace)
+start_profile_intrinsic = _op_wrapper(_tir_op.start_profile_intrinsic)
+end_profile_intrinsic = _op_wrapper(_tir_op.end_profile_intrinsic)
 
 
 class inline:
@@ -1796,7 +1688,7 @@ class inline:
 # pylint: enable=invalid-name
 
 
-__all__ = [
+__all__ += [
     "buffer_decl",
     "prim_func",
     "arg",
@@ -1835,21 +1727,6 @@ __all__ = [
     "buffer_store",
     "prefetch",
     "evaluate",
-    "int8",
-    "int16",
-    "int32",
-    "int64",
-    "uint8",
-    "uint16",
-    "uint32",
-    "uint64",
-    "float8",
-    "float16",
-    "float32",
-    "float64",
-    "int32x4",
-    "int32x8",
-    "int32x16",
     "boolean",
     "handle",
     "void",
@@ -1859,6 +1736,7 @@ __all__ = [
     "max",
     "iter_var",
     "comm_reducer",
+    "target",
     "buffer_var",
     "abs",
     "fabs",
@@ -1905,6 +1783,7 @@ __all__ = [
     "popcount",
     "power",
     "q_multiply_shift",
+    "q_multiply_shift_per_axis",
     "ret",
     "reinterpret",
     "round",
@@ -1925,6 +1804,7 @@ __all__ = [
     "tvm_stack_alloca",
     "tvm_stack_make_shape",
     "tvm_stack_make_array",
+    "tvm_check_return",
     "call_packed",
     "call_cpacked",
     "call_packed_lowered",
@@ -1934,7 +1814,6 @@ __all__ = [
     "call_llvm_intrin",
     "call_llvm_pure_intrin",
     "call_pure_extern",
-    "tvm_access_ptr",
     "tvm_tuple",
     "tvm_struct_set",
     "tvm_struct_get",
@@ -1963,14 +1842,50 @@ __all__ = [
     "tvm_call_cpacked_lowered",
     "TVMBackendAllocWorkspace",
     "TVMBackendFreeWorkspace",
+    "start_profile_intrinsic",
+    "end_profile_intrinsic",
     "inline",
     "llvm_lookup_intrinsic_id",
-    "Cast",
-    "Let",
-    "Select",
-    "Shuffle",
     "type_annotation",
     "broadcast",
     "ramp",
     "cast",
+    # tvm.tir.expr
+    "Var",
+    "SizeVar",
+    "Reduce",
+    "FloatImm",
+    "IntImm",
+    "StringImm",
+    "Cast",
+    "Add",
+    "Sub",
+    "Mul",
+    "Div",
+    "Mod",
+    "FloorDiv",
+    "FloorMod",
+    "Min",
+    "Max",
+    "EQ",
+    "NE",
+    "LT",
+    "LE",
+    "GT",
+    "GE",
+    "And",
+    "Or",
+    "Not",
+    "Select",
+    "BufferLoad",
+    "ProducerLoad",
+    "Load",
+    "Ramp",
+    "Broadcast",
+    "Shuffle",
+    "Call",
+    "CallEffectKind",
+    "Let",
+    "IterVar",
+    "CommReducer",
 ]
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 588b40ae40..e1adc0a6bb 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -18,14 +18,15 @@
 """Operators used in TIR expression."""
 import warnings
 from typing import Any, Optional
+
 import tvm._ffi
-from tvm.ir.base import Span
-from tvm.runtime import convert, const
 from tvm.ir import Array, Op, PrimExpr
+from tvm.ir.base import Span
+from tvm.runtime import const, convert
 
-from .buffer import Buffer
-from .expr import Call, PrimExprWithOp, StringImm, Var, CommReducer, IntImm
 from . import _ffi_api
+from .buffer import Buffer
+from .expr import Call, CommReducer, IntImm, PrimExprWithOp, StringImm, Var
 
 
 def _pack_buffer(buf, span=None):
@@ -322,6 +323,24 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
     )
 
 
+def tvm_check_return(expected, return_unexpected, nested_call):
+    """Return new on stack dtype[num]
+    Parameters
+    ----------
+    expected : int
+        The expected return code.
+    return_unexpected : int
+        The unexpected return code.
+    nested_call : PrimExpr
+        The call expression to check return.
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("int32", "tir.tvm_check_return", expected, return_unexpected, nested_call)
+
+
 def tvm_stack_alloca(dtype_str, num):
     """Return new on stack dtype[num]
 
@@ -403,7 +422,7 @@ def assume(cond=None):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("int32", "tir.assume", cond)
+    return call_intrin("bool", "tir.assume", cond)
 
 
 def undef():
@@ -417,6 +436,34 @@ def undef():
     return call_intrin("int32", "tir.undef")
 
 
+def start_profile_intrinsic(id):
+    """Start profile intrinsic.
+    Parameters
+    ----------
+    id : int
+        The intrinsic id.
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("handle", "tir.start_profile_intrinsic", id)
+
+
+def end_profile_intrinsic(id):
+    """End profile intrinsic.
+    Parameters
+    ----------
+    id : int
+        The intrinsic id.
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("handle", "tir.end_profile_intrinsic", id)
+
+
 def tvm_tuple(*value):
     """Create a tuple structure in value field of AttrStmt
 
diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc
index aa9efa653f..f48ee52506 100644
--- a/src/script/ir_builder/tir/frame.cc
+++ b/src/script/ir_builder/tir/frame.cc
@@ -117,14 +117,14 @@ void LaunchThreadFrameNode::ExitWithScope() {
 
 void AllocateFrameNode::ExitWithScope() {
   TIRFrameNode::ExitWithScope();
-  AddToParent(tvm::tir::Allocate(buffer->data, buffer->dtype, buffer->shape, condition,
-                                 AsStmt(stmts), annotations));
+  AddToParent(
+      tvm::tir::Allocate(buffer_var, dtype, extents, condition, AsStmt(stmts), annotations));
 }
 
 void AllocateConstFrameNode::ExitWithScope() {
   TIRFrameNode::ExitWithScope();
   AddToParent(
-      tvm::tir::AllocateConst(buffer->data, dtype, extents, data, AsStmt(stmts), annotations));
+      tvm::tir::AllocateConst(buffer_var, dtype, extents, data, AsStmt(stmts), annotations));
 }
 void AttrFrameNode::ExitWithScope() {
   TIRFrameNode::ExitWithScope();
@@ -182,7 +182,13 @@ void ElseFrameNode::ExitWithScope() {
 
 void DeclBufferFrameNode::ExitWithScope() {
   TIRFrameNode::ExitWithScope();
-  AddToParent(tvm::tir::DeclBuffer(buffer, AsStmt(stmts)));
+  if (allocated) {
+    AddToParent(tvm::tir::DeclBuffer(buffer, AsStmt(stmts)));
+  } else {
+    AddToParent(tvm::tir::Allocate(buffer->data, buffer->dtype, buffer->shape,
+                                   tvm::IntImm(DataType::Bool(), 1),
+                                   tvm::tir::DeclBuffer(buffer, AsStmt(stmts))));
+  }
 }
 
 TVM_REGISTER_NODE_TYPE(TIRFrameNode);
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 6be6e2619f..78107136d4 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -452,20 +452,19 @@ AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_s
   n->storage_scope = storage_scope;
   n->condition = condition.value_or(tvm::Bool(true));
   n->annotations = annotations.value_or(Map<String, ObjectRef>());
-  n->buffer = BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, storage_scope, 0, 0,
-                         "default", NullOpt);
+  n->buffer_var = Var("", tvm::PointerType(tvm::PrimType(dtype), storage_scope));
   return AllocateFrame(n);
 }
 
 AllocateConstFrame AllocateConst(tvm::runtime::NDArray data, DataType dtype,
-                                 Array<PrimExpr> extents, Map<String, ObjectRef> annotations) {
+                                 Array<PrimExpr> extents,
+                                 Optional<Map<String, ObjectRef>> annotations) {
   ObjectPtr<AllocateConstFrameNode> n = make_object<AllocateConstFrameNode>();
   n->dtype = dtype;
   n->extents = extents;
   n->data = data;
-  n->annotations = annotations;
-  n->buffer =
-      BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, "", 0, 0, "default", NullOpt);
+  n->annotations = annotations.value_or(Map<String, ObjectRef>());
+  n->buffer_var = Var("", tvm::PointerType(tvm::PrimType(dtype)));
   return AllocateConstFrame(n);
 }
 
@@ -529,6 +528,7 @@ DeclBufferFrame DeclBuffer(Array<PrimExpr> shape, DataType dtype, String buffer_
   ObjectPtr<DeclBufferFrameNode> n = make_object<DeclBufferFrameNode>();
   n->buffer = BufferDecl(shape, dtype, buffer_name, data, strides, elem_offset, storage_scope,
                          align, offset_factor, buffer_type, axis_separators);
+  n->allocated = data.defined();
   return DeclBufferFrame(n);
 }
 
@@ -638,21 +638,35 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate);
 
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Ptr").set_body_typed(Ptr);
 
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int8").set_body_typed(Int8);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int16").set_body_typed(Int16);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32").set_body_typed(Int32);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int64").set_body_typed(Int64);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.UInt8").set_body_typed(UInt8);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.UInt16").set_body_typed(UInt16);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.UInt32").set_body_typed(UInt32);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.UInt64").set_body_typed(UInt64);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float8").set_body_typed(Float8);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float16").set_body_typed(Float16);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float32").set_body_typed(Float32);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float64").set_body_typed(Float64);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32x4").set_body_typed(Int32x4);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32x8").set_body_typed(Int32x8);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32x16").set_body_typed(Int32x16);
+#define TVM_TMP_STR(x) #x
+
+#define TVM_REGISTER_GLOBAL_SIZE(Prefix, DType)                          \
+  TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(8)).set_body_typed(DType##8);   \
+  TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(16)).set_body_typed(DType##16); \
+  TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(32)).set_body_typed(DType##32); \
+  TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(64)).set_body_typed(DType##64);
+
+TVM_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.Float", Float);
+TVM_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.UInt", UInt);
+TVM_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.Int", Int);
+
+#define TVM_REGISTER_GLOBAL_LANES(Prefix, Func)                           \
+  TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x4)).set_body_typed(Func##x4);   \
+  TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x8)).set_body_typed(Func##x8);   \
+  TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x16)).set_body_typed(Func##x16); \
+  TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x32)).set_body_typed(Func##x32); \
+  TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x64)).set_body_typed(Func##x64);
+
+#define TVM_REGISTER_GLOBAL_SIZES_LANES(Prefix, DType)          \
+  TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(8), DType##8);   \
+  TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(16), DType##16); \
+  TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(32), DType##32); \
+  TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(64), DType##64);
+
+TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float);
+TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt);
+TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int);
+
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void);
diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py
index 36de35fa92..2ec52bfbfe 100644
--- a/tests/python/unittest/test_tvmscript_error_report.py
+++ b/tests/python/unittest/test_tvmscript_error_report.py
@@ -52,7 +52,7 @@ def check_error(func, rel_lineno):
         return
     error = errors[0]
     assert (
-        error.span.line - 1 == rel_lineno
+        error.span.line - 1 == rel_lineno or error.span.line == rel_lineno
     ), f"Expected error to be on line {rel_lineno}, but it was on {error.span.line - 1}"
 
     error_line = source_code.split("\n")[rel_lineno]
diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
index dbc9b594fb..a3df5a183b 100644
--- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
@@ -16,15 +16,15 @@
 # under the License.
 # pylint: disable=invalid-name, missing-docstring
 """Unittests for tvm.script.ir_builder.tir"""
-import pytest
 import numpy as np
+import pytest
 import tvm
 import tvm.testing
 from tvm import tir
+from tvm.ir.base import assert_structural_equal
 from tvm.runtime import ndarray
-from tvm.script.ir_builder import tir as T
 from tvm.script.ir_builder import IRBuilder
-from tvm.ir.base import assert_structural_equal
+from tvm.script.ir_builder import tir as T
 
 
 def test_ir_builder_tir_primfunc_base():
@@ -372,7 +372,12 @@ def test_ir_builder_tir_allocate_const():
     # the expected allocate const
     buffer_var = tir.Var("v", tvm.ir.PointerType(tvm.ir.PrimType("int32")))
     ir_expected = tir.AllocateConst(
-        buffer_var, "int32", [10], ndarray.array(np.asarray(data, "int32")), tir.Evaluate(1)
+        buffer_var,
+        "int32",
+        [10],
+        ndarray.array(np.asarray(data, "int32")),
+        tir.Evaluate(1),
+        annotations={},
     )
 
     # Check if the generated ir is expected
@@ -470,7 +475,13 @@ def test_ir_builder_tir_decl_buffer():
 
     # the expected decl_buffer
     buffer = T.buffer_decl((128, 128), "float32")
-    ir_expected = tir.DeclBuffer(buffer, tir.Evaluate(0))
+    ir_expected = tir.Allocate(
+        buffer.data,
+        "float32",
+        (128, 128),
+        tir.IntImm("bool", True),
+        tir.DeclBuffer(buffer, tir.Evaluate(0)),
+    )
 
     # Check if the generated ir is expected
     assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)