You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/09/26 16:52:09 UTC

[tvm] branch main updated: [TVMScript] Import TIR methods into the IRBuilder (#12900)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8711ba44b9 [TVMScript] Import TIR methods into the IRBuilder (#12900)
8711ba44b9 is described below

commit 8711ba44b9bebc54bb4bc3c3f456ee3ce3d40eed
Author: Yaxing Cai <ca...@gmail.com>
AuthorDate: Mon Sep 26 09:52:02 2022 -0700

    [TVMScript] Import TIR methods into the IRBuilder (#12900)
    
    This PR introduces remaining TIR methods into IRBuilder
    
    Co-authored-by: yongwww <yo...@gmail.com>
---
 include/tvm/script/ir_builder/tir/ir.h             |   8 +
 python/tvm/script/ir_builder/tir/ir.py             | 396 ++++++++++++++++++++-
 src/script/ir_builder/tir/ir.cc                    |  11 +
 .../unittest/test_tvmscript_ir_builder_tir.py      |  15 +
 4 files changed, 428 insertions(+), 2 deletions(-)

diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h
index dd289b6915..7460099f94 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -435,6 +435,14 @@ void Prefetch(Buffer buffer, Array<Range> bounds);
  */
 void Evaluate(PrimExpr value);
 
+/*!
+ * \brief The pointer declaration function.
+ * \param dtype The data type of the pointer.
+ * \param storage_scope The storage scope of the pointer.
+ * \return The pointer.
+ */
+PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global");
+
 #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType)                             \
   inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt) {                        \
     DataType dtype = DType;                                                            \
diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py
index 625e1291ff..4ec1511f29 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -17,24 +17,35 @@
 # pylint: disable=missing-docstring
 """IRBuilder for TIR"""
 
+import inspect
+import functools
 from numbers import Integral
-from typing import Any, Dict, List, Optional, Union, Tuple
+from typing import Any, Callable, Dict, List, Optional, Union, Tuple
 import numpy as np  # type: ignore
 
 from tvm.ir import Range, Type
 from tvm.runtime import convert, ndarray
+from tvm.target.codegen import llvm_lookup_intrinsic_id
 from tvm.tir import (
     Buffer,
     BufferLoad,
     BufferRegion,
+    Cast,
+    CommReducer,
     IntImm,
     IterVar,
     Let,
     PrimExpr,
+    Select,
+    Shuffle,
     StringImm,
+    type_annotation,
     Var,
 )
+from tvm.tir import Broadcast as broadcast
 from tvm.tir import Ramp as ramp
+from tvm.tir import op as _tir_op
+from tvm.tir.generic import cast
 
 from . import _ffi_api, frame
 
@@ -1501,7 +1512,7 @@ def void(expr: Optional[PrimExpr] = None) -> PrimExpr:
     return _ffi_api.Void(expr)  # type: ignore[attr-defined] # pylint: disable=no-member
 
 
-def var(dtype, name="") -> Var:
+def var(dtype: str, name: str = "") -> Var:
     """Construct a new tir.Var.
 
     Parameters
@@ -1520,6 +1531,268 @@ def var(dtype, name="") -> Var:
     return Var(name, dtype)  # pylint: disable=no-member
 
 
+def ptr(dtype: str, storage_scope: str = "global") -> Var:
+    """The pointer declaration function.
+
+    Parameters
+    ----------
+    dtype : str
+        The data type of the pointer.
+
+    storage_scope : str
+        The storage scope of the pointer.
+
+    Returns
+    -------
+    res : Var
+        The pointer.
+    """
+    return _ffi_api.Ptr(dtype, storage_scope)  # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def min(a: PrimExpr, b: PrimExpr) -> PrimExpr:  # pylint: disable=redefined-builtin
+    """Compute the minimum value of two expressions.
+
+    Parameters
+    ----------
+    a : PrimExpr
+        The left hand operand
+
+    b : PrimExpr
+        The right hand operand
+
+    Returns
+    -------
+    res : PrimExpr
+        The result expression.
+    """
+    return _ffi_api.min(a, b)  # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def max(a: PrimExpr, b: PrimExpr) -> PrimExpr:  # pylint: disable=redefined-builtin
+    """Compute the maximum value of two expressions.
+
+    Parameters
+    ----------
+    a : PrimExpr
+        The left hand operand
+
+    b : PrimExpr
+        The right hand operand
+
+    Returns
+    -------
+    res : PrimExpr
+        The result expression.
+    """
+    return _ffi_api.max(a, b)  # type: ignore[attr-defined] # pylint: disable=no-member
+
+
+def iter_var(v: Union[Var, str], dom: Range, iter_type: str, thread_tag: str) -> IterVar:
+    """The iteration variable.
+
+    Parameters
+    ----------
+    var : Union[Var, str]
+        The internal variable that is used for iteration.
+
+    dom : Range
+        The domain of the iteration.
+
+    iter_type : str
+        The iteration type.
+
+    thread_tag : str
+        The thread type tag.
+
+    Returns
+    -------
+    res : IterVar
+        The iteration variable.
+    """
+    iter_type = getattr(IterVar, iter_type)
+    return IterVar(dom, v, iter_type, thread_tag)
+
+
+def comm_reducer(combiner: Callable, identity: List[PrimExpr]) -> CommReducer:
+    """
+    Create a CommReducer from lambda inputs/outputs and the identities
+
+    Parameters
+    ----------
+    combiner : Callable
+        A binary function which takes two PrimExpr as input to return a PrimExpr.
+
+    identity : List[PrimExpr]
+        A list of types of output PrimExpr.
+
+    Returns
+    -------
+    res : CommReducer
+        The CommReducer.
+    """
+    params = inspect.signature(combiner).parameters
+    num_args = len(params)
+    args = []
+    for name, i in zip(params.keys(), identity + identity):
+        if isinstance(i, int):
+            args.append(Var(name, "int32"))
+        else:
+            args.append(Var(name, i.dtype))
+    res = combiner(*args)
+    if not isinstance(res, tuple):
+        res = (res,)
+    return CommReducer(args[: num_args // 2], args[num_args // 2 :], res, identity)
+
+
+def _op_wrapper(func):
+    @functools.wraps(func)
+    def wrapped(*args, **kwargs):
+        if "dtype" in kwargs:
+            kwargs.pop("dtype")
+        return func(*args, **kwargs)
+
+    return wrapped
+
+
+def _dtype_forward(func):
+    @functools.wraps(func)
+    def wrapped(*args, **kwargs):
+        if "dtype" in kwargs:
+            args = (kwargs.pop("dtype"),) + args
+        return func(*args, **kwargs)
+
+    return wrapped
+
+
+# pylint: disable=invalid-name
+
+buffer_var = ptr
+abs = _op_wrapper(_tir_op.abs)  # pylint: disable=redefined-builtin
+fabs = abs
+acos = _op_wrapper(_tir_op.acos)
+acosh = _op_wrapper(_tir_op.acosh)
+address_of = _op_wrapper(_tir_op.address_of)
+asin = _op_wrapper(_tir_op.asin)
+asinh = _op_wrapper(_tir_op.asinh)
+atan = _op_wrapper(_tir_op.atan)
+atan2 = _op_wrapper(_tir_op.atan2)
+atanh = _op_wrapper(_tir_op.atanh)
+ceil = _op_wrapper(_tir_op.ceil)
+clz = _op_wrapper(_tir_op.clz)
+copysign = _op_wrapper(_tir_op.copysign)
+cos = _op_wrapper(_tir_op.cos)
+cosh = _op_wrapper(_tir_op.cosh)
+erf = _op_wrapper(_tir_op.erf)
+exp = _op_wrapper(_tir_op.exp)
+exp2 = _op_wrapper(_tir_op.exp2)
+exp10 = _op_wrapper(_tir_op.exp10)
+floor = _op_wrapper(_tir_op.floor)
+ceildiv = _op_wrapper(_tir_op.ceildiv)
+floordiv = _op_wrapper(_tir_op.floordiv)
+floormod = _op_wrapper(_tir_op.floormod)
+fmod = _op_wrapper(_tir_op.fmod)
+hypot = _op_wrapper(_tir_op.hypot)
+if_then_else = _op_wrapper(_tir_op.if_then_else)
+infinity = _op_wrapper(_tir_op.infinity)
+isfinite = _op_wrapper(_tir_op.isfinite)
+isinf = _op_wrapper(_tir_op.isinf)
+isnan = _op_wrapper(_tir_op.isnan)
+isnullptr = _op_wrapper(_tir_op.isnullptr)
+ldexp = _op_wrapper(_tir_op.ldexp)
+likely = _op_wrapper(_tir_op.likely)
+log = _op_wrapper(_tir_op.log)
+log1p = _op_wrapper(_tir_op.log1p)
+log2 = _op_wrapper(_tir_op.log2)
+log10 = _op_wrapper(_tir_op.log10)
+lookup_param = _op_wrapper(_tir_op.lookup_param)
+max_value = _op_wrapper(_tir_op.max_value)
+min_value = _op_wrapper(_tir_op.min_value)
+nearbyint = _op_wrapper(_tir_op.nearbyint)
+nextafter = _op_wrapper(_tir_op.nextafter)
+popcount = _op_wrapper(_tir_op.popcount)
+power = _op_wrapper(_tir_op.power)
+q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift)
+ret = _op_wrapper(_tir_op.ret)
+reinterpret = _dtype_forward(_tir_op.reinterpret)
+round = _op_wrapper(_tir_op.round)  # pylint: disable=redefined-builtin
+rsqrt = _op_wrapper(_tir_op.rsqrt)
+shift_left = _op_wrapper(_tir_op.shift_left)
+shift_right = _op_wrapper(_tir_op.shift_right)
+sigmoid = _op_wrapper(_tir_op.sigmoid)
+sin = _op_wrapper(_tir_op.sin)
+sinh = _op_wrapper(_tir_op.sinh)
+sqrt = _op_wrapper(_tir_op.sqrt)
+tan = _op_wrapper(_tir_op.tan)
+tanh = _op_wrapper(_tir_op.tanh)
+trunc = _op_wrapper(_tir_op.trunc)
+truncdiv = _op_wrapper(_tir_op.truncdiv)
+truncmod = _op_wrapper(_tir_op.truncmod)
+tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr)
+tvm_throw_last_error = _op_wrapper(_tir_op.tvm_throw_last_error)
+tvm_stack_alloca = _op_wrapper(_tir_op.tvm_stack_alloca)
+tvm_stack_make_shape = _op_wrapper(_tir_op.tvm_stack_make_shape)
+tvm_stack_make_array = _op_wrapper(_tir_op.tvm_stack_make_array)
+call_packed = _op_wrapper(_tir_op.call_packed)
+call_cpacked = _op_wrapper(_tir_op.call_cpacked)
+call_packed_lowered = _op_wrapper(_tir_op.call_packed_lowered)
+call_cpacked_lowered = _op_wrapper(_tir_op.call_cpacked_lowered)
+call_extern = _dtype_forward(_tir_op.call_extern)
+call_intrin = _dtype_forward(_tir_op.call_intrin)
+call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin)
+call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin)
+call_pure_extern = _dtype_forward(_tir_op.call_pure_extern)
+tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr)
+tvm_tuple = _op_wrapper(_tir_op.tvm_tuple)
+tvm_struct_set = _op_wrapper(_tir_op.tvm_struct_set)
+tvm_struct_get = _tir_op.tvm_struct_get
+tvm_thread_allreduce = _op_wrapper(_tir_op.tvm_thread_allreduce)
+tvm_load_matrix_sync = _op_wrapper(_tir_op.tvm_load_matrix_sync)
+tvm_mma_sync = _op_wrapper(_tir_op.tvm_mma_sync)
+tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync)
+tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment)
+tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync)
+ptx_mma = _dtype_forward(_tir_op.ptx_mma)
+ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp)
+ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix)
+ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async)
+ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group)
+ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group)
+mma_store = _dtype_forward(_tir_op.mma_store)
+mma_fill = _dtype_forward(_tir_op.mma_fill)
+vectorlow = _dtype_forward(_tir_op.vectorlow)
+vectorhigh = _dtype_forward(_tir_op.vectorhigh)
+vectorcombine = _dtype_forward(_tir_op.vectorcombine)
+assume = _op_wrapper(_tir_op.assume)
+undef = _op_wrapper(_tir_op.undef)
+tvm_call_packed = call_packed
+tvm_call_cpacked = call_cpacked
+tvm_call_packed_lowered = call_packed_lowered
+tvm_call_cpacked_lowered = call_cpacked_lowered
+TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace)
+TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace)
+
+
+class inline:
+    """Inline function for meta-programming.
+
+    Parameters
+    ----------
+    value: Any
+        The value to be inlined.
+    """
+
+    def __init__(self, value: Any) -> None:
+        self.value = value
+
+    def __iter__(self):
+        def f():
+            for i in self.value:
+                yield inline(i)
+
+        return f()
+
+
 # pylint: enable=invalid-name
 
 
@@ -1581,4 +1854,123 @@ __all__ = [
     "handle",
     "void",
     "var",
+    "ptr",
+    "min",
+    "max",
+    "iter_var",
+    "comm_reducer",
+    "buffer_var",
+    "abs",
+    "fabs",
+    "acos",
+    "acosh",
+    "address_of",
+    "asin",
+    "asinh",
+    "atan",
+    "atan2",
+    "atanh",
+    "ceil",
+    "clz",
+    "copysign",
+    "cos",
+    "cosh",
+    "erf",
+    "exp",
+    "exp2",
+    "exp10",
+    "floor",
+    "ceildiv",
+    "floordiv",
+    "floormod",
+    "fmod",
+    "hypot",
+    "if_then_else",
+    "infinity",
+    "isfinite",
+    "isinf",
+    "isnan",
+    "isnullptr",
+    "ldexp",
+    "likely",
+    "log",
+    "log1p",
+    "log2",
+    "log10",
+    "lookup_param",
+    "max_value",
+    "min_value",
+    "nearbyint",
+    "nextafter",
+    "popcount",
+    "power",
+    "q_multiply_shift",
+    "ret",
+    "reinterpret",
+    "round",
+    "rsqrt",
+    "shift_left",
+    "shift_right",
+    "sigmoid",
+    "sin",
+    "sinh",
+    "sqrt",
+    "tan",
+    "tanh",
+    "trunc",
+    "truncdiv",
+    "truncmod",
+    "tvm_access_ptr",
+    "tvm_throw_last_error",
+    "tvm_stack_alloca",
+    "tvm_stack_make_shape",
+    "tvm_stack_make_array",
+    "call_packed",
+    "call_cpacked",
+    "call_packed_lowered",
+    "call_cpacked_lowered",
+    "call_extern",
+    "call_intrin",
+    "call_llvm_intrin",
+    "call_llvm_pure_intrin",
+    "call_pure_extern",
+    "tvm_access_ptr",
+    "tvm_tuple",
+    "tvm_struct_set",
+    "tvm_struct_get",
+    "tvm_thread_allreduce",
+    "tvm_load_matrix_sync",
+    "tvm_mma_sync",
+    "tvm_bmma_sync",
+    "tvm_fill_fragment",
+    "tvm_store_matrix_sync",
+    "ptx_mma",
+    "ptx_mma_sp",
+    "ptx_ldmatrix",
+    "ptx_cp_async",
+    "ptx_wait_group",
+    "ptx_commit_group",
+    "mma_store",
+    "mma_fill",
+    "vectorlow",
+    "vectorhigh",
+    "vectorcombine",
+    "assume",
+    "undef",
+    "tvm_call_packed",
+    "tvm_call_cpacked",
+    "tvm_call_packed_lowered",
+    "tvm_call_cpacked_lowered",
+    "TVMBackendAllocWorkspace",
+    "TVMBackendFreeWorkspace",
+    "inline",
+    "llvm_lookup_intrinsic_id",
+    "Cast",
+    "Let",
+    "Select",
+    "Shuffle",
+    "type_annotation",
+    "broadcast",
+    "ramp",
+    "cast",
 ]
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index 28c3d69861..6be6e2619f 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -534,6 +534,10 @@ DeclBufferFrame DeclBuffer(Array<PrimExpr> shape, DataType dtype, String buffer_
 
 void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); }
 
+PrimExpr Ptr(runtime::DataType dtype, String storage_scope) {
+  return tvm::tir::Var("", tvm::PointerType(PrimType(dtype), storage_scope));
+}
+
 using tvm::script::ir_builder::details::Namer;
 
 TVM_STATIC_IR_FUNCTOR(Namer, vtable)
@@ -632,6 +636,8 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore").set_body_typed(BufferSt
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Prefetch").set_body_typed(Prefetch);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate);
 
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.Ptr").set_body_typed(Ptr);
+
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int8").set_body_typed(Int8);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int16").set_body_typed(Int16);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32").set_body_typed(Int32);
@@ -650,6 +656,11 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32x16").set_body_typed(Int32x16);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle);
 TVM_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void);
+
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.min")
+    .set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::min(a, b); });
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.max")
+    .set_body_typed([](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::max(a, b); });
 }  // namespace tir
 }  // namespace ir_builder
 }  // namespace script
diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
index 40e13a2fbe..dbc9b594fb 100644
--- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py
@@ -476,5 +476,20 @@ def test_ir_builder_tir_decl_buffer():
     assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)
 
 
+def test_ir_builder_tir_inline():
+    with IRBuilder() as ib:
+        m, n = T.inline(1), T.inline(2)
+        a, b = T.inline([3, 4])
+        T.evaluate(m.value + n.value + a.value + b.value)
+    # the evaluate generated by IRBuilder
+    eval_actual = ib.get()
+
+    # the expected evaluate
+    eval_expected = tir.Evaluate(10)
+
+    # Check if the generated ir is expected
+    assert_structural_equal(eval_actual, eval_expected, map_free_vars=True)
+
+
 if __name__ == "__main__":
     tvm.testing.main()