You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/08/17 19:06:26 UTC

[tvm] branch main updated: [TIR] Expose TVM Backend API-related Builtins and Misc (#12468)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b0b9bd976b [TIR] Expose TVM Backend API-related Builtins and Misc (#12468)
b0b9bd976b is described below

commit b0b9bd976ba14cfc8b224396ab065ff03a34f46a
Author: Yaxing Cai <ca...@gmail.com>
AuthorDate: Wed Aug 17 12:06:19 2022 -0700

    [TIR] Expose TVM Backend API-related Builtins and Misc (#12468)
    
    This PR exposes the following TIR operation in python:
    
    `tvm_thread_allreduce`: tested [here](https://github.com/apache/tvm/blob/bcc7cde95c1e84b85f18c07110489350865b8cfe/tests/python/unittest/test_tvmscript_type.py#L135)
    `type_annotation`: tested [here](https://github.com/apache/tvm/blob/bcc7cde95c1e84b85f18c07110489350865b8cfe/tests/python/unittest/test_tvmscript_roundtrip.py#L718)
    `tvm_access_ptr`: tested [here](https://github.com/apache/tvm/blob/bcc7cde95c1e84b85f18c07110489350865b8cfe/tests/python/unittest/test_tvmscript_roundtrip.py#L717)
    `tvm_throw_last_error`: tested [here](https://github.com/apache/tvm/blob/bcc7cde95c1e84b85f18c07110489350865b8cfe/tests/python/unittest/test_tvmscript_roundtrip.py#L343)
    `TVMBackendAllocWorkspace`: tested [here](https://github.com/apache/tvm/blob/bcc7cde95c1e84b85f18c07110489350865b8cfe/tests/python/unittest/test_tvmscript_roundtrip.py#L340)
    `TVMBackendAllocWorkspace`: tested [here](https://github.com/apache/tvm/blob/bcc7cde95c1e84b85f18c07110489350865b8cfe/tests/python/unittest/test_tvmscript_roundtrip.py#L465)
    
    Co-Authored-By: yongwww <yo...@gmail.com>
---
 python/tvm/tir/__init__.py                 |   2 +
 python/tvm/tir/op.py                       | 159 ++++++++++++++++++++++++++++-
 tests/python/unittest/test_tir_op_types.py |  42 ++++++++
 3 files changed, 198 insertions(+), 5 deletions(-)

diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 82b1089ac1..7ea8c02bed 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -51,6 +51,7 @@ from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_valu
 from .op import tvm_stack_alloca, tvm_stack_make_shape, tvm_stack_make_array
 from .op import tvm_tuple, tvm_struct_get, tvm_struct_set
 from .op import address_of, lookup_param, assume, undef
+from .op import tvm_thread_allreduce, type_annotation, tvm_access_ptr, tvm_throw_last_error
 from .op import infinity, reinterpret
 from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
 from .op import sin, sinh, asin, asinh
@@ -62,6 +63,7 @@ from .op import likely, isnan, isnullptr, isfinite, isinf, copysign
 from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv
 from .op import comm_reducer, min, max, sum
 from .op import q_multiply_shift
+from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace
 
 from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
 
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 3e8dc52935..19ce4f4bc1 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -16,6 +16,7 @@
 # under the License.
 # pylint: disable=redefined-builtin, invalid-name
 """Operators used in TIR expression."""
+import warnings
 from typing import Any, Optional
 import tvm._ffi
 from tvm.ir.base import Span
@@ -262,10 +263,22 @@ def call_llvm_intrin(dtype, name, *args, span=None):
     # pylint: disable=import-outside-toplevel
     from tvm.target import codegen
 
-    llvm_id = codegen.llvm_lookup_intrinsic_id(name)
-    assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
+    from .expr import IntImm
+
+    if isinstance(name, str):
+        llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+    elif isinstance(name, IntImm):
+        llvm_id = name.value
+    else:
+        llvm_id = name
+    if llvm_id == 0:
+        warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0")
     return call_intrin(
-        dtype, Op.get("tir.call_llvm_intrin"), tvm.tir.const(llvm_id, "uint32"), *args, span=span
+        dtype,
+        Op.get("tir.call_llvm_intrin"),
+        tvm.tir.const(llvm_id, "uint32"),
+        *args,
+        span=span,
     )
 
 
@@ -294,8 +307,16 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
     # pylint: disable=import-outside-toplevel
     from tvm.target import codegen
 
-    llvm_id = codegen.llvm_lookup_intrinsic_id(name)
-    assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
+    from .expr import IntImm
+
+    if isinstance(name, str):
+        llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+    elif isinstance(name, IntImm):
+        llvm_id = name.value
+    else:
+        llvm_id = name
+    if llvm_id == 0:
+        warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0")
     return call_intrin(
         dtype,
         Op.get("tir.call_llvm_pure_intrin"),
@@ -504,6 +525,76 @@ def lookup_param(param_name, span=None):
     return call_intrin("handle", "tir.lookup_param", param_name, span=span)
 
 
+def tvm_thread_allreduce(*freduce_args):
+    """
+    Parameters
+    ----------
+    freduce_args : Expr
+        The args.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args)
+
+
+def type_annotation(dtype):
+    """Create a type annotation expression
+
+    Parameters
+    ----------
+    dtype : Expr
+        The data type.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin(dtype, "tir.type_annotation")
+
+
+def tvm_access_ptr(ptype, data, offset, extent, rw_mask):
+    """Get head access address with memory access pattern info
+
+    Parameters
+    ----------
+    ptype : Expr
+        The data type of pointer.
+
+    data : DType*
+        The data of pointer.
+
+    offset : int
+        The offset of pointer.
+
+    extent : int
+        The extent of pointer.
+
+    rw_mask : int
+        The read write mask.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("handle", "tir.tvm_access_ptr", ptype, data, offset, extent, rw_mask)
+
+
+def tvm_throw_last_error():
+    """Throw TVMGetLastError()
+
+    Returns
+    -------
+    ret : PrimExpr
+        The return expression
+    """
+    return call_intrin("handle", "tir.tvm_throw_last_error")
+
+
 def ret(val):
     """Create a tir return expression
 
@@ -1857,6 +1948,64 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
     return reducer
 
 
+def TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint):
+    """Backend function to allocate temporal workspace
+
+    Parameters
+    ----------
+    device_type : int
+        The device type which the space will be allocated.
+
+    device_id : int
+        The device id which the space will be allocated.
+
+    nbytes : int
+        The size of the space requested.
+
+    dtype_code_hint : int
+        The type code of the array elements. Only used in certain backends such as OpenGL.
+
+    dtype_bits_hint : int
+        The type bits of the array elements. Only used in certain backends such as OpenGL.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin(
+        "handle",
+        "tir.TVMBackendAllocWorkspace",
+        device_type,
+        device_id,
+        nbytes,
+        dtype_code_hint,
+        dtype_bits_hint,
+    )
+
+
+def TVMBackendFreeWorkspace(device_type, device_id, ptr):
+    """Backend function to free temporal workspace.
+
+    Parameters
+    ----------
+    device_type : int
+        The device type which the space will be allocated.
+
+    device_id : int
+        The device id which the space will be allocated.
+
+    ptr : Var
+        The result allocated space pointer.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, device_id, ptr)
+
+
 # pylint: disable=unnecessary-lambda
 sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum")
 min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min")  # type: ignore
diff --git a/tests/python/unittest/test_tir_op_types.py b/tests/python/unittest/test_tir_op_types.py
index e36eecf2ec..ffee3b3b57 100644
--- a/tests/python/unittest/test_tir_op_types.py
+++ b/tests/python/unittest/test_tir_op_types.py
@@ -79,6 +79,42 @@ def test_tir_op_call_likely():
     assert expr.op.name == "tir.likely"
 
 
+def test_tir_op_tvm_thread_allreduce():
+    x = tir.Var("x", "int32")
+    buffer = tir.decl_buffer((128), "float32")
+    y = tir.Var("y", "handle")
+    z = tir.Var("z", "int32")
+    expr = tir.tvm_thread_allreduce(x, buffer[0], True, y, z)
+    assert expr.op.name == "tir.tvm_thread_allreduce"
+
+
+def test_tir_op_type_annotation():
+    expr = tir.type_annotation("int32")
+    assert expr.op.name == "tir.type_annotation"
+
+
+def test_tir_op_tvm_access_ptr():
+    buffer = tir.decl_buffer((128), "float32")
+    expr = tir.tvm_access_ptr("float32", buffer.data, 0, 1, 2)
+    assert expr.op.name == "tir.tvm_access_ptr"
+
+
+def test_tir_op_tvm_throw_last_error():
+    expr = tir.tvm_throw_last_error()
+    assert expr.op.name == "tir.tvm_throw_last_error"
+
+
+def test_tir_op_TVMBackendAllocWorkspace():
+    expr = tir.TVMBackendAllocWorkspace(0, 1, 2, 3, 4)
+    assert expr.op.name == "tir.TVMBackendAllocWorkspace"
+
+
+def test_tir_op_TVMBackendFreeWorkspace():
+    buffer = tir.decl_buffer((128), "float32")
+    expr = tir.TVMBackendFreeWorkspace(0, 1, buffer.data)
+    assert expr.op.name == "tir.TVMBackendFreeWorkspace"
+
+
 if __name__ == "__main__":
     test_tir_op_tvm_tuple()
     test_tir_op_tvm_struct_get()
@@ -90,3 +126,9 @@ if __name__ == "__main__":
     test_tir_op_call_assume()
     test_tir_op_call_undef()
     test_tir_op_call_likely()
+    test_tir_op_tvm_thread_allreduce()
+    test_tir_op_type_annotation()
+    test_tir_op_tvm_access_ptr()
+    test_tir_op_tvm_throw_last_error()
+    test_tir_op_TVMBackendAllocWorkspace()
+    test_tir_op_TVMBackendFreeWorkspace()