You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/07/15 00:22:38 UTC

[tvm] branch main updated: [TVMSCRIPT] Make ceildiv available from tvmscript (#12096)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 37f9d3c496 [TVMSCRIPT] Make ceildiv available from tvmscript (#12096)
37f9d3c496 is described below

commit 37f9d3c496bd32387f190ee31e4fa9bb525e7b85
Author: Tristan Konolige <tk...@octoml.ai>
AuthorDate: Thu Jul 14 17:22:33 2022 -0700

    [TVMSCRIPT] Make ceildiv available from tvmscript (#12096)
---
 python/tvm/script/tir/__init__.pyi          |  1 +
 python/tvm/script/tir/intrin.py             |  5 +++++
 python/tvm/tir/__init__.py                  |  2 +-
 python/tvm/tir/op.py                        | 20 ++++++++++++++++++++
 src/tir/op/op.cc                            |  1 +
 tests/python/unittest/test_tvmscript_ops.py | 15 +++++++++++++++
 6 files changed, 43 insertions(+), 1 deletion(-)

diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi
index 1c5687da52..f03c5c06da 100644
--- a/python/tvm/script/tir/__init__.pyi
+++ b/python/tvm/script/tir/__init__.pyi
@@ -93,6 +93,7 @@ def min_value(dtype: str) -> PrimExpr: ...
 def max_value(dtype: str) -> PrimExpr: ...
 def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
 def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
+def ceildiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
 def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
 def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ...
 def abs(x: PrimExpr) -> PrimExpr: ...
diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py
index 2099b86dad..bd3b171127 100644
--- a/python/tvm/script/tir/intrin.py
+++ b/python/tvm/script/tir/intrin.py
@@ -137,6 +137,11 @@ def truncmod(x, y, span):
     return tvm.tir.truncmod(x, y, span)
 
 
+@register
+def ceildiv(x, y, span):
+    return tvm.tir.ceildiv(x, y, span)
+
+
 @register
 def abs(x, span):
     return tvm.tir.abs(x, span)
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 6db93b6ad0..a3798ccab4 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -53,7 +53,7 @@ from .op import tan, tanh, atan, atan2, atanh
 from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
 from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else
 from .op import isnan, isfinite, isinf, copysign
-from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
+from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv
 from .op import comm_reducer, min, max, sum
 from .op import q_multiply_shift
 
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 5d15bf15da..17005b04a4 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -1350,6 +1350,26 @@ def floormod(a, b, span=None):
     return _ffi_api._OpFloorMod(a, b, span)  # type: ignore
 
 
+def ceildiv(lhs, rhs, span=None):
+    """Generic ceildiv operator.
+
+    Parameters
+    ----------
+    lhs : object
+        The left operand.
+    rhs : object
+        The right operand.
+    span : Optional[Span]
+        The location of this operator in the source.
+
+    Returns
+    -------
+    op : tvm.Expr
+        The result Expr of ceildiv operaton.
+    """
+    return _ffi_api._OpCeilDiv(lhs, rhs, span)  # type: ignore
+
+
 def comm_reducer(fcombine, fidentity, name="reduce"):
     """Create a commutative reducer for reduction.
 
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 456453a274..114571218b 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -979,6 +979,7 @@ REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
 REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
 REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
 REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod);
+REGISTER_MAKE_BINARY_OP(_OpCeilDiv, ceildiv);
 REGISTER_MAKE_BINARY_OP(_OpPow, pow);
 REGISTER_MAKE_BINARY_OP(_OpMin, min);
 REGISTER_MAKE_BINARY_OP(_OpMax, max);
diff --git a/tests/python/unittest/test_tvmscript_ops.py b/tests/python/unittest/test_tvmscript_ops.py
index 82f0fa5c86..3f30c6ddb0 100644
--- a/tests/python/unittest/test_tvmscript_ops.py
+++ b/tests/python/unittest/test_tvmscript_ops.py
@@ -162,6 +162,21 @@ def test_alloc_zero_dim_buffer_round_trip():
     _check_alloc_zero_dim_buffer(rt_mod_with_block)
 
 
+@T.prim_func
+def ceildiv_test(A: T.Buffer[16, "int32"]):
+    for i in range(16):
+        A[i] = T.ceildiv(A[i], 4)
+
+
+@tvm.testing.requires_llvm
+def test_ceildiv():
+    f = tvm.build(ceildiv_test, "llvm")
+    a = tvm.nd.array(np.arange(16).astype("int32"))
+    f(a)
+    ref = (np.arange(16) + 3) // 4
+    tvm.testing.assert_allclose(a.numpy(), ref)
+
+
 if __name__ == "__main__":
     test_get_valid_counts_script_func()
     test_alloc_zero_dim_buffer_round_trip()