You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2021/04/24 19:36:44 UTC
[tvm] branch main updated: [TIR][SPIR-V] Fix computing clz on int64
input for vulkan (#7913)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a741652 [TIR][SPIR-V] Fix computing clz on int64 input for vulkan (#7913)
a741652 is described below
commit a741652f37c9e5f68e2c1eb6edcb103ba9b45c89
Author: masahi <ma...@gmail.com>
AuthorDate: Sun Apr 25 04:36:22 2021 +0900
[TIR][SPIR-V] Fix computing clz on int64 input for vulkan (#7913)
* Fix computing clz on int64 input for vulkan
* rebase fix
Co-authored-by: masa <ma...@pop-os.localdomain>
---
python/tvm/tir/op.py | 3 ++-
src/target/spirv/intrin_rule_spirv.cc | 27 +++++++++++++++++++++++----
tests/python/unittest/test_tir_intrin.py | 7 ++++++-
3 files changed, 31 insertions(+), 6 deletions(-)
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 6dd6c79..874724b 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -758,7 +758,8 @@ def clz(x):
Parameters
----------
x : PrimExpr
- Input argument. The result is undefined if the input is 0.
+ Input 32 or 64 bit integer.
+ The result is undefined if the input is 0.
Returns
-------
diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc
index 3baa77e..fa38f8f 100644
--- a/src/target/spirv/intrin_rule_spirv.cc
+++ b/src/target/spirv/intrin_rule_spirv.cc
@@ -33,22 +33,28 @@ namespace spirv {
using tir::FLowerIntrinsic;
// num_signature means number of arguments used to query signature
-
template <unsigned id>
-PrimExpr CallGLSLIntrin(const PrimExpr& e) {
+PrimExpr CallGLSLIntrin(PrimExpr e, const Array<PrimExpr>& args) {
const tir::CallNode* call = e.as<tir::CallNode>();
ICHECK(call != nullptr);
Array<PrimExpr> cargs;
// intrin id.
cargs.push_back(IntImm(DataType::UInt(32), id));
- for (PrimExpr arg : call->args) {
+ for (PrimExpr arg : args) {
cargs.push_back(arg);
}
return tir::Call(call->dtype, tir::builtin::call_spirv_pure_glsl450(), cargs);
}
template <unsigned id>
+PrimExpr CallGLSLIntrin(PrimExpr e) {
+ const tir::CallNode* call = e.as<tir::CallNode>();
+ ICHECK(call != nullptr);
+ return CallGLSLIntrin<id>(e, call->args);
+}
+
+template <unsigned id>
inline PrimExpr DispatchGLSLPureIntrin(const PrimExpr& e) {
return CallGLSLIntrin<id>(e);
}
@@ -98,7 +104,20 @@ TVM_REGISTER_OP("tir.clz").set_attr<FLowerIntrinsic>(
ICHECK(call != nullptr);
ICHECK_EQ(call->args.size(), 1);
PrimExpr arg = call->args[0];
- PrimExpr msb = CallGLSLIntrin<GLSLstd450FindUMsb>(e);
+ PrimExpr msb;
+ if (arg.dtype().bits() == 64) {
+ // SPIR-V FindUMsb intrinsic only supports 32 bit input
+ auto int32 = DataType::Int(32);
+ PrimExpr arg_hi32 = tvm::tir::Cast(int32, arg >> 32);
+ PrimExpr arg_lo32 = tvm::tir::Cast(int32, arg);
+ PrimExpr msb_hi = CallGLSLIntrin<GLSLstd450FindUMsb>(e, {arg_hi32});
+ PrimExpr msb_lo = CallGLSLIntrin<GLSLstd450FindUMsb>(e, {arg_lo32});
+ msb = tvm::if_then_else(arg_hi32 == 0, msb_lo, msb_hi + 32);
+ } else if (arg.dtype().bits() == 32) {
+ msb = CallGLSLIntrin<GLSLstd450FindUMsb>(e);
+ } else {
+ LOG(FATAL) << "SPIR-V clz only supports a 32 bit or 64 bit integer.";
+ }
return PrimExpr(arg.dtype().bits() - 1) - msb;
});
diff --git a/tests/python/unittest/test_tir_intrin.py b/tests/python/unittest/test_tir_intrin.py
index 0a82c91..8512d1c 100644
--- a/tests/python/unittest/test_tir_intrin.py
+++ b/tests/python/unittest/test_tir_intrin.py
@@ -170,7 +170,12 @@ def test_clz():
dev = tvm.device(target, 0)
n = 10
- for high in [10, 100, 1000, 10000, 100000, 1000000]:
+ highs = [10, 100, 1000, 10000, 100000, 1000000]
+
+ if dtype == "int64":
+ highs.append((1 << 63) - 1)
+
+ for high in highs:
a_np = np.random.randint(1, high=high, size=(n,)).astype(dtype)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(np.zeros((n,)).astype("int32"), dev)