You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2021/04/24 19:36:44 UTC

[tvm] branch main updated: [TIR][SPIR-V] Fix computing clz on int64 input for vulkan (#7913)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a741652  [TIR][SPIR-V] Fix computing clz on int64 input for vulkan (#7913)
a741652 is described below

commit a741652f37c9e5f68e2c1eb6edcb103ba9b45c89
Author: masahi <ma...@gmail.com>
AuthorDate: Sun Apr 25 04:36:22 2021 +0900

    [TIR][SPIR-V] Fix computing clz on int64 input for vulkan (#7913)
    
    * Fix computing clz on int64 input for vulkan
    
    * rebase fix
    
    Co-authored-by: masa <ma...@pop-os.localdomain>
---
 python/tvm/tir/op.py                     |  3 ++-
 src/target/spirv/intrin_rule_spirv.cc    | 27 +++++++++++++++++++++++----
 tests/python/unittest/test_tir_intrin.py |  7 ++++++-
 3 files changed, 31 insertions(+), 6 deletions(-)

diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 6dd6c79..874724b 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -758,7 +758,8 @@ def clz(x):
     Parameters
     ----------
     x : PrimExpr
-        Input argument. The result is undefined if the input is 0.
+        Input 32 or 64 bit integer.
+        The result is undefined if the input is 0.
 
     Returns
     -------
diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc
index 3baa77e..fa38f8f 100644
--- a/src/target/spirv/intrin_rule_spirv.cc
+++ b/src/target/spirv/intrin_rule_spirv.cc
@@ -33,22 +33,28 @@ namespace spirv {
 using tir::FLowerIntrinsic;
 
 // num_signature means number of arguments used to query signature
-
 template <unsigned id>
-PrimExpr CallGLSLIntrin(const PrimExpr& e) {
+PrimExpr CallGLSLIntrin(PrimExpr e, const Array<PrimExpr>& args) {
   const tir::CallNode* call = e.as<tir::CallNode>();
   ICHECK(call != nullptr);
   Array<PrimExpr> cargs;
   // intrin id.
   cargs.push_back(IntImm(DataType::UInt(32), id));
 
-  for (PrimExpr arg : call->args) {
+  for (PrimExpr arg : args) {
     cargs.push_back(arg);
   }
   return tir::Call(call->dtype, tir::builtin::call_spirv_pure_glsl450(), cargs);
 }
 
 template <unsigned id>
+PrimExpr CallGLSLIntrin(PrimExpr e) {
+  const tir::CallNode* call = e.as<tir::CallNode>();
+  ICHECK(call != nullptr);
+  return CallGLSLIntrin<id>(e, call->args);
+}
+
+template <unsigned id>
 inline PrimExpr DispatchGLSLPureIntrin(const PrimExpr& e) {
   return CallGLSLIntrin<id>(e);
 }
@@ -98,7 +104,20 @@ TVM_REGISTER_OP("tir.clz").set_attr<FLowerIntrinsic>(
       ICHECK(call != nullptr);
       ICHECK_EQ(call->args.size(), 1);
       PrimExpr arg = call->args[0];
-      PrimExpr msb = CallGLSLIntrin<GLSLstd450FindUMsb>(e);
+      PrimExpr msb;
+      if (arg.dtype().bits() == 64) {
+        // SPIR-V FindUMsb intrinsic only supports 32 bit input
+        auto int32 = DataType::Int(32);
+        PrimExpr arg_hi32 = tvm::tir::Cast(int32, arg >> 32);
+        PrimExpr arg_lo32 = tvm::tir::Cast(int32, arg);
+        PrimExpr msb_hi = CallGLSLIntrin<GLSLstd450FindUMsb>(e, {arg_hi32});
+        PrimExpr msb_lo = CallGLSLIntrin<GLSLstd450FindUMsb>(e, {arg_lo32});
+        msb = tvm::if_then_else(arg_hi32 == 0, msb_lo, msb_hi + 32);
+      } else if (arg.dtype().bits() == 32) {
+        msb = CallGLSLIntrin<GLSLstd450FindUMsb>(e);
+      } else {
+        LOG(FATAL) << "SPIR-V clz only supports a 32 bit or 64 bit integer.";
+      }
       return PrimExpr(arg.dtype().bits() - 1) - msb;
     });
 
diff --git a/tests/python/unittest/test_tir_intrin.py b/tests/python/unittest/test_tir_intrin.py
index 0a82c91..8512d1c 100644
--- a/tests/python/unittest/test_tir_intrin.py
+++ b/tests/python/unittest/test_tir_intrin.py
@@ -170,7 +170,12 @@ def test_clz():
             dev = tvm.device(target, 0)
             n = 10
 
-            for high in [10, 100, 1000, 10000, 100000, 1000000]:
+            highs = [10, 100, 1000, 10000, 100000, 1000000]
+
+            if dtype == "int64":
+                highs.append((1 << 63) - 1)
+
+            for high in highs:
                 a_np = np.random.randint(1, high=high, size=(n,)).astype(dtype)
                 a = tvm.nd.array(a_np, dev)
                 b = tvm.nd.array(np.zeros((n,)).astype("int32"), dev)