You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ru...@apache.org on 2022/11/11 16:40:27 UTC
[tvm] branch main updated: [IRBuilder][Minor] Add intrinsics like `T.int32x4` (#13361)
This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8897983484 [IRBuilder][Minor] Add intrinsics like `T.int32x4` (#13361)
8897983484 is described below
commit 88979834842115ef9ea8487d9a631dc3275f7a7d
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Fri Nov 11 08:40:17 2022 -0800
[IRBuilder][Minor] Add intrinsics like `T.int32x4` (#13361)
This PR adds all common TIR intrinsics like `T.int32x4`, `T.floatx4`.
Co-authored-by: Yaxing Cai <ca...@gmail.com>
---
include/tvm/script/ir_builder/tir/frame.h | 16 +-
include/tvm/script/ir_builder/tir/ir.h | 46 +-
python/tvm/script/ir_builder/tir/frame.py | 4 +-
python/tvm/script/ir_builder/tir/ir.py | 473 +++++++++------------
python/tvm/tir/op.py | 57 ++-
src/script/ir_builder/tir/frame.cc | 14 +-
src/script/ir_builder/tir/ir.cc | 56 ++-
.../python/unittest/test_tvmscript_error_report.py | 2 +-
.../unittest/test_tvmscript_ir_builder_tir.py | 21 +-
9 files changed, 348 insertions(+), 341 deletions(-)
diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h
index aa2386e7f1..b95d575360 100644
--- a/include/tvm/script/ir_builder/tir/frame.h
+++ b/include/tvm/script/ir_builder/tir/frame.h
@@ -453,8 +453,8 @@ class AllocateFrameNode : public TIRFrameNode {
PrimExpr condition;
/*! \brief Additional annotation hints. */
Map<String, ObjectRef> annotations;
- /*! \brief The buffer. */
- tvm::tir::Buffer buffer;
+ /*! \brief The buffer var. */
+ tvm::tir::Var buffer_var;
void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
@@ -463,7 +463,7 @@ class AllocateFrameNode : public TIRFrameNode {
v->Visit("storage_scope", &storage_scope);
v->Visit("condition", &condition);
v->Visit("annotations", &annotations);
- v->Visit("buffer", &buffer);
+ v->Visit("buffer_var", &buffer_var);
}
static constexpr const char* _type_key = "script.ir_builder.tir.AllocateFrame";
@@ -500,8 +500,8 @@ class AllocateConstFrameNode : public TIRFrameNode {
Array<PrimExpr> extents;
/*! \brief The data associated with the constant. */
tvm::runtime::NDArray data;
- /*! \brief The buffer */
- tvm::tir::Buffer buffer;
+ /*! \brief The buffer var */
+ tvm::tir::Var buffer_var;
/*! \brief Additional annotations about the allocation. */
Map<String, ObjectRef> annotations;
@@ -510,7 +510,7 @@ class AllocateConstFrameNode : public TIRFrameNode {
v->Visit("dtype", &dtype);
v->Visit("extents", &extents);
v->Visit("data", &data);
- v->Visit("buffer", &buffer);
+ v->Visit("buffer_var", &buffer_var);
v->Visit("annotations", &annotations);
}
@@ -723,11 +723,15 @@ class ElseFrame : public TIRFrame {
class DeclBufferFrameNode : public TIRFrameNode {
public:
+ /*! \brief The declared buffer. */
tvm::tir::Buffer buffer;
+ /*! \brief The buffer allocated or not. */
+ bool allocated;
void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("buffer", &buffer);
+ v->Visit("allocated", &allocated);
}
static constexpr const char* _type_key = "script.ir_builder.tir.DeclBufferFrame";
diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h
index 7460099f94..d9e1a1b490 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -339,9 +339,8 @@ AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_s
* \param annotations Additional annotation hints.
* \return The created AllocateConstFrame.
*/
-AllocateConstFrame AllocateConst(
- NDArray data, DataType dtype, Array<PrimExpr> extents,
- Map<String, ObjectRef> annotations = NullValue<Map<String, ObjectRef>>());
+AllocateConstFrame AllocateConst(NDArray data, DataType dtype, Array<PrimExpr> extents,
+ Optional<Map<String, ObjectRef>> annotations = NullOpt);
/*!
* \brief Create an attribute.
@@ -449,21 +448,32 @@ PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global");
return expr.defined() ? tvm::cast(dtype, expr.value()) : tvm::tir::Var("", dtype); \
}
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int8, DataType::Int(8));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int16, DataType::Int(16));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32, DataType::Int(32));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int64, DataType::Int(64));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt8, DataType::UInt(8));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt16, DataType::UInt(16));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt32, DataType::UInt(32));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt64, DataType::UInt(64));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float8, DataType::Float(8));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float16, DataType::Float(16));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float32, DataType::Float(32));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float64, DataType::Float(64));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x4, DataType::Int(32, 4));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x8, DataType::Int(32, 8));
-TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x16, DataType::Int(32, 16));
+#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType) \
+ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##8, FDType(8)); \
+ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##16, FDType(16)); \
+ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##32, FDType(32)); \
+ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##64, FDType(64));
+
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Float, DataType::Float);
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(UInt, DataType::UInt);
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int);
+
+#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(FuncName, FDType, Size) \
+ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x4, FDType(Size, 4)); \
+ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x8, FDType(Size, 8)); \
+ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x16, FDType(Size, 16)); \
+ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x32, FDType(Size, 32)); \
+ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x64, FDType(Size, 64));
+
+#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(DType, FDType) \
+ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##8, FDType, 8); \
+ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##16, FDType, 16); \
+ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##32, FDType, 32); \
+ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##64, FDType, 64);
+
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, DataType::Float);
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt);
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py
index b9b50dfa98..a57c878bd9 100644
--- a/python/tvm/script/ir_builder/tir/frame.py
+++ b/python/tvm/script/ir_builder/tir/frame.py
@@ -69,14 +69,14 @@ class RealizeFrame(TIRFrame):
class AllocateFrame(TIRFrame):
def __enter__(self) -> Buffer:
super().__enter__()
- return self.buffer
+ return self.buffer_var
@_register_object("script.ir_builder.tir.AllocateConstFrame")
class AllocateConstFrame(TIRFrame):
def __enter__(self) -> Buffer:
super().__enter__()
- return self.buffer
+ return self.buffer_var
@_register_object("script.ir_builder.tir.AttrFrame")
diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py
index 4ec1511f29..bd9e4e1db5 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -14,41 +14,75 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=missing-docstring
"""IRBuilder for TIR"""
-import inspect
import functools
+import inspect
from numbers import Integral
-from typing import Any, Callable, Dict, List, Optional, Union, Tuple
-import numpy as np # type: ignore
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+# isort: off
+from typing_extensions import Literal
+# isort: on
+
+import numpy as np # type: ignore
from tvm.ir import Range, Type
from tvm.runtime import convert, ndarray
+from tvm.target import Target
+
+# pylint: disable=unused-import
from tvm.target.codegen import llvm_lookup_intrinsic_id
-from tvm.tir import (
- Buffer,
+from tvm.tir import Buffer, BufferRegion, PrimExpr
+from tvm.tir import op as _tir_op
+from tvm.tir import type_annotation
+
+# import tir.expr for direct ir construction to pass structural_equal comparison
+from tvm.tir.expr import (
+ EQ,
+ GE,
+ GT,
+ LE,
+ LT,
+ NE,
+ Add,
+ And,
+ Broadcast,
BufferLoad,
- BufferRegion,
+ Call,
+ CallEffectKind,
Cast,
CommReducer,
+ Div,
+ FloatImm,
+ FloorDiv,
+ FloorMod,
IntImm,
IterVar,
Let,
- PrimExpr,
+ Load,
+ Max,
+ Min,
+ Mod,
+ Mul,
+ Not,
+ Or,
+ ProducerLoad,
+ Ramp,
+ Reduce,
Select,
Shuffle,
+ SizeVar,
StringImm,
- type_annotation,
+ Sub,
Var,
)
-from tvm.tir import Broadcast as broadcast
-from tvm.tir import Ramp as ramp
-from tvm.tir import op as _tir_op
from tvm.tir.generic import cast
from . import _ffi_api, frame
+# pylint: enable=unused-import
+
def buffer_decl(
shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral],
@@ -56,7 +90,7 @@ def buffer_decl(
data: Var = None,
strides: List[PrimExpr] = None,
elem_offset: PrimExpr = None,
- scope: str = "",
+ scope: str = "global",
align: int = 0,
offset_factor: int = 0,
buffer_type: str = "",
@@ -187,7 +221,7 @@ def func_ret(ret_type: Type) -> Type:
def match_buffer(
param: Union[Var, BufferLoad, BufferRegion],
- shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral],
+ shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] = None,
dtype: str = "float32",
data: Var = None,
strides: List[PrimExpr] = None,
@@ -256,6 +290,12 @@ def match_buffer(
res : Buffer
The matched buffer.
"""
+ if shape is None:
+ if isinstance(param, BufferRegion):
+ dtype = param.buffer.dtype
+ shape = [region.extent for region in param.region]
+ else:
+ raise ValueError("Shape must be specified when binding input param")
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
if strides is None:
strides = []
@@ -447,7 +487,7 @@ def alloc_buffer(
data: Var = None,
strides: List[PrimExpr] = None,
elem_offset: PrimExpr = None,
- scope: str = "",
+ scope: str = "global",
align: int = -1,
offset_factor: int = 0,
buffer_type: str = "default",
@@ -526,10 +566,14 @@ def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range:
return dom
if isinstance(dom, (list, tuple)):
return Range(dom[0], dom[1])
+ if hasattr(dom, "dtype"):
+ return Range(IntImm(dom.dtype, 0), dom)
return Range(0, dom)
class axis: # pylint: disable=invalid-name
+ """The axis class"""
+
@staticmethod
def spatial(
dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32"
@@ -686,7 +730,10 @@ def serial(
"""
if stop is None:
stop = start
- start = 0
+ if hasattr(start, "dtype"):
+ start = IntImm(start.dtype, 0)
+ else:
+ start = 0
return _ffi_api.Serial(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
@@ -713,7 +760,10 @@ def parallel(
"""
if stop is None:
stop = start
- start = 0
+ if hasattr(start, "dtype"):
+ start = IntImm(start.dtype, 0)
+ else:
+ start = 0
return _ffi_api.Parallel(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
@@ -740,7 +790,10 @@ def vectorized(
"""
if stop is None:
stop = start
- start = 0
+ if hasattr(start, "dtype"):
+ start = IntImm(start.dtype, 0)
+ else:
+ start = 0
return _ffi_api.Vectorized(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
@@ -767,7 +820,10 @@ def unroll(
"""
if stop is None:
stop = start
- start = 0
+ if hasattr(start, "dtype"):
+ start = IntImm(start.dtype, 0)
+ else:
+ start = 0
return _ffi_api.Unroll(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
@@ -804,10 +860,16 @@ def thread_binding(
raise ValueError("Thread cannot be None for thread_binding")
thread = stop
stop = start
- start = 0
+ if hasattr(start, "dtype"):
+ start = IntImm(start.dtype, 0)
+ else:
+ start = 0
elif stop is None:
stop = start
- start = 0
+ if hasattr(start, "dtype"):
+ start = IntImm(start.dtype, 0)
+ else:
+ start = 0
return _ffi_api.ThreadBinding( # type: ignore[attr-defined] # pylint: disable=no-member
start, stop, thread, annotations
)
@@ -907,7 +969,7 @@ def realize(
def allocate(
extents: List[PrimExpr],
dtype: str,
- scope: str = "",
+ scope: str = "global",
condition: PrimExpr = None,
annotations=None,
) -> frame.AllocateFrame:
@@ -959,9 +1021,18 @@ def allocate_const(
annotations : Optional[Map]
Additional annotations about the allocation.
"""
+ np_data = np.asarray(data, dtype=dtype)
+ prod_extent = 1
+ for extent in extents:
+ prod_extent *= extent
+ prod_shape = 1
+ for shape in np_data.shape:
+ prod_shape *= shape
+ if prod_extent == prod_shape:
+ np_data = np_data.reshape(extents)
return _ffi_api.AllocateConst( # type: ignore[attr-defined] # pylint: disable=no-member
- ndarray.array(np.asarray(data, dtype)), dtype, extents, annotations
+ ndarray.array(np_data), dtype, extents, annotations
)
@@ -1054,7 +1125,7 @@ def decl_buffer(
data=None,
strides=None,
elem_offset=None,
- scope="",
+ scope="global",
align=0,
offset_factor=0,
buffer_type="",
@@ -1221,247 +1292,41 @@ def evaluate(value: PrimExpr) -> None:
"""
if isinstance(value, str):
value = StringImm(value)
+ if isinstance(value, bool):
+ value = cast(value, "bool")
return _ffi_api.Evaluate(value) # type: ignore[attr-defined] # pylint: disable=no-member
-def int8(expr: Optional[PrimExpr] = None) -> PrimExpr:
- """Construct a new tir.Var with type int8 or cast expression to type int8.
+__all__ = []
+for _dtype in ["Float", "UInt", "Int"]:
+ for _size in ["8", "16", "32", "64"]:
+ for _lanes in ["", "x4", "x8", "x16", "x32", "x64"]:
+ _name = _dtype + _size + _lanes # pylint: disable=invalid-name
- Parameters
- ----------
- expr: PrimExpr
- The expression to be cast.
+ def func_gen(name: str):
+ """Generate a function for each PrimExpr dtype.
- Returns
- -------
- res : PrimExpr
- The new tir.Var with type int8 or casted expression with type int8.
- """
- return _ffi_api.Int8(expr) # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def int16(expr: Optional[PrimExpr] = None) -> PrimExpr:
- """Construct a new tir.Var with type int16 or cast expression to type int16.
+ Parameters
+ ----------
+ name: str
+ The ffi function name to call.
+ """
- Parameters
- ----------
- expr: PrimExpr
- The expression to be cast.
+ def func(
+ expr: Union[
+ None,
+ PrimExpr,
+ Literal["inf", "-inf", "nan"],
+ ] = None
+ ) -> PrimExpr:
+ if isinstance(expr, str):
+ expr = float(expr)
+ return getattr(_ffi_api, name)(expr)
- Returns
- -------
- res : PrimExpr
- The new tir.Var with type int16 or casted expression with type int16.
- """
- return _ffi_api.Int16(expr) # type: ignore[attr-defined] # pylint: disable=no-member
+ return func
-
-def int32(expr: Optional[PrimExpr] = None) -> PrimExpr:
- """Construct a new tir.Var with type int32 or cast expression to type int32.
-
- Parameters
- ----------
- expr: PrimExpr
- The expression to be cast.
-
- Returns
- -------
- res : PrimExpr
- The new tir.Var with type int32 or casted expression with type int32.
- """
- return _ffi_api.Int32(expr) # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def int64(expr: Optional[PrimExpr] = None) -> PrimExpr:
- """Construct a new tir.Var with type int64 or cast expression to type int64.
-
- Parameters
- ----------
- expr: PrimExpr
- The expression to be cast.
-
- Returns
- -------
- res : PrimExpr
- The new tir.Var with type int64 or casted expression with type int64.
- """
- return _ffi_api.Int64(expr) # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def uint8(expr: Optional[PrimExpr] = None) -> PrimExpr:
- """Construct a new tir.Var with type uint8 or cast expression to type uint8.
-
- Parameters
- ----------
- expr: PrimExpr
- The expression to be cast.
-
- Returns
- -------
- res : PrimExpr
- The new tir.Var with type uint8 or casted expression with type uint8.
- """
- return _ffi_api.UInt8(expr) # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def uint16(expr: Optional[PrimExpr] = None) -> PrimExpr:
- """Construct a new tir.Var with type uint16 or cast expression to type uint16.
-
- Parameters
- ----------
- expr: PrimExpr
- The expression to be cast.
-
- Returns
- -------
- res : PrimExpr
- The new tir.Var with type uint16 or casted expression with type uint16.
- """
- return _ffi_api.UInt16(expr) # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def uint32(expr: Optional[PrimExpr] = None) -> PrimExpr:
- """Construct a new tir.Var with type uint32 or cast expression to type uint32.
-
- Parameters
- ----------
- expr: PrimExpr
- The expression to be cast.
-
- Returns
- -------
- res : PrimExpr
- The new tir.Var with type uint32 or casted expression with type uint32.
- """
- return _ffi_api.UInt32(expr) # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def uint64(expr: Optional[PrimExpr] = None) -> PrimExpr:
- """Construct a new tir.Var with type uint64 or cast expression to type uint64.
-
- Parameters
- ----------
- expr: PrimExpr
- The expression to be cast.
-
- Returns
- -------
- res : PrimExpr
- The new tir.Var with type uint64 or casted expression with type uint64.
- """
- return _ffi_api.UInt64(expr) # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def float8(expr: Optional[PrimExpr] = None) -> PrimExpr:
- """Construct a new tir.Var with type float8 or cast expression to type float8.
-
- Parameters
- ----------
- expr: PrimExpr
- The expression to be cast.
-
- Returns
- -------
- res : PrimExpr
- The new tir.Var with type float8 or casted expression with type float8.
- """
- return _ffi_api.Float8(expr) # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def float16(expr: Optional[PrimExpr] = None) -> PrimExpr:
- """Construct a new tir.Var with type float16 or cast expression to type float16.
-
- Parameters
- ----------
- expr: PrimExpr
- The expression to be cast.
-
- Returns
- -------
- res : PrimExpr
- The new tir.Var with type float16 or casted expression with type float16.
- """
- return _ffi_api.Float16(expr) # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def float32(expr: Optional[PrimExpr] = None) -> PrimExpr:
- """Construct a new tir.Var with type float32 or cast expression to type float32.
-
- Parameters
- ----------
- expr: PrimExpr
- The expression to be cast.
-
- Returns
- -------
- res : PrimExpr
- The new tir.Var with type float32 or casted expression with type float32.
- """
- return _ffi_api.Float32(expr) # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def float64(expr: Optional[PrimExpr] = None) -> PrimExpr:
- """Construct a new tir.Var with type float64 or cast expression to type float64.
-
- Parameters
- ----------
- expr: PrimExpr
- The expression to be cast.
-
- Returns
- -------
- res : PrimExpr
- The new tir.Var with type float64 or casted expression with type float64.
- """
- return _ffi_api.Float64(expr) # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def int32x4(expr: Optional[PrimExpr] = None) -> PrimExpr:
- """Construct a new tir.Var with type int32x4 or cast expression to type int32x4.
-
- Parameters
- ----------
- expr: PrimExpr
- The expression to be cast.
-
- Returns
- -------
- res : PrimExpr
- The new tir.Var with type int32x4 or casted expression with type int32x4.
- """
- return _ffi_api.Int32x4(expr) # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def int32x8(expr: Optional[PrimExpr] = None) -> PrimExpr:
- """Construct a new tir.Var with type int32x8 or cast expression to type int32x8.
-
- Parameters
- ----------
- expr: PrimExpr
- The expression to be cast.
-
- Returns
- -------
- res : PrimExpr
- The new tir.Var with type int32x8 or casted expression with type int32x8.
- """
- return _ffi_api.Int32x8(expr) # type: ignore[attr-defined] # pylint: disable=no-member
-
-
-def int32x16(expr: Optional[PrimExpr] = None) -> PrimExpr:
- """Construct a new tir.Var with type int32x16 or cast expression to type int32x16.
-
- Parameters
- ----------
- expr: PrimExpr
- The expression to be cast.
-
- Returns
- -------
- res : PrimExpr
- The new tir.Var with type int32x16 or casted expression with type int32x16.
- """
- return _ffi_api.Int32x16(expr) # type: ignore[attr-defined] # pylint: disable=no-member
+ globals()[_name.lower()] = func_gen(_name)
+ __all__.append(_name.lower())
def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr:
@@ -1645,6 +1510,27 @@ def comm_reducer(combiner: Callable, identity: List[PrimExpr]) -> CommReducer:
return CommReducer(args[: num_args // 2], args[num_args // 2 :], res, identity)
+def target(target_config: Union[Dict, str]) -> Target:
+ """
+ Create a target
+
+ Parameters
+ ----------
+ target_config : Union[Dict, str]
+ The target configuration.
+
+ Returns
+ -------
+ res : Target
+ The target.
+ """
+ if not isinstance(target_config, (str, dict)):
+ raise ValueError(
+ f"T.target expected a config dict or string, but got {type(target_config)}"
+ )
+ return Target(target_config)
+
+
def _op_wrapper(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
@@ -1667,6 +1553,9 @@ def _dtype_forward(func):
# pylint: disable=invalid-name
+broadcast = Broadcast
+ramp = Ramp
+
buffer_var = ptr
abs = _op_wrapper(_tir_op.abs) # pylint: disable=redefined-builtin
fabs = abs
@@ -1713,6 +1602,7 @@ nextafter = _op_wrapper(_tir_op.nextafter)
popcount = _op_wrapper(_tir_op.popcount)
power = _op_wrapper(_tir_op.power)
q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift)
+q_multiply_shift_per_axis = _op_wrapper(_tir_op.q_multiply_shift_per_axis)
ret = _op_wrapper(_tir_op.ret)
reinterpret = _dtype_forward(_tir_op.reinterpret)
round = _op_wrapper(_tir_op.round) # pylint: disable=redefined-builtin
@@ -1733,6 +1623,7 @@ tvm_throw_last_error = _op_wrapper(_tir_op.tvm_throw_last_error)
tvm_stack_alloca = _op_wrapper(_tir_op.tvm_stack_alloca)
tvm_stack_make_shape = _op_wrapper(_tir_op.tvm_stack_make_shape)
tvm_stack_make_array = _op_wrapper(_tir_op.tvm_stack_make_array)
+tvm_check_return = _op_wrapper(_tir_op.tvm_check_return)
call_packed = _op_wrapper(_tir_op.call_packed)
call_cpacked = _op_wrapper(_tir_op.call_cpacked)
call_packed_lowered = _op_wrapper(_tir_op.call_packed_lowered)
@@ -1742,7 +1633,6 @@ call_intrin = _dtype_forward(_tir_op.call_intrin)
call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin)
call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin)
call_pure_extern = _dtype_forward(_tir_op.call_pure_extern)
-tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr)
tvm_tuple = _op_wrapper(_tir_op.tvm_tuple)
tvm_struct_set = _op_wrapper(_tir_op.tvm_struct_set)
tvm_struct_get = _tir_op.tvm_struct_get
@@ -1771,6 +1661,8 @@ tvm_call_packed_lowered = call_packed_lowered
tvm_call_cpacked_lowered = call_cpacked_lowered
TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace)
TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace)
+start_profile_intrinsic = _op_wrapper(_tir_op.start_profile_intrinsic)
+end_profile_intrinsic = _op_wrapper(_tir_op.end_profile_intrinsic)
class inline:
@@ -1796,7 +1688,7 @@ class inline:
# pylint: enable=invalid-name
-__all__ = [
+__all__ += [
"buffer_decl",
"prim_func",
"arg",
@@ -1835,21 +1727,6 @@ __all__ = [
"buffer_store",
"prefetch",
"evaluate",
- "int8",
- "int16",
- "int32",
- "int64",
- "uint8",
- "uint16",
- "uint32",
- "uint64",
- "float8",
- "float16",
- "float32",
- "float64",
- "int32x4",
- "int32x8",
- "int32x16",
"boolean",
"handle",
"void",
@@ -1859,6 +1736,7 @@ __all__ = [
"max",
"iter_var",
"comm_reducer",
+ "target",
"buffer_var",
"abs",
"fabs",
@@ -1905,6 +1783,7 @@ __all__ = [
"popcount",
"power",
"q_multiply_shift",
+ "q_multiply_shift_per_axis",
"ret",
"reinterpret",
"round",
@@ -1925,6 +1804,7 @@ __all__ = [
"tvm_stack_alloca",
"tvm_stack_make_shape",
"tvm_stack_make_array",
+ "tvm_check_return",
"call_packed",
"call_cpacked",
"call_packed_lowered",
@@ -1934,7 +1814,6 @@ __all__ = [
"call_llvm_intrin",
"call_llvm_pure_intrin",
"call_pure_extern",
- "tvm_access_ptr",
"tvm_tuple",
"tvm_struct_set",
"tvm_struct_get",
@@ -1963,14 +1842,50 @@ __all__ = [
"tvm_call_cpacked_lowered",
"TVMBackendAllocWorkspace",
"TVMBackendFreeWorkspace",
+ "start_profile_intrinsic",
+ "end_profile_intrinsic",
"inline",
"llvm_lookup_intrinsic_id",
- "Cast",
- "Let",
- "Select",
- "Shuffle",
"type_annotation",
"broadcast",
"ramp",
"cast",
+ # tvm.tir.expr
+ "Var",
+ "SizeVar",
+ "Reduce",
+ "FloatImm",
+ "IntImm",
+ "StringImm",
+ "Cast",
+ "Add",
+ "Sub",
+ "Mul",
+ "Div",
+ "Mod",
+ "FloorDiv",
+ "FloorMod",
+ "Min",
+ "Max",
+ "EQ",
+ "NE",
+ "LT",
+ "LE",
+ "GT",
+ "GE",
+ "And",
+ "Or",
+ "Not",
+ "Select",
+ "BufferLoad",
+ "ProducerLoad",
+ "Load",
+ "Ramp",
+ "Broadcast",
+ "Shuffle",
+ "Call",
+ "CallEffectKind",
+ "Let",
+ "IterVar",
+ "CommReducer",
]
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 588b40ae40..e1adc0a6bb 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -18,14 +18,15 @@
"""Operators used in TIR expression."""
import warnings
from typing import Any, Optional
+
import tvm._ffi
-from tvm.ir.base import Span
-from tvm.runtime import convert, const
from tvm.ir import Array, Op, PrimExpr
+from tvm.ir.base import Span
+from tvm.runtime import const, convert
-from .buffer import Buffer
-from .expr import Call, PrimExprWithOp, StringImm, Var, CommReducer, IntImm
from . import _ffi_api
+from .buffer import Buffer
+from .expr import Call, CommReducer, IntImm, PrimExprWithOp, StringImm, Var
def _pack_buffer(buf, span=None):
@@ -322,6 +323,24 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
)
+def tvm_check_return(expected, return_unexpected, nested_call):
+ """Return new on stack dtype[num]
+ Parameters
+ ----------
+ expected : int
+ The expected return code.
+ return_unexpected : int
+ The unexpected return code.
+ nested_call : PrimExpr
+ The call expression to check return.
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin("int32", "tir.tvm_check_return", expected, return_unexpected, nested_call)
+
+
def tvm_stack_alloca(dtype_str, num):
"""Return new on stack dtype[num]
@@ -403,7 +422,7 @@ def assume(cond=None):
call : PrimExpr
The call expression.
"""
- return call_intrin("int32", "tir.assume", cond)
+ return call_intrin("bool", "tir.assume", cond)
def undef():
@@ -417,6 +436,34 @@ def undef():
return call_intrin("int32", "tir.undef")
+def start_profile_intrinsic(id):
+ """Start profile intrinsic.
+ Parameters
+ ----------
+ id : int
+ The intrinsic id.
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin("handle", "tir.start_profile_intrinsic", id)
+
+
+def end_profile_intrinsic(id):
+ """End profile intrinsic.
+ Parameters
+ ----------
+ id : int
+ The intrinsic id.
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin("handle", "tir.end_profile_intrinsic", id)
+
+
def tvm_tuple(*value):
"""Create a tuple structure in value field of AttrStmt
diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc
index aa9efa653f..f48ee52506 100644
--- a/src/script/ir_builder/tir/frame.cc
+++ b/src/script/ir_builder/tir/frame.cc
@@ -117,14 +117,14 @@ void LaunchThreadFrameNode::ExitWithScope() {
void AllocateFrameNode::ExitWithScope() {
TIRFrameNode::ExitWithScope();
- AddToParent(tvm::tir::Allocate(buffer->data, buffer->dtype, buffer->shape, condition,
- AsStmt(stmts), annotations));
+ AddToParent(
+ tvm::tir::Allocate(buffer_var, dtype, extents, condition, AsStmt(stmts), annotations));
}
void AllocateConstFrameNode::ExitWithScope() {
TIRFrameNode::ExitWithScope();
AddToParent(
- tvm::tir::AllocateConst(buffer->data, dtype, extents, data, AsStmt(stmts), annotations));
+ tvm::tir::AllocateConst(buffer_var, dtype, extents, data, AsStmt(stmts), annotations));
}
void AttrFrameNode::ExitWithScope() {
TIRFrameNode::ExitWithScope();
@@ -182,7 +182,13 @@ void ElseFrameNode::ExitWithScope() {
void DeclBufferFrameNode::ExitWithScope() {
TIRFrameNode::ExitWithScope();
- AddToParent(tvm::tir::DeclBuffer(buffer, AsStmt(stmts)));
+ if (allocated) {
+ AddToParent(tvm::tir::DeclBuffer(buffer, AsStmt(stmts)));
+ } else {
+ AddToParent(tvm::tir::Allocate(buffer->data, buffer->dtype, buffer->shape,
+ tvm::IntImm(DataType::Bool(), 1),
+ tvm::tir::DeclBuffer(buffer, AsStmt(stmts))));
+ }
}
TVM_REGISTER_NODE_TYPE(TIRFrameNode);
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 6be6e2619f..78107136d4 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -452,20 +452,19 @@ AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_s
n->storage_scope = storage_scope;
n->condition = condition.value_or(tvm::Bool(true));
n->annotations = annotations.value_or(Map<String, ObjectRef>());
- n->buffer = BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, storage_scope, 0, 0,
- "default", NullOpt);
+ n->buffer_var = Var("", tvm::PointerType(tvm::PrimType(dtype), storage_scope));
return AllocateFrame(n);
}
AllocateConstFrame AllocateConst(tvm::runtime::NDArray data, DataType dtype,
- Array<PrimExpr> extents, Map<String, ObjectRef> annotations) {
+ Array<PrimExpr> extents,
+ Optional<Map<String, ObjectRef>> annotations) {
ObjectPtr<AllocateConstFrameNode> n = make_object<AllocateConstFrameNode>();
n->dtype = dtype;
n->extents = extents;
n->data = data;
- n->annotations = annotations;
- n->buffer =
- BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, "", 0, 0, "default", NullOpt);
+ n->annotations = annotations.value_or(Map<String, ObjectRef>());
+ n->buffer_var = Var("", tvm::PointerType(tvm::PrimType(dtype)));
return AllocateConstFrame(n);
}
@@ -529,6 +528,7 @@ DeclBufferFrame DeclBuffer(Array<PrimExpr> shape, DataType dtype, String buffer_
ObjectPtr<DeclBufferFrameNode> n = make_object<DeclBufferFrameNode>();
n->buffer = BufferDecl(shape, dtype, buffer_name, data, strides, elem_offset, storage_scope,
align, offset_factor, buffer_type, axis_separators);
+ n->allocated = data.defined();
return DeclBufferFrame(n);
}
@@ -638,21 +638,35 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Ptr").set_body_typed(Ptr);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int8").set_body_typed(Int8);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int16").set_body_typed(Int16);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32").set_body_typed(Int32);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int64").set_body_typed(Int64);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.UInt8").set_body_typed(UInt8);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.UInt16").set_body_typed(UInt16);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.UInt32").set_body_typed(UInt32);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.UInt64").set_body_typed(UInt64);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float8").set_body_typed(Float8);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float16").set_body_typed(Float16);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float32").set_body_typed(Float32);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float64").set_body_typed(Float64);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32x4").set_body_typed(Int32x4);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32x8").set_body_typed(Int32x8);
-TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32x16").set_body_typed(Int32x16);
+#define TVM_TMP_STR(x) #x
+
+#define TVM_REGISTER_GLOBAL_SIZE(Prefix, DType) \
+ TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(8)).set_body_typed(DType##8); \
+ TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(16)).set_body_typed(DType##16); \
+ TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(32)).set_body_typed(DType##32); \
+ TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(64)).set_body_typed(DType##64);
+
+TVM_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.Float", Float);
+TVM_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.UInt", UInt);
+TVM_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.Int", Int);
+
+#define TVM_REGISTER_GLOBAL_LANES(Prefix, Func) \
+ TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x4)).set_body_typed(Func##x4); \
+ TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x8)).set_body_typed(Func##x8); \
+ TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x16)).set_body_typed(Func##x16); \
+ TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x32)).set_body_typed(Func##x32); \
+ TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x64)).set_body_typed(Func##x64);
+
+#define TVM_REGISTER_GLOBAL_SIZES_LANES(Prefix, DType) \
+ TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(8), DType##8); \
+ TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(16), DType##16); \
+ TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(32), DType##32); \
+ TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(64), DType##64);
+
+TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float);
+TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt);
+TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int);
+
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void);
diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py
index 36de35fa92..2ec52bfbfe 100644
--- a/tests/python/unittest/test_tvmscript_error_report.py
+++ b/tests/python/unittest/test_tvmscript_error_report.py
@@ -52,7 +52,7 @@ def check_error(func, rel_lineno):
return
error = errors[0]
assert (
- error.span.line - 1 == rel_lineno
+ error.span.line - 1 == rel_lineno or error.span.line == rel_lineno
), f"Expected error to be on line {rel_lineno}, but it was on {error.span.line - 1}"
error_line = source_code.split("\n")[rel_lineno]
diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
index dbc9b594fb..a3df5a183b 100644
--- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
@@ -16,15 +16,15 @@
# under the License.
# pylint: disable=invalid-name, missing-docstring
"""Unittests for tvm.script.ir_builder.tir"""
-import pytest
import numpy as np
+import pytest
import tvm
import tvm.testing
from tvm import tir
+from tvm.ir.base import assert_structural_equal
from tvm.runtime import ndarray
-from tvm.script.ir_builder import tir as T
from tvm.script.ir_builder import IRBuilder
-from tvm.ir.base import assert_structural_equal
+from tvm.script.ir_builder import tir as T
def test_ir_builder_tir_primfunc_base():
@@ -372,7 +372,12 @@ def test_ir_builder_tir_allocate_const():
# the expected allocate const
buffer_var = tir.Var("v", tvm.ir.PointerType(tvm.ir.PrimType("int32")))
ir_expected = tir.AllocateConst(
- buffer_var, "int32", [10], ndarray.array(np.asarray(data, "int32")), tir.Evaluate(1)
+ buffer_var,
+ "int32",
+ [10],
+ ndarray.array(np.asarray(data, "int32")),
+ tir.Evaluate(1),
+ annotations={},
)
# Check if the generated ir is expected
@@ -470,7 +475,13 @@ def test_ir_builder_tir_decl_buffer():
# the expected decl_buffer
buffer = T.buffer_decl((128, 128), "float32")
- ir_expected = tir.DeclBuffer(buffer, tir.Evaluate(0))
+ ir_expected = tir.Allocate(
+ buffer.data,
+ "float32",
+ (128, 128),
+ tir.IntImm("bool", True),
+ tir.DeclBuffer(buffer, tir.Evaluate(0)),
+ )
# Check if the generated ir is expected
assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)