You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/08/17 19:06:26 UTC
[tvm] branch main updated: [TIR] Expose TVM Backend API-related Builtins and Misc (#12468)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b0b9bd976b [TIR] Expose TVM Backend API-related Builtins and Misc (#12468)
b0b9bd976b is described below
commit b0b9bd976ba14cfc8b224396ab065ff03a34f46a
Author: Yaxing Cai <ca...@gmail.com>
AuthorDate: Wed Aug 17 12:06:19 2022 -0700
[TIR] Expose TVM Backend API-related Builtins and Misc (#12468)
This PR exposes the following TIR operation in python:
`tvm_thread_allreduce`: tested [here](https://github.com/apache/tvm/blob/bcc7cde95c1e84b85f18c07110489350865b8cfe/tests/python/unittest/test_tvmscript_type.py#L135)
`type_annotation`: tested [here](https://github.com/apache/tvm/blob/bcc7cde95c1e84b85f18c07110489350865b8cfe/tests/python/unittest/test_tvmscript_roundtrip.py#L718)
`tvm_access_ptr`: tested [here](https://github.com/apache/tvm/blob/bcc7cde95c1e84b85f18c07110489350865b8cfe/tests/python/unittest/test_tvmscript_roundtrip.py#L717)
`tvm_throw_last_error`: tested [here](https://github.com/apache/tvm/blob/bcc7cde95c1e84b85f18c07110489350865b8cfe/tests/python/unittest/test_tvmscript_roundtrip.py#L343)
`TVMBackendAllocWorkspace`: tested [here](https://github.com/apache/tvm/blob/bcc7cde95c1e84b85f18c07110489350865b8cfe/tests/python/unittest/test_tvmscript_roundtrip.py#L340)
`TVMBackendAllocWorkspace`: tested [here](https://github.com/apache/tvm/blob/bcc7cde95c1e84b85f18c07110489350865b8cfe/tests/python/unittest/test_tvmscript_roundtrip.py#L465)
Co-Authored-By: yongwww <yo...@gmail.com>
---
python/tvm/tir/__init__.py | 2 +
python/tvm/tir/op.py | 159 ++++++++++++++++++++++++++++-
tests/python/unittest/test_tir_op_types.py | 42 ++++++++
3 files changed, 198 insertions(+), 5 deletions(-)
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 82b1089ac1..7ea8c02bed 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -51,6 +51,7 @@ from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_valu
from .op import tvm_stack_alloca, tvm_stack_make_shape, tvm_stack_make_array
from .op import tvm_tuple, tvm_struct_get, tvm_struct_set
from .op import address_of, lookup_param, assume, undef
+from .op import tvm_thread_allreduce, type_annotation, tvm_access_ptr, tvm_throw_last_error
from .op import infinity, reinterpret
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
from .op import sin, sinh, asin, asinh
@@ -62,6 +63,7 @@ from .op import likely, isnan, isnullptr, isfinite, isinf, copysign
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv
from .op import comm_reducer, min, max, sum
from .op import q_multiply_shift
+from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace
from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 3e8dc52935..19ce4f4bc1 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -16,6 +16,7 @@
# under the License.
# pylint: disable=redefined-builtin, invalid-name
"""Operators used in TIR expression."""
+import warnings
from typing import Any, Optional
import tvm._ffi
from tvm.ir.base import Span
@@ -262,10 +263,22 @@ def call_llvm_intrin(dtype, name, *args, span=None):
# pylint: disable=import-outside-toplevel
from tvm.target import codegen
- llvm_id = codegen.llvm_lookup_intrinsic_id(name)
- assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
+ from .expr import IntImm
+
+ if isinstance(name, str):
+ llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+ elif isinstance(name, IntImm):
+ llvm_id = name.value
+ else:
+ llvm_id = name
+ if llvm_id == 0:
+ warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0")
return call_intrin(
- dtype, Op.get("tir.call_llvm_intrin"), tvm.tir.const(llvm_id, "uint32"), *args, span=span
+ dtype,
+ Op.get("tir.call_llvm_intrin"),
+ tvm.tir.const(llvm_id, "uint32"),
+ *args,
+ span=span,
)
@@ -294,8 +307,16 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None):
# pylint: disable=import-outside-toplevel
from tvm.target import codegen
- llvm_id = codegen.llvm_lookup_intrinsic_id(name)
- assert llvm_id != 0, "%s is not an LLVM intrinsic" % name
+ from .expr import IntImm
+
+ if isinstance(name, str):
+ llvm_id = codegen.llvm_lookup_intrinsic_id(name)
+ elif isinstance(name, IntImm):
+ llvm_id = name.value
+ else:
+ llvm_id = name
+ if llvm_id == 0:
+ warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0")
return call_intrin(
dtype,
Op.get("tir.call_llvm_pure_intrin"),
@@ -504,6 +525,76 @@ def lookup_param(param_name, span=None):
return call_intrin("handle", "tir.lookup_param", param_name, span=span)
+def tvm_thread_allreduce(*freduce_args):
+ """
+ Parameters
+ ----------
+ freduce_args : Expr
+ The args.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args)
+
+
+def type_annotation(dtype):
+ """Create a type annotation expression
+
+ Parameters
+ ----------
+ dtype : Expr
+ The data type.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin(dtype, "tir.type_annotation")
+
+
+def tvm_access_ptr(ptype, data, offset, extent, rw_mask):
+ """Get head access address with memory access pattern info
+
+ Parameters
+ ----------
+ ptype : Expr
+ The data type of pointer.
+
+ data : DType*
+ The data of pointer.
+
+ offset : int
+ The offset of pointer.
+
+ extent : int
+ The extent of pointer.
+
+ rw_mask : int
+ The read write mask.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin("handle", "tir.tvm_access_ptr", ptype, data, offset, extent, rw_mask)
+
+
+def tvm_throw_last_error():
+ """Throw TVMGetLastError()
+
+ Returns
+ -------
+ ret : PrimExpr
+ The return expression
+ """
+ return call_intrin("handle", "tir.tvm_throw_last_error")
+
+
def ret(val):
"""Create a tir return expression
@@ -1857,6 +1948,64 @@ def comm_reducer(fcombine, fidentity, name="reduce"):
return reducer
+def TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dtype_bits_hint):
+ """Backend function to allocate temporal workspace
+
+ Parameters
+ ----------
+ device_type : int
+ The device type which the space will be allocated.
+
+ device_id : int
+ The device id which the space will be allocated.
+
+ nbytes : int
+ The size of the space requested.
+
+ dtype_code_hint : int
+ The type code of the array elements. Only used in certain backends such as OpenGL.
+
+ dtype_bits_hint : int
+ The type bits of the array elements. Only used in certain backends such as OpenGL.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin(
+ "handle",
+ "tir.TVMBackendAllocWorkspace",
+ device_type,
+ device_id,
+ nbytes,
+ dtype_code_hint,
+ dtype_bits_hint,
+ )
+
+
+def TVMBackendFreeWorkspace(device_type, device_id, ptr):
+ """Backend function to free temporal workspace.
+
+ Parameters
+ ----------
+ device_type : int
+ The device type which the space will be allocated.
+
+ device_id : int
+ The device id which the space will be allocated.
+
+ ptr : Var
+ The result allocated space pointer.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, device_id, ptr)
+
+
# pylint: disable=unnecessary-lambda
sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum")
min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore
diff --git a/tests/python/unittest/test_tir_op_types.py b/tests/python/unittest/test_tir_op_types.py
index e36eecf2ec..ffee3b3b57 100644
--- a/tests/python/unittest/test_tir_op_types.py
+++ b/tests/python/unittest/test_tir_op_types.py
@@ -79,6 +79,42 @@ def test_tir_op_call_likely():
assert expr.op.name == "tir.likely"
+def test_tir_op_tvm_thread_allreduce():
+ x = tir.Var("x", "int32")
+ buffer = tir.decl_buffer((128), "float32")
+ y = tir.Var("y", "handle")
+ z = tir.Var("z", "int32")
+ expr = tir.tvm_thread_allreduce(x, buffer[0], True, y, z)
+ assert expr.op.name == "tir.tvm_thread_allreduce"
+
+
+def test_tir_op_type_annotation():
+ expr = tir.type_annotation("int32")
+ assert expr.op.name == "tir.type_annotation"
+
+
+def test_tir_op_tvm_access_ptr():
+ buffer = tir.decl_buffer((128), "float32")
+ expr = tir.tvm_access_ptr("float32", buffer.data, 0, 1, 2)
+ assert expr.op.name == "tir.tvm_access_ptr"
+
+
+def test_tir_op_tvm_throw_last_error():
+ expr = tir.tvm_throw_last_error()
+ assert expr.op.name == "tir.tvm_throw_last_error"
+
+
+def test_tir_op_TVMBackendAllocWorkspace():
+ expr = tir.TVMBackendAllocWorkspace(0, 1, 2, 3, 4)
+ assert expr.op.name == "tir.TVMBackendAllocWorkspace"
+
+
+def test_tir_op_TVMBackendFreeWorkspace():
+ buffer = tir.decl_buffer((128), "float32")
+ expr = tir.TVMBackendFreeWorkspace(0, 1, buffer.data)
+ assert expr.op.name == "tir.TVMBackendFreeWorkspace"
+
+
if __name__ == "__main__":
test_tir_op_tvm_tuple()
test_tir_op_tvm_struct_get()
@@ -90,3 +126,9 @@ if __name__ == "__main__":
test_tir_op_call_assume()
test_tir_op_call_undef()
test_tir_op_call_likely()
+ test_tir_op_tvm_thread_allreduce()
+ test_tir_op_type_annotation()
+ test_tir_op_tvm_access_ptr()
+ test_tir_op_tvm_throw_last_error()
+ test_tir_op_TVMBackendAllocWorkspace()
+ test_tir_op_TVMBackendFreeWorkspace()