You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/08/17 09:26:50 UTC

[tvm] branch main updated: Expose Missing TIR Builtins to Python (#12466)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bcc7cde95c Expose Missing TIR Builtins to Python (#12466)
bcc7cde95c is described below

commit bcc7cde95c1e84b85f18c07110489350865b8cfe
Author: Yaxing Cai <ca...@gmail.com>
AuthorDate: Wed Aug 17 02:26:45 2022 -0700

    Expose Missing TIR Builtins to Python (#12466)
    
    This PR exposes the following TIR operation in python:
    
    `address_of`: tested [here](https://github.com/apache/tvm/blob/d2f9f254d275df256dbcbc5a9f8b3a07cee1d81f/tests/python/unittest/test_tvmscript_roundtrip.py#L3247)
    `lookup_param`: tested [here](https://github.com/apache/tvm/blob/d2f9f254d275df256dbcbc5a9f8b3a07cee1d81f/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py#L171)
    `infinity`: add new unittest
    `reinterpret`: tested [here](https://github.com/apache/tvm/blob/d2f9f254d275df256dbcbc5a9f8b3a07cee1d81f/tests/python/unittest/test_tvmscript_roundtrip.py#L2991)
    `isnullptr`: tested [here](https://github.com/apache/tvm/blob/d2f9f254d275df256dbcbc5a9f8b3a07cee1d81f/tests/python/unittest/test_tvmscript_roundtrip.py#L260)
    
    Co-Authored-By: yongwww <yo...@gmail.com>
---
 include/tvm/tir/op.h                       |  1 +
 python/tvm/tir/__init__.py                 |  5 +-
 python/tvm/tir/op.py                       | 98 ++++++++++++++++++++++++++++++
 src/tir/op/op.cc                           |  2 +
 tests/python/unittest/test_tir_nodes.py    |  6 ++
 tests/python/unittest/test_tir_op_types.py | 27 ++++++++
 6 files changed, 137 insertions(+), 2 deletions(-)

diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 7236c6a611..94603307a7 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -701,6 +701,7 @@ TVM_DECLARE_INTRIN_UNARY(rsqrt);
 TVM_DECLARE_INTRIN_UNARY(log);
 TVM_DECLARE_INTRIN_UNARY(log2);
 TVM_DECLARE_INTRIN_UNARY(log10);
+TVM_DECLARE_INTRIN_UNARY(log1p);
 TVM_DECLARE_INTRIN_UNARY(popcount);
 TVM_DECLARE_INTRIN_UNARY(tan);
 TVM_DECLARE_INTRIN_UNARY(cos);
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 7980b5adaa..82b1089ac1 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -50,14 +50,15 @@ from .op import call_packed, call_cpacked, call_intrin, call_pure_extern, call_e
 from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
 from .op import tvm_stack_alloca, tvm_stack_make_shape, tvm_stack_make_array
 from .op import tvm_tuple, tvm_struct_get, tvm_struct_set
-from .op import assume, undef
+from .op import address_of, lookup_param, assume, undef
+from .op import infinity, reinterpret
 from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
 from .op import sin, sinh, asin, asinh
 from .op import cos, cosh, acos, acosh
 from .op import tan, tanh, atan, atan2, atanh
 from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
 from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else
-from .op import likely, isnan, isfinite, isinf, copysign
+from .op import likely, isnan, isnullptr, isfinite, isinf, copysign
 from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv
 from .op import comm_reducer, min, max, sum
 from .op import q_multiply_shift
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 1bb185120e..3e8dc52935 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -466,6 +466,44 @@ def tvm_struct_set(arr, index, field, value):
     return call_intrin("handle", "tir.tvm_struct_set", arr, index, field, value)
 
 
+def address_of(buffer_load, span=None):
+    """Returns the address of an element in the buffer
+
+    Parameters
+    ----------
+    buffer_load: BufferLoad
+        The buffer load.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("handle", "tir.address_of", buffer_load, span=span)
+
+
+def lookup_param(param_name, span=None):
+    """Returns the param by name
+
+    Parameters
+    ----------
+    param_name : str
+        The name of param.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin("handle", "tir.lookup_param", param_name, span=span)
+
+
 def ret(val):
     """Create a tir return expression
 
@@ -610,6 +648,47 @@ def max_value(dtype: str, span: Optional[Span] = None) -> Any:
     return _ffi_api.max_value(dtype, span)  # type: ignore
 
 
+def infinity(dtype: str, span: Optional[Span] = None) -> Any:
+    """infinity value of dtype
+
+    Parameters
+    ----------
+    dtype : str
+        The data type.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    value : tvm.Expr
+        The infinity value of dtype.
+    """
+    return _ffi_api.infinity(dtype, span)  # type: ignore
+
+
+def reinterpret(dtype, value) -> Any:
+    """infinity value of dtype
+
+    Parameters
+    ----------
+    dtype : str
+        The data type.
+
+    value : PrimExpr
+        The input value.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    value : tvm.Expr
+        The reinterpret cast value of dtype.
+    """
+    return call_intrin(dtype, "tir.reinterpret", value)
+
+
 def exp(x):
     """Take exponential of input x.
 
@@ -1253,6 +1332,25 @@ def isnan(x, span=None):
     return _ffi_api.isnan(x, span)  # type: ignore
 
 
+def isnullptr(x, span=None):
+    """Check if input value is nullptr.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    span : Optional[Span]
+        The location of this operator in the source code.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return call_intrin("bool", "tir.isnullptr", x, span=span)  # type: ignore
+
+
 def isfinite(x, span=None):
     """Check if input value is finite.
 
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 7879c9fee9..69d1da5e8c 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -929,6 +929,8 @@ TVM_REGISTER_GLOBAL("tir.min_value").set_body_typed(min_value);
 
 TVM_REGISTER_GLOBAL("tir.max_value").set_body_typed(max_value);
 
+TVM_REGISTER_GLOBAL("tir.infinity").set_body_typed(infinity);
+
 TVM_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs);
 
 TVM_REGISTER_GLOBAL("tir.likely").set_body_typed(tvm::likely);
diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py
index b4295411bf..c4ab76cd26 100644
--- a/tests/python/unittest/test_tir_nodes.py
+++ b/tests/python/unittest/test_tir_nodes.py
@@ -301,6 +301,12 @@ def test_divide_by_zero():
             pass
 
 
+def test_infinity():
+    assert str(tvm.tir.infinity("float16")) == "inff16"
+    assert str(tvm.tir.infinity("float32")) == "inff32"
+    assert str(tvm.tir.infinity("float64")) == "inff64"
+
+
 def test_isnan():
     x = te.var("x", "float32")
     assert str(tvm.tir.isnan(x)) == "@tir.isnan(x: float32, dtype=bool)"
diff --git a/tests/python/unittest/test_tir_op_types.py b/tests/python/unittest/test_tir_op_types.py
index 3bcdf8b086..e36eecf2ec 100644
--- a/tests/python/unittest/test_tir_op_types.py
+++ b/tests/python/unittest/test_tir_op_types.py
@@ -39,6 +39,29 @@ def test_tir_op_tvm_struct_set():
     assert expr.op.name == "tir.tvm_struct_set"
 
 
+def test_tir_op_address_of():
+    buffer = tir.decl_buffer((128), "float32")
+    expr = tir.address_of(buffer[0])
+    assert expr.op.name == "tir.address_of"
+
+
+def test_tir_op_lookup_param():
+    expr = tir.lookup_param("p0")
+    assert expr.op.name == "tir.lookup_param"
+
+
+def test_tir_op_reinterpret():
+    x = tir.Var("x", dtype="int32")
+    expr = tir.reinterpret("float32", x)
+    assert expr.op.name == "tir.reinterpret"
+
+
+def test_tir_op_isnullptr():
+    x = tir.Var("x", dtype="int32")
+    expr = tir.isnullptr(x)
+    assert expr.op.name == "tir.isnullptr"
+
+
 def test_tir_op_call_assume():
     x = tir.Var("x", dtype="int32")
     expr = tir.assume(cond=x)
@@ -60,6 +83,10 @@ if __name__ == "__main__":
     test_tir_op_tvm_tuple()
     test_tir_op_tvm_struct_get()
     test_tir_op_tvm_struct_set()
+    test_tir_op_address_of()
+    test_tir_op_lookup_param()
+    test_tir_op_reinterpret()
+    test_tir_op_isnullptr()
     test_tir_op_call_assume()
     test_tir_op_call_undef()
     test_tir_op_call_likely()