You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/07/15 00:22:38 UTC
[tvm] branch main updated: [TVMSCRIPT] Make ceildiv available from tvmscript (#12096)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 37f9d3c496 [TVMSCRIPT] Make ceildiv available from tvmscript (#12096)
37f9d3c496 is described below
commit 37f9d3c496bd32387f190ee31e4fa9bb525e7b85
Author: Tristan Konolige <tk...@octoml.ai>
AuthorDate: Thu Jul 14 17:22:33 2022 -0700
[TVMSCRIPT] Make ceildiv available from tvmscript (#12096)
---
python/tvm/script/tir/__init__.pyi | 1 +
python/tvm/script/tir/intrin.py | 5 +++++
python/tvm/tir/__init__.py | 2 +-
python/tvm/tir/op.py | 20 ++++++++++++++++++++
src/tir/op/op.cc | 1 +
tests/python/unittest/test_tvmscript_ops.py | 15 +++++++++++++++
6 files changed, 43 insertions(+), 1 deletion(-)
diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi
index 1c5687da52..f03c5c06da 100644
--- a/python/tvm/script/tir/__init__.pyi
+++ b/python/tvm/script/tir/__init__.pyi
@@ -93,6 +93,7 @@ def min_value(dtype: str) -> PrimExpr: ...
def max_value(dtype: str) -> PrimExpr: ...
def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def ceildiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
def abs(x: PrimExpr) -> PrimExpr: ...
diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py
index 2099b86dad..bd3b171127 100644
--- a/python/tvm/script/tir/intrin.py
+++ b/python/tvm/script/tir/intrin.py
@@ -137,6 +137,11 @@ def truncmod(x, y, span):
return tvm.tir.truncmod(x, y, span)
+@register
+def ceildiv(x, y, span):
+ return tvm.tir.ceildiv(x, y, span)
+
+
@register
def abs(x, span):
return tvm.tir.abs(x, span)
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 6db93b6ad0..a3798ccab4 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -53,7 +53,7 @@ from .op import tan, tanh, atan, atan2, atanh
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else
from .op import isnan, isfinite, isinf, copysign
-from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
+from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv
from .op import comm_reducer, min, max, sum
from .op import q_multiply_shift
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 5d15bf15da..17005b04a4 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -1350,6 +1350,26 @@ def floormod(a, b, span=None):
return _ffi_api._OpFloorMod(a, b, span) # type: ignore
+def ceildiv(lhs, rhs, span=None):
+ """Generic ceildiv operator.
+
+ Parameters
+ ----------
+ lhs : object
+ The left operand.
+ rhs : object
+ The right operand.
+ span : Optional[Span]
+ The location of this operator in the source.
+
+ Returns
+ -------
+ op : tvm.Expr
+ The result Expr of ceildiv operaton.
+ """
+ return _ffi_api._OpCeilDiv(lhs, rhs, span) # type: ignore
+
+
def comm_reducer(fcombine, fidentity, name="reduce"):
"""Create a commutative reducer for reduction.
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 456453a274..114571218b 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -979,6 +979,7 @@ REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod);
+REGISTER_MAKE_BINARY_OP(_OpCeilDiv, ceildiv);
REGISTER_MAKE_BINARY_OP(_OpPow, pow);
REGISTER_MAKE_BINARY_OP(_OpMin, min);
REGISTER_MAKE_BINARY_OP(_OpMax, max);
diff --git a/tests/python/unittest/test_tvmscript_ops.py b/tests/python/unittest/test_tvmscript_ops.py
index 82f0fa5c86..3f30c6ddb0 100644
--- a/tests/python/unittest/test_tvmscript_ops.py
+++ b/tests/python/unittest/test_tvmscript_ops.py
@@ -162,6 +162,21 @@ def test_alloc_zero_dim_buffer_round_trip():
_check_alloc_zero_dim_buffer(rt_mod_with_block)
+@T.prim_func
+def ceildiv_test(A: T.Buffer[16, "int32"]):
+ for i in range(16):
+ A[i] = T.ceildiv(A[i], 4)
+
+
+@tvm.testing.requires_llvm
+def test_ceildiv():
+ f = tvm.build(ceildiv_test, "llvm")
+ a = tvm.nd.array(np.arange(16).astype("int32"))
+ f(a)
+ ref = (np.arange(16) + 3) // 4
+ tvm.testing.assert_allclose(a.numpy(), ref)
+
+
if __name__ == "__main__":
test_get_valid_counts_script_func()
test_alloc_zero_dim_buffer_round_trip()