You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2021/03/04 22:04:22 UTC

[tvm] 01/11: introduce ArgUnion64

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch vk-i64
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit ad075025a61bed23df0bf5f6c78853b5204725f4
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Wed Mar 3 07:37:52 2021 +0900

    introduce ArgUnion64
---
 src/runtime/pack_args.h      | 26 +++++++++++++++++++-------
 src/runtime/vulkan/vulkan.cc | 12 ++++++------
 2 files changed, 25 insertions(+), 13 deletions(-)

diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h
index 45cde22..54a75d6 100644
--- a/src/runtime/pack_args.h
+++ b/src/runtime/pack_args.h
@@ -47,6 +47,15 @@ union ArgUnion {
   uint32_t v_uint32;
   float v_float32;
 };
+
+union ArgUnion64 {
+  int32_t v_int32[2];
+  uint32_t v_uint32[2];
+  float v_float32[2];
+  int64_t v_int64;
+  uint64_t v_uint64;
+  double v_float64;
+};
 /*!
  * \brief Create a packed function from void addr types.
  *
@@ -177,25 +186,28 @@ template <int N, typename F>
 inline PackedFunc PackFuncNonBufferArg_(F f, int base, const std::vector<ArgConvertCode>& codes) {
   int num_args = static_cast<int>(codes.size());
   auto ret = [f, codes, base, num_args](TVMArgs args, TVMRetValue* ret) {
-    TempArray<ArgUnion, N> holder_(num_args);
-    ArgUnion* holder = holder_.data();
+    TempArray<ArgUnion64, N> holder_(num_args);
+    ArgUnion64* holder = holder_.data();
     for (int i = 0; i < num_args; ++i) {
       switch (codes[i]) {
-        case INT64_TO_INT64:
+        case INT64_TO_INT64: {
+          holder[i].v_int64 = args.values[base + i].v_int64;
+          break;
+        }
         case FLOAT64_TO_FLOAT64: {
-          LOG(FATAL) << "Do not support 64bit argument to device function";
+          holder[i].v_float64 = args.values[base + i].v_float64;
           break;
         }
         case INT64_TO_INT32: {
-          holder[i].v_int32 = static_cast<int32_t>(args.values[base + i].v_int64);
+          holder[i].v_int32[0] = static_cast<int32_t>(args.values[base + i].v_int64);
           break;
         }
         case INT64_TO_UINT32: {
-          holder[i].v_uint32 = static_cast<uint32_t>(args.values[base + i].v_int64);
+          holder[i].v_uint32[0] = static_cast<uint32_t>(args.values[base + i].v_int64);
           break;
         }
         case FLOAT64_TO_FLOAT32: {
-          holder[i].v_float32 = static_cast<float>(args.values[base + i].v_float64);
+          holder[i].v_float32[0] = static_cast<float>(args.values[base + i].v_float64);
           break;
         }
         case HANDLE_TO_HANDLE: {
diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc
index f40fd80..4eb3481 100644
--- a/src/runtime/vulkan/vulkan.cc
+++ b/src/runtime/vulkan/vulkan.cc
@@ -711,7 +711,7 @@ class VulkanWrappedFunc {
     thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
   }
 
-  void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const;
+  void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const;
 
  private:
   // internal module
@@ -875,7 +875,7 @@ class VulkanModuleNode final : public runtime::ModuleNode {
     VkPushConstantRange crange;
     crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
     crange.offset = 0;
-    crange.size = sizeof(ArgUnion) * num_pack_args;
+    crange.size = sizeof(ArgUnion64) * num_pack_args;
 
     VkPipelineLayoutCreateInfo playout_cinfo;
     playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
@@ -1046,7 +1046,7 @@ VulkanStream* VulkanThreadEntry::Stream(size_t device_id) {
   return streams_[device_id].get();
 }
 
-void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const {
+void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const {
   int device_id = VulkanThreadEntry::ThreadLocal()->ctx.device_id;
   ICHECK_LT(device_id, kVulkanMaxNumDevice);
   const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
@@ -1075,7 +1075,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion
           descriptor_buffers.data());
       if (num_pack_args_ != 0) {
         vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout,
-                           VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion),
+                           VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64),
                            pack_args);
       }
       vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
@@ -1093,7 +1093,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion
   }
 
   // Otherwise, the more expensive deferred path.
-  std::vector<ArgUnion> pack_args_storage(pack_args, pack_args + num_pack_args_);
+  std::vector<ArgUnion64> pack_args_storage(pack_args, pack_args + num_pack_args_);
   const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() {
     std::vector<VkWriteDescriptorSet> write_descriptor_sets;
     write_descriptor_sets.resize(descriptor_buffers.size());
@@ -1119,7 +1119,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion
                             nullptr);
     if (pack_args_storage.size() != 0) {
       vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT,
-                         0, pack_args_storage.size() * sizeof(ArgUnion), pack_args_storage.data());
+                         0, pack_args_storage.size() * sizeof(ArgUnion64), pack_args_storage.data());
     }
     vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
     VkMemoryBarrier barrier_info;