You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/09/26 16:52:09 UTC
[tvm] branch main updated: [TVMScript] Import TIR methods into the IRBuilder (#12900)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8711ba44b9 [TVMScript] Import TIR methods into the IRBuilder (#12900)
8711ba44b9 is described below
commit 8711ba44b9bebc54bb4bc3c3f456ee3ce3d40eed
Author: Yaxing Cai <ca...@gmail.com>
AuthorDate: Mon Sep 26 09:52:02 2022 -0700
[TVMScript] Import TIR methods into the IRBuilder (#12900)
This PR introduces remaining TIR methods into IRBuilder
Co-authored-by: yongwww <yo...@gmail.com>
---
include/tvm/script/ir_builder/tir/ir.h | 8 +
python/tvm/script/ir_builder/tir/ir.py | 396 ++++++++++++++++++++-
src/script/ir_builder/tir/ir.cc | 11 +
.../unittest/test_tvmscript_ir_builder_tir.py | 15 +
4 files changed, 428 insertions(+), 2 deletions(-)
diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h
index dd289b6915..7460099f94 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -435,6 +435,14 @@ void Prefetch(Buffer buffer, Array<Range> bounds);
*/
void Evaluate(PrimExpr value);
+/*!
+ * \brief The pointer declaration function.
+ * \param dtype The data type of the pointer.
+ * \param storage_scope The storage scope of the pointer.
+ * \return The pointer.
+ */
+PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global");
+
#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \
inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt) { \
DataType dtype = DType; \
diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py
index 625e1291ff..4ec1511f29 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -17,24 +17,35 @@
# pylint: disable=missing-docstring
"""IRBuilder for TIR"""
+import inspect
+import functools
from numbers import Integral
-from typing import Any, Dict, List, Optional, Union, Tuple
+from typing import Any, Callable, Dict, List, Optional, Union, Tuple
import numpy as np # type: ignore
from tvm.ir import Range, Type
from tvm.runtime import convert, ndarray
+from tvm.target.codegen import llvm_lookup_intrinsic_id
from tvm.tir import (
Buffer,
BufferLoad,
BufferRegion,
+ Cast,
+ CommReducer,
IntImm,
IterVar,
Let,
PrimExpr,
+ Select,
+ Shuffle,
StringImm,
+ type_annotation,
Var,
)
+from tvm.tir import Broadcast as broadcast
from tvm.tir import Ramp as ramp
+from tvm.tir import op as _tir_op
+from tvm.tir.generic import cast
from . import _ffi_api, frame
@@ -1501,7 +1512,7 @@ def void(expr: Optional[PrimExpr] = None) -> PrimExpr:
return _ffi_api.Void(expr) # type: ignore[attr-defined] # pylint: disable=no-member
-def var(dtype, name="") -> Var:
+def var(dtype: str, name: str = "") -> Var:
"""Construct a new tir.Var.
Parameters
@@ -1520,6 +1531,268 @@ def var(dtype, name="") -> Var:
return Var(name, dtype) # pylint: disable=no-member
+def ptr(dtype: str, storage_scope: str = "global") -> Var:
+ """The pointer declaration function.
+
+ Parameters
+ ----------
+ dtype : str
+ The data type of the pointer.
+
+ storage_scope : str
+ The storage scope of the pointer.
+
+ Returns
+ -------
+ res : Var
+ The pointer.
+ """
+ return _ffi_api.Ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: # pylint: disable=redefined-builtin
+ """Compute the minimum value of two expressions.
+
+ Parameters
+ ----------
+ a : PrimExpr
+ The left hand operand
+
+ b : PrimExpr
+ The right hand operand
+
+ Returns
+ -------
+ res : PrimExpr
+ The result expression.
+ """
+ return _ffi_api.min(a, b) # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: # pylint: disable=redefined-builtin
+ """Compute the maximum value of two expressions.
+
+ Parameters
+ ----------
+ a : PrimExpr
+ The left hand operand
+
+ b : PrimExpr
+ The right hand operand
+
+ Returns
+ -------
+ res : PrimExpr
+ The result expression.
+ """
+ return _ffi_api.max(a, b) # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def iter_var(v: Union[Var, str], dom: Range, iter_type: str, thread_tag: str) -> IterVar:
+ """The iteration variable.
+
+ Parameters
+ ----------
+ var : Union[Var, str]
+ The internal variable that is used for iteration.
+
+ dom : Range
+ The domain of the iteration.
+
+ iter_type : str
+ The iteration type.
+
+ thread_tag : str
+ The thread type tag.
+
+ Returns
+ -------
+ res : IterVar
+ The iteration variable.
+ """
+ iter_type = getattr(IterVar, iter_type)
+ return IterVar(dom, v, iter_type, thread_tag)
+
+
+def comm_reducer(combiner: Callable, identity: List[PrimExpr]) -> CommReducer:
+ """
+ Create a CommReducer from lambda inputs/outputs and the identities
+
+ Parameters
+ ----------
+ combiner : Callable
+ A binary function which takes two PrimExpr as input to return a PrimExpr.
+
+ identity : List[PrimExpr]
+ A list of types of output PrimExpr.
+
+ Returns
+ -------
+ res : CommReducer
+ The CommReducer.
+ """
+ params = inspect.signature(combiner).parameters
+ num_args = len(params)
+ args = []
+ for name, i in zip(params.keys(), identity + identity):
+ if isinstance(i, int):
+ args.append(Var(name, "int32"))
+ else:
+ args.append(Var(name, i.dtype))
+ res = combiner(*args)
+ if not isinstance(res, tuple):
+ res = (res,)
+ return CommReducer(args[: num_args // 2], args[num_args // 2 :], res, identity)
+
+
+def _op_wrapper(func):
+ @functools.wraps(func)
+ def wrapped(*args, **kwargs):
+ if "dtype" in kwargs:
+ kwargs.pop("dtype")
+ return func(*args, **kwargs)
+
+ return wrapped
+
+
+def _dtype_forward(func):
+ @functools.wraps(func)
+ def wrapped(*args, **kwargs):
+ if "dtype" in kwargs:
+ args = (kwargs.pop("dtype"),) + args
+ return func(*args, **kwargs)
+
+ return wrapped
+
+
+# pylint: disable=invalid-name
+
+buffer_var = ptr
+abs = _op_wrapper(_tir_op.abs) # pylint: disable=redefined-builtin
+fabs = abs
+acos = _op_wrapper(_tir_op.acos)
+acosh = _op_wrapper(_tir_op.acosh)
+address_of = _op_wrapper(_tir_op.address_of)
+asin = _op_wrapper(_tir_op.asin)
+asinh = _op_wrapper(_tir_op.asinh)
+atan = _op_wrapper(_tir_op.atan)
+atan2 = _op_wrapper(_tir_op.atan2)
+atanh = _op_wrapper(_tir_op.atanh)
+ceil = _op_wrapper(_tir_op.ceil)
+clz = _op_wrapper(_tir_op.clz)
+copysign = _op_wrapper(_tir_op.copysign)
+cos = _op_wrapper(_tir_op.cos)
+cosh = _op_wrapper(_tir_op.cosh)
+erf = _op_wrapper(_tir_op.erf)
+exp = _op_wrapper(_tir_op.exp)
+exp2 = _op_wrapper(_tir_op.exp2)
+exp10 = _op_wrapper(_tir_op.exp10)
+floor = _op_wrapper(_tir_op.floor)
+ceildiv = _op_wrapper(_tir_op.ceildiv)
+floordiv = _op_wrapper(_tir_op.floordiv)
+floormod = _op_wrapper(_tir_op.floormod)
+fmod = _op_wrapper(_tir_op.fmod)
+hypot = _op_wrapper(_tir_op.hypot)
+if_then_else = _op_wrapper(_tir_op.if_then_else)
+infinity = _op_wrapper(_tir_op.infinity)
+isfinite = _op_wrapper(_tir_op.isfinite)
+isinf = _op_wrapper(_tir_op.isinf)
+isnan = _op_wrapper(_tir_op.isnan)
+isnullptr = _op_wrapper(_tir_op.isnullptr)
+ldexp = _op_wrapper(_tir_op.ldexp)
+likely = _op_wrapper(_tir_op.likely)
+log = _op_wrapper(_tir_op.log)
+log1p = _op_wrapper(_tir_op.log1p)
+log2 = _op_wrapper(_tir_op.log2)
+log10 = _op_wrapper(_tir_op.log10)
+lookup_param = _op_wrapper(_tir_op.lookup_param)
+max_value = _op_wrapper(_tir_op.max_value)
+min_value = _op_wrapper(_tir_op.min_value)
+nearbyint = _op_wrapper(_tir_op.nearbyint)
+nextafter = _op_wrapper(_tir_op.nextafter)
+popcount = _op_wrapper(_tir_op.popcount)
+power = _op_wrapper(_tir_op.power)
+q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift)
+ret = _op_wrapper(_tir_op.ret)
+reinterpret = _dtype_forward(_tir_op.reinterpret)
+round = _op_wrapper(_tir_op.round) # pylint: disable=redefined-builtin
+rsqrt = _op_wrapper(_tir_op.rsqrt)
+shift_left = _op_wrapper(_tir_op.shift_left)
+shift_right = _op_wrapper(_tir_op.shift_right)
+sigmoid = _op_wrapper(_tir_op.sigmoid)
+sin = _op_wrapper(_tir_op.sin)
+sinh = _op_wrapper(_tir_op.sinh)
+sqrt = _op_wrapper(_tir_op.sqrt)
+tan = _op_wrapper(_tir_op.tan)
+tanh = _op_wrapper(_tir_op.tanh)
+trunc = _op_wrapper(_tir_op.trunc)
+truncdiv = _op_wrapper(_tir_op.truncdiv)
+truncmod = _op_wrapper(_tir_op.truncmod)
+tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr)
+tvm_throw_last_error = _op_wrapper(_tir_op.tvm_throw_last_error)
+tvm_stack_alloca = _op_wrapper(_tir_op.tvm_stack_alloca)
+tvm_stack_make_shape = _op_wrapper(_tir_op.tvm_stack_make_shape)
+tvm_stack_make_array = _op_wrapper(_tir_op.tvm_stack_make_array)
+call_packed = _op_wrapper(_tir_op.call_packed)
+call_cpacked = _op_wrapper(_tir_op.call_cpacked)
+call_packed_lowered = _op_wrapper(_tir_op.call_packed_lowered)
+call_cpacked_lowered = _op_wrapper(_tir_op.call_cpacked_lowered)
+call_extern = _dtype_forward(_tir_op.call_extern)
+call_intrin = _dtype_forward(_tir_op.call_intrin)
+call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin)
+call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin)
+call_pure_extern = _dtype_forward(_tir_op.call_pure_extern)
+tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr)
+tvm_tuple = _op_wrapper(_tir_op.tvm_tuple)
+tvm_struct_set = _op_wrapper(_tir_op.tvm_struct_set)
+tvm_struct_get = _tir_op.tvm_struct_get
+tvm_thread_allreduce = _op_wrapper(_tir_op.tvm_thread_allreduce)
+tvm_load_matrix_sync = _op_wrapper(_tir_op.tvm_load_matrix_sync)
+tvm_mma_sync = _op_wrapper(_tir_op.tvm_mma_sync)
+tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync)
+tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment)
+tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync)
+ptx_mma = _dtype_forward(_tir_op.ptx_mma)
+ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp)
+ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix)
+ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async)
+ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group)
+ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group)
+mma_store = _dtype_forward(_tir_op.mma_store)
+mma_fill = _dtype_forward(_tir_op.mma_fill)
+vectorlow = _dtype_forward(_tir_op.vectorlow)
+vectorhigh = _dtype_forward(_tir_op.vectorhigh)
+vectorcombine = _dtype_forward(_tir_op.vectorcombine)
+assume = _op_wrapper(_tir_op.assume)
+undef = _op_wrapper(_tir_op.undef)
+tvm_call_packed = call_packed
+tvm_call_cpacked = call_cpacked
+tvm_call_packed_lowered = call_packed_lowered
+tvm_call_cpacked_lowered = call_cpacked_lowered
+TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace)
+TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace)
+
+
+class inline:
+ """Inline function for meta-programming.
+
+ Parameters
+ ----------
+ value: Any
+ The value to be inlined.
+ """
+
+ def __init__(self, value: Any) -> None:
+ self.value = value
+
+ def __iter__(self):
+ def f():
+ for i in self.value:
+ yield inline(i)
+
+ return f()
+
+
# pylint: enable=invalid-name
@@ -1581,4 +1854,123 @@ __all__ = [
"handle",
"void",
"var",
+ "ptr",
+ "min",
+ "max",
+ "iter_var",
+ "comm_reducer",
+ "buffer_var",
+ "abs",
+ "fabs",
+ "acos",
+ "acosh",
+ "address_of",
+ "asin",
+ "asinh",
+ "atan",
+ "atan2",
+ "atanh",
+ "ceil",
+ "clz",
+ "copysign",
+ "cos",
+ "cosh",
+ "erf",
+ "exp",
+ "exp2",
+ "exp10",
+ "floor",
+ "ceildiv",
+ "floordiv",
+ "floormod",
+ "fmod",
+ "hypot",
+ "if_then_else",
+ "infinity",
+ "isfinite",
+ "isinf",
+ "isnan",
+ "isnullptr",
+ "ldexp",
+ "likely",
+ "log",
+ "log1p",
+ "log2",
+ "log10",
+ "lookup_param",
+ "max_value",
+ "min_value",
+ "nearbyint",
+ "nextafter",
+ "popcount",
+ "power",
+ "q_multiply_shift",
+ "ret",
+ "reinterpret",
+ "round",
+ "rsqrt",
+ "shift_left",
+ "shift_right",
+ "sigmoid",
+ "sin",
+ "sinh",
+ "sqrt",
+ "tan",
+ "tanh",
+ "trunc",
+ "truncdiv",
+ "truncmod",
+ "tvm_access_ptr",
+ "tvm_throw_last_error",
+ "tvm_stack_alloca",
+ "tvm_stack_make_shape",
+ "tvm_stack_make_array",
+ "call_packed",
+ "call_cpacked",
+ "call_packed_lowered",
+ "call_cpacked_lowered",
+ "call_extern",
+ "call_intrin",
+ "call_llvm_intrin",
+ "call_llvm_pure_intrin",
+ "call_pure_extern",
+ "tvm_access_ptr",
+ "tvm_tuple",
+ "tvm_struct_set",
+ "tvm_struct_get",
+ "tvm_thread_allreduce",
+ "tvm_load_matrix_sync",
+ "tvm_mma_sync",
+ "tvm_bmma_sync",
+ "tvm_fill_fragment",
+ "tvm_store_matrix_sync",
+ "ptx_mma",
+ "ptx_mma_sp",
+ "ptx_ldmatrix",
+ "ptx_cp_async",
+ "ptx_wait_group",
+ "ptx_commit_group",
+ "mma_store",
+ "mma_fill",
+ "vectorlow",
+ "vectorhigh",
+ "vectorcombine",
+ "assume",
+ "undef",
+ "tvm_call_packed",
+ "tvm_call_cpacked",
+ "tvm_call_packed_lowered",
+ "tvm_call_cpacked_lowered",
+ "TVMBackendAllocWorkspace",
+ "TVMBackendFreeWorkspace",
+ "inline",
+ "llvm_lookup_intrinsic_id",
+ "Cast",
+ "Let",
+ "Select",
+ "Shuffle",
+ "type_annotation",
+ "broadcast",
+ "ramp",
+ "cast",
]
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 28c3d69861..6be6e2619f 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -534,6 +534,10 @@ DeclBufferFrame DeclBuffer(Array<PrimExpr> shape, DataType dtype, String buffer_
void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); }
+PrimExpr Ptr(runtime::DataType dtype, String storage_scope) {
+ return tvm::tir::Var("", tvm::PointerType(PrimType(dtype), storage_scope));
+}
+
using tvm::script::ir_builder::details::Namer;
TVM_STATIC_IR_FUNCTOR(Namer, vtable)
@@ -632,6 +636,8 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore").set_body_typed(BufferSt
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Prefetch").set_body_typed(Prefetch);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Ptr").set_body_typed(Ptr);
+
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int8").set_body_typed(Int8);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int16").set_body_typed(Int16);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32").set_body_typed(Int32);
@@ -650,6 +656,11 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32x16").set_body_typed(Int32x16);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void);
+
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.min")
+ .set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::min(a, b); });
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.max")
+ .set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::max(a, b); });
} // namespace tir
} // namespace ir_builder
} // namespace script
diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
index 40e13a2fbe..dbc9b594fb 100644
--- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
@@ -476,5 +476,20 @@ def test_ir_builder_tir_decl_buffer():
assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)
+def test_ir_builder_tir_inline():
+ with IRBuilder() as ib:
+ m, n = T.inline(1), T.inline(2)
+ a, b = T.inline([3, 4])
+ T.evaluate(m.value + n.value + a.value + b.value)
+ # the evaluate generated by IRBuilder
+ eval_actual = ib.get()
+
+ # the expected evaluate
+ eval_expected = tir.Evaluate(10)
+
+ # Check if the generated ir is expected
+ assert_structural_equal(eval_actual, eval_expected, map_free_vars=True)
+
+
if __name__ == "__main__":
tvm.testing.main()