You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/04/16 13:01:27 UTC

[tvm] branch main updated: [TIR] Add a new intrinsic count leading zeros for LLVM and SPIR-V (#7825)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cc79e8f  [TIR] Add a new intrinsic count leading zeros for LLVM and SPIR-V (#7825)
cc79e8f is described below

commit cc79e8fe3548cf6ee9297a19ae62d8de6e501663
Author: masahi <ma...@gmail.com>
AuthorDate: Fri Apr 16 22:01:17 2021 +0900

    [TIR] Add a new intrinsic count leading zeros for LLVM and SPIR-V (#7825)
---
 include/tvm/tir/op.h                     |  1 +
 python/tvm/tir/__init__.py               |  2 +-
 python/tvm/tir/op.py                     | 16 ++++++++++++++
 src/target/llvm/intrin_rule_llvm.cc      | 15 +++++++++++++
 src/target/spirv/intrin_rule_spirv.cc    | 22 ++++++++++++++++--
 src/tir/op/op.cc                         |  2 ++
 tests/python/unittest/test_tir_intrin.py | 38 ++++++++++++++++++++++++++++++++
 7 files changed, 93 insertions(+), 3 deletions(-)

diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index b5a62c9..9cf7d0a 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -864,6 +864,7 @@ TVM_DECLARE_INTRIN_UNARY(atan);
 TVM_DECLARE_INTRIN_UNARY(acosh);
 TVM_DECLARE_INTRIN_UNARY(asinh);
 TVM_DECLARE_INTRIN_UNARY(atanh);
+TVM_DECLARE_INTRIN_UNARY(clz);
 
 #define TVM_DECLARE_INTRIN_BINARY(OpName)                              \
   inline PrimExpr OpName(PrimExpr x, PrimExpr y, Span span = Span()) { \
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 681fc31..b348da8 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -37,7 +37,7 @@ from .function import PrimFunc
 
 from .op import call_packed, call_intrin, call_pure_extern, call_extern
 from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
-from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp
+from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
 from .op import sin, sinh, asin, asinh
 from .op import cos, cosh, acos, acosh
 from .op import tan, tanh, atan, atan2, atanh
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index a986c2f..6dd6c79 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -752,6 +752,22 @@ def rsqrt(x):
     return call_intrin(x.dtype, "tir.rsqrt", x)
 
 
+def clz(x):
+    """Count leading zero bits of an integer x.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument. The result is undefined if the input is 0.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return call_intrin("int32", "tir.clz", x)
+
+
 def floor(x, span=None):
     """Take floor of float input x.
 
diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc
index 4c8862b..093a746 100644
--- a/src/target/llvm/intrin_rule_llvm.cc
+++ b/src/target/llvm/intrin_rule_llvm.cc
@@ -160,6 +160,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sinh")
       *rv = ret;
     });
 
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.clz").set_body([](const TVMArgs& targs, TVMRetValue* rv) {
+  PrimExpr e = targs[0];
+  const tir::CallNode* call = e.as<tir::CallNode>();
+  ICHECK(call != nullptr);
+  ICHECK_EQ(call->args.size(), 1);
+  Array<PrimExpr> cargs;
+  cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz));
+  cargs.push_back(IntImm(DataType::UInt(32), 2));
+  cargs.push_back(call->args[0]);
+  cargs.push_back(IntImm(DataType::Int(1), 1));  // is_zero_undef
+  // LLVM requires that the return type must match the first argument type
+  auto clz = tir::Call(call->args[0]->dtype, tir::builtin::call_llvm_intrin(), cargs);
+  *rv = cast(call->dtype, clz);
+});
+
 }  // namespace llvm
 }  // namespace codegen
 }  // namespace tvm
diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc
index b75fb53..f77e8f4 100644
--- a/src/target/spirv/intrin_rule_spirv.cc
+++ b/src/target/spirv/intrin_rule_spirv.cc
@@ -24,6 +24,7 @@
 #include <tvm/runtime/registry.h>
 #include <tvm/tir/builtin.h>
 #include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
 
 namespace tvm {
 namespace codegen {
@@ -32,8 +33,9 @@ namespace spirv {
 using namespace runtime;
 
 // num_signature means number of arguments used to query signature
+
 template <unsigned id>
-inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
+PrimExpr CallGLSLIntrin(const TVMArgs& targs, TVMRetValue* rv) {
   PrimExpr e = targs[0];
   const tir::CallNode* call = e.as<tir::CallNode>();
   ICHECK(call != nullptr);
@@ -44,7 +46,12 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
   for (PrimExpr arg : call->args) {
     cargs.push_back(arg);
   }
-  *rv = tir::Call(call->dtype, tir::builtin::call_spirv_pure_glsl450(), cargs);
+  return tir::Call(call->dtype, tir::builtin::call_spirv_pure_glsl450(), cargs);
+}
+
+template <unsigned id>
+inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) {
+  *rv = CallGLSLIntrin<id>(targs, rv);
 }
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor")
@@ -76,6 +83,17 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow").set_body(DispatchGLSLPureIntri
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.tanh").set_body(DispatchGLSLPureIntrin<GLSLstd450Tanh>);
 
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.clz")
+    .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
+      PrimExpr e = targs[0];
+      const tir::CallNode* call = e.as<tir::CallNode>();
+      ICHECK(call != nullptr);
+      ICHECK_EQ(call->args.size(), 1);
+      PrimExpr arg = call->args[0];
+      PrimExpr msb = CallGLSLIntrin<GLSLstd450FindUMsb>(targs, rv);
+      *rv = PrimExpr(arg.dtype().bits() - 1) - msb;
+    });
+
 // WebGPU rules.
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.floor")
     .set_body(DispatchGLSLPureIntrin<GLSLstd450Floor>);
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 9fcb071..af78804 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -858,6 +858,8 @@ TIR_REGISTER_PURE_UNARY_OP("tir.asinh");
 
 TIR_REGISTER_PURE_UNARY_OP("tir.atanh");
 
+TIR_REGISTER_PURE_UNARY_OP("tir.clz");
+
 // binary intrinsics
 TIR_REGISTER_PURE_BINARY_OP("tir.atan2");
 
diff --git a/tests/python/unittest/test_tir_intrin.py b/tests/python/unittest/test_tir_intrin.py
index 755ffdf..0a82c91 100644
--- a/tests/python/unittest/test_tir_intrin.py
+++ b/tests/python/unittest/test_tir_intrin.py
@@ -142,9 +142,47 @@ def test_ldexp():
     )
 
 
+def test_clz():
+    def clz_np(x, dtype):
+        ceil_log2 = np.ceil(np.log2(x)).astype(dtype)
+        bits = int(dtype[-2:])
+        clz = bits - ceil_log2
+        clz[np.bitwise_and(x, x - 1) == 0] -= 1
+        return clz
+
+    for target in ["llvm", "vulkan"]:
+        if not tvm.testing.device_enabled("vulkan"):
+            continue
+
+        for dtype in ["int32", "int64"]:
+            m = te.var("m")
+            A = te.placeholder((m,), name="A", dtype=dtype)
+            B = te.compute((m,), lambda *i: tvm.tir.clz(A(*i)), name="B")
+            s = te.create_schedule(B.op)
+
+            if target == "vulkan":
+                bx, tx = s[B].split(B.op.axis[0], factor=64)
+
+                s[B].bind(bx, te.thread_axis("blockIdx.x"))
+                s[B].bind(tx, te.thread_axis("threadIdx.x"))
+
+            f = tvm.build(s, [A, B], target)
+            dev = tvm.device(target, 0)
+            n = 10
+
+            for high in [10, 100, 1000, 10000, 100000, 1000000]:
+                a_np = np.random.randint(1, high=high, size=(n,)).astype(dtype)
+                a = tvm.nd.array(a_np, dev)
+                b = tvm.nd.array(np.zeros((n,)).astype("int32"), dev)
+                f(a, b)
+                ref = clz_np(a_np, dtype)
+                np.testing.assert_equal(b.asnumpy(), ref)
+
+
 if __name__ == "__main__":
     test_nearbyint()
     test_unary_intrin()
     test_round_intrinsics_on_int()
     test_binary_intrin()
     test_ldexp()
+    test_clz()