You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/03/05 18:54:02 UTC

[tvm] branch main updated: [Vulkan] Support passing 64 bit scalar (#7572)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c0b9688  [Vulkan] Support passing 64 bit scalar  (#7572)
c0b9688 is described below

commit c0b9688869f26794765fb26a8d9b51191f60d6b2
Author: masahi <ma...@gmail.com>
AuthorDate: Sat Mar 6 03:53:48 2021 +0900

    [Vulkan] Support passing 64 bit scalar  (#7572)
    
    
    
    Co-authored-by: Wuwei Lin <wu...@apache.org>
---
 src/runtime/metal/metal_module.mm            |  4 ++--
 src/runtime/pack_args.h                      | 36 +++++++++++++++++++---------
 src/runtime/vulkan/vulkan.cc                 | 14 ++++++-----
 src/target/source/codegen_metal.cc           |  7 +++++-
 src/target/spirv/intrin_rule_spirv.cc        |  6 +++++
 tests/python/topi/python/test_topi_cumsum.py |  7 ++++++
 tests/python/topi/python/test_topi_vision.py |  2 +-
 7 files changed, 55 insertions(+), 21 deletions(-)

diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm
index 981dd61..8f1fde8 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -180,7 +180,7 @@ class MetalWrappedFunc {
     scache_[dev_id] = m->GetPipelineState(dev_id, func_name);
   }
   // invoke the function with void arguments
-  void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const {
+  void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const {
     metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
     int device_id = t->context.device_id;
     if (scache_[device_id] == nil) {
@@ -197,7 +197,7 @@ class MetalWrappedFunc {
     }
     if (num_pack_args_ != 0) {
       [encoder setBytes:pack_args
-                 length:num_pack_args_ * sizeof(ArgUnion)
+                 length:num_pack_args_ * sizeof(ArgUnion64)
                 atIndex:num_buffer_args_];
     }
     // launch
diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h
index 45cde22..7c852da 100644
--- a/src/runtime/pack_args.h
+++ b/src/runtime/pack_args.h
@@ -40,13 +40,24 @@ namespace tvm {
 namespace runtime {
 /*!
  * \brief argument union type of 32bit.
- * Choose 32 bit because most GPU API do not work well with 64 bit.
  */
-union ArgUnion {
+union ArgUnion32 {
   int32_t v_int32;
   uint32_t v_uint32;
   float v_float32;
 };
+
+/*!
+ * \brief argument union type of 64 bit, for use by Vulkan and Metal runtime.
+ */
+union ArgUnion64 {
+  int32_t v_int32[2];
+  uint32_t v_uint32[2];
+  float v_float32[2];
+  int64_t v_int64;
+  uint64_t v_uint64;
+  double v_float64;
+};
 /*!
  * \brief Create a packed function from void addr types.
  *
@@ -140,9 +151,9 @@ inline PackedFunc PackFuncVoidAddr_(F f, const std::vector<ArgConvertCode>& code
   int num_args = static_cast<int>(codes.size());
   auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) {
     TempArray<void*, N> addr_(num_args);
-    TempArray<ArgUnion, N> holder_(num_args);
+    TempArray<ArgUnion32, N> holder_(num_args);
     void** addr = addr_.data();
-    ArgUnion* holder = holder_.data();
+    ArgUnion32* holder = holder_.data();
     for (int i = 0; i < num_args; ++i) {
       switch (codes[i]) {
         case INT64_TO_INT64:
@@ -177,25 +188,28 @@ template <int N, typename F>
 inline PackedFunc PackFuncNonBufferArg_(F f, int base, const std::vector<ArgConvertCode>& codes) {
   int num_args = static_cast<int>(codes.size());
   auto ret = [f, codes, base, num_args](TVMArgs args, TVMRetValue* ret) {
-    TempArray<ArgUnion, N> holder_(num_args);
-    ArgUnion* holder = holder_.data();
+    TempArray<ArgUnion64, N> holder_(num_args);
+    ArgUnion64* holder = holder_.data();
     for (int i = 0; i < num_args; ++i) {
       switch (codes[i]) {
-        case INT64_TO_INT64:
+        case INT64_TO_INT64: {
+          holder[i].v_int64 = args.values[base + i].v_int64;
+          break;
+        }
         case FLOAT64_TO_FLOAT64: {
-          LOG(FATAL) << "Do not support 64bit argument to device function";
+          holder[i].v_float64 = args.values[base + i].v_float64;
           break;
         }
         case INT64_TO_INT32: {
-          holder[i].v_int32 = static_cast<int32_t>(args.values[base + i].v_int64);
+          holder[i].v_int32[0] = static_cast<int32_t>(args.values[base + i].v_int64);
           break;
         }
         case INT64_TO_UINT32: {
-          holder[i].v_uint32 = static_cast<uint32_t>(args.values[base + i].v_int64);
+          holder[i].v_uint32[0] = static_cast<uint32_t>(args.values[base + i].v_int64);
           break;
         }
         case FLOAT64_TO_FLOAT32: {
-          holder[i].v_float32 = static_cast<float>(args.values[base + i].v_float64);
+          holder[i].v_float32[0] = static_cast<float>(args.values[base + i].v_float64);
           break;
         }
         case HANDLE_TO_HANDLE: {
diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc
index f40fd80..794f3c5 100644
--- a/src/runtime/vulkan/vulkan.cc
+++ b/src/runtime/vulkan/vulkan.cc
@@ -711,7 +711,7 @@ class VulkanWrappedFunc {
     thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
   }
 
-  void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const;
+  void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const;
 
  private:
   // internal module
@@ -875,7 +875,7 @@ class VulkanModuleNode final : public runtime::ModuleNode {
     VkPushConstantRange crange;
     crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
     crange.offset = 0;
-    crange.size = sizeof(ArgUnion) * num_pack_args;
+    crange.size = sizeof(ArgUnion64) * num_pack_args;
 
     VkPipelineLayoutCreateInfo playout_cinfo;
     playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
@@ -1046,7 +1046,8 @@ VulkanStream* VulkanThreadEntry::Stream(size_t device_id) {
   return streams_[device_id].get();
 }
 
-void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const {
+void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
+                                   const ArgUnion64* pack_args) const {
   int device_id = VulkanThreadEntry::ThreadLocal()->ctx.device_id;
   ICHECK_LT(device_id, kVulkanMaxNumDevice);
   const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
@@ -1075,7 +1076,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion
           descriptor_buffers.data());
       if (num_pack_args_ != 0) {
         vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout,
-                           VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion),
+                           VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64),
                            pack_args);
       }
       vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
@@ -1093,7 +1094,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion
   }
 
   // Otherwise, the more expensive deferred path.
-  std::vector<ArgUnion> pack_args_storage(pack_args, pack_args + num_pack_args_);
+  std::vector<ArgUnion64> pack_args_storage(pack_args, pack_args + num_pack_args_);
   const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() {
     std::vector<VkWriteDescriptorSet> write_descriptor_sets;
     write_descriptor_sets.resize(descriptor_buffers.size());
@@ -1119,7 +1120,8 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion
                             nullptr);
     if (pack_args_storage.size() != 0) {
       vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT,
-                         0, pack_args_storage.size() * sizeof(ArgUnion), pack_args_storage.data());
+                         0, pack_args_storage.size() * sizeof(ArgUnion64),
+                         pack_args_storage.data());
     }
     vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
     VkMemoryBarrier barrier_info;
diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc
index baa3006..c95d578 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -47,7 +47,7 @@ CodeGenMetal::CodeGenMetal() {
   decl_stream << "#include <metal_stdlib>\n";
   decl_stream << "using namespace metal;\n\n";
   decl_stream << "union __TVMArgUnion {\n"
-              << " int v_int;\n"
+              << " int v_int[2];\n"
               << "};\n\n";
 }
 
@@ -104,6 +104,11 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
       if (v.dtype().bits() == 32) {
         decl_stream << "  ";
         PrintType(v.dtype(), decl_stream);
+        decl_stream << " " << vid << "[2];\n";
+        vref << varg << "." << vid << "[0]";
+      } else if (v.dtype().bits() == 64) {
+        decl_stream << "  ";
+        PrintType(v.dtype(), decl_stream);
         decl_stream << " " << vid << ";\n";
         vref << varg << "." << vid;
       } else {
diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc
index 90b2eb2..b75fb53 100644
--- a/src/target/spirv/intrin_rule_spirv.cc
+++ b/src/target/spirv/intrin_rule_spirv.cc
@@ -62,8 +62,14 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.fabs").set_body(DispatchGLSLPureIntr
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp").set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>);
 
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sin").set_body(DispatchGLSLPureIntrin<GLSLstd450Sin>);
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.cos").set_body(DispatchGLSLPureIntrin<GLSLstd450Cos>);
+
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log").set_body(DispatchGLSLPureIntrin<GLSLstd450Log>);
 
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log2").set_body(DispatchGLSLPureIntrin<GLSLstd450Log2>);
+
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt").set_body(DispatchGLSLPureIntrin<GLSLstd450Sqrt>);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow").set_body(DispatchGLSLPureIntrin<GLSLstd450Pow>);
diff --git a/tests/python/topi/python/test_topi_cumsum.py b/tests/python/topi/python/test_topi_cumsum.py
index a01a496..cfe5130 100644
--- a/tests/python/topi/python/test_topi_cumsum.py
+++ b/tests/python/topi/python/test_topi_cumsum.py
@@ -28,6 +28,8 @@ def test_cumsum(ctx, target):
             "generic": (lambda x: topi.cumsum(x, axis, dtype), topi.generic.schedule_extern),
             "cuda": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
             "nvptx": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
+            "vulkan": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
+            "metal": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
         }
         fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations)
         tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, fcompute, fschedule)
@@ -44,6 +46,9 @@ def test_cumsum(ctx, target):
     check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32")
 
     for in_dtype in ["float32", "float64"]:
+        if target == "metal" and in_dtype == "float64":
+            # float64 is not supported in metal
+            continue
         data = np.random.randn(10, 10).astype(in_dtype)
         check_cumsum(np.cumsum(data), data)
         check_cumsum(np.cumsum(data, axis=0), data, axis=0)
@@ -70,3 +75,5 @@ if __name__ == "__main__":
     test_cumsum(tvm.context("cpu"), tvm.target.Target("llvm"))
     test_cumsum(tvm.context("cuda"), tvm.target.Target("cuda"))
     test_cumsum(tvm.context("nvptx"), tvm.target.Target("nvptx"))
+    test_cumsum(tvm.context("vulkan"), tvm.target.Target("vulkan"))
+    test_cumsum(tvm.context("metal"), tvm.target.Target("metal"))
diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py
index 8393568..2fdf3cf 100644
--- a/tests/python/topi/python/test_topi_vision.py
+++ b/tests/python/topi/python/test_topi_vision.py
@@ -112,7 +112,7 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index):
         tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)
         tvm.testing.assert_allclose(tvm_out3.asnumpy(), np_out3, rtol=1e-3)
 
-    for device in ["llvm", "cuda", "opencl"]:
+    for device in ["llvm", "cuda", "opencl", "vulkan"]:
         check_device(device)