You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2021/04/11 01:10:20 UTC

[tvm] branch main updated: [Vulkan] Support uniform buffer object for passing many scalar arguments (#7717)

This is an automated email from the ASF dual-hosted git repository.

moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5bc1cec  [Vulkan] Support uniform buffer object for passing many scalar arguments (#7717)
5bc1cec is described below

commit 5bc1cec4c4acf0a54889227c1d19a6b65b6803c2
Author: masahi <ma...@gmail.com>
AuthorDate: Sun Apr 11 10:10:02 2021 +0900

    [Vulkan] Support uniform buffer object for passing many scalar arguments (#7717)
    
    * ubo codegen first cut
    
    * begin runtime change for UBO
    
    * allocate and bind ubo
    
    * query memory type for uniform
    
    * refactor
    
    * do not use float64
    
    * trying an approach similar to push constant
    
    * add more log
    
    * do not delete ubo when not using it
    
    * cumsum and nms test working with ubo
    
    * remove log
    
    * cleaning up
    
    * formatting
    
    * revert BufferArgument change
    
    * refactored codegen
    
    * minor fix
    
    * introduce value kind for ubo
    
    * fix cpplint and revert float64 change
    
    * query push constant size using runtime API
    
    * let vkmap/unmap allocate and delete host_buf
    
    * doc update
    
    * fix typo
    
    Co-authored-by: Masahiro Masuda <ma...@gmail.com>
---
 src/runtime/vulkan/vulkan.cc       | 268 +++++++++++++++++++++++++------------
 src/runtime/vulkan/vulkan_common.h |   3 +
 src/target/spirv/codegen_spirv.cc  |  23 +++-
 src/target/spirv/ir_builder.cc     |  31 ++++-
 src/target/spirv/ir_builder.h      |  32 ++++-
 5 files changed, 255 insertions(+), 102 deletions(-)

diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc
index 5cd4812..c8a0858 100644
--- a/src/runtime/vulkan/vulkan.cc
+++ b/src/runtime/vulkan/vulkan.cc
@@ -91,6 +91,11 @@ struct VulkanBuffer {
   VkDeviceMemory memory{VK_NULL_HANDLE};
 };
 
+struct UniformBuffer {
+  VulkanBuffer* vk_buf;
+  void* host_buf;
+};
+
 struct VulkanPipeline {
   VulkanContext* vctx_{nullptr};
   VkShaderModule shader{VK_NULL_HANDLE};
@@ -100,10 +105,105 @@ struct VulkanPipeline {
   VkPipelineLayout pipeline_layout{VK_NULL_HANDLE};
   VkPipeline pipeline{VK_NULL_HANDLE};
   VkDescriptorUpdateTemplateKHR descriptor_update_template{VK_NULL_HANDLE};
+  UniformBuffer ubo;
 };
 
 typedef dmlc::ThreadLocalStore<VulkanThreadEntry> VulkanThreadStore;
 
+uint32_t FindMemoryType(VkDevice logical_device, VkPhysicalDevice phy_device, VkBuffer buffer,
+                        VkMemoryPropertyFlags req_prop) {
+  VkMemoryRequirements mem_reqs;
+  vkGetBufferMemoryRequirements(logical_device, buffer, &mem_reqs);
+  uint32_t type_bits = mem_reqs.memoryTypeBits;
+  VkPhysicalDeviceMemoryProperties phy_mem_prop;
+  vkGetPhysicalDeviceMemoryProperties(phy_device, &phy_mem_prop);
+  for (uint32_t i = 0; i < phy_mem_prop.memoryTypeCount; i++) {
+    if ((type_bits & 1) == 1 &&
+        (phy_mem_prop.memoryTypes[i].propertyFlags & req_prop) == req_prop) {
+      return i;
+    }
+    type_bits >>= 1;
+  }
+  LOG(FATAL) << "Requested memory type not found";
+  return 0;
+}
+
+VulkanBuffer* CreateBuffer(const VulkanContext& vctx, size_t nbytes, VkBufferUsageFlags usage) {
+  VkBufferCreateInfo info;
+  info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
+  info.pNext = nullptr;
+  info.flags = 0;
+  info.size = nbytes;
+  info.queueFamilyIndexCount = 1;
+  info.pQueueFamilyIndices = &(vctx.queue_family_index);
+  info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
+  info.usage = usage;
+  // create buffer
+  VkBuffer buffer;
+  VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer));
+
+  uint32_t mem_type_index = vctx.compute_mtype_index;
+
+  if (usage & VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT) {
+    // Find a memory type that supports UBO
+    auto prop = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT;
+    mem_type_index = FindMemoryType(vctx.device, vctx.phy_device, buffer, prop);
+  }
+
+  // bind to memory
+  bool dedicated_allocation = false;
+  VkMemoryRequirements2KHR req2;
+
+  if (vctx.get_buffer_memory_requirements_2_functions) {
+    VkBufferMemoryRequirementsInfo2KHR req_info2;
+    req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR;
+    req_info2.pNext = 0;
+    req_info2.buffer = buffer;
+
+    req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR;
+    req2.pNext = 0;
+
+    VkMemoryDedicatedRequirementsKHR dedicated_req;
+    dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR;
+    dedicated_req.pNext = 0;
+    req2.pNext = &dedicated_req;
+
+    vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR(
+        vctx.device, &req_info2, &req2);
+    dedicated_allocation =
+        dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation;
+  }
+
+  VkDeviceMemory memory;
+  if (!dedicated_allocation) {
+    VkMemoryAllocateInfo minfo;
+    minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
+    minfo.pNext = nullptr;
+    minfo.allocationSize = nbytes;
+    minfo.memoryTypeIndex = mem_type_index;
+    VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
+  } else {
+    VkMemoryAllocateInfo minfo;
+    minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
+    minfo.pNext = nullptr;
+    minfo.allocationSize = req2.memoryRequirements.size;
+    minfo.memoryTypeIndex = mem_type_index;
+
+    VkMemoryDedicatedAllocateInfoKHR mdinfo;
+    mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR;
+    mdinfo.pNext = 0;
+    mdinfo.image = 0;
+    mdinfo.buffer = buffer;
+    minfo.pNext = &mdinfo;
+    VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
+  }
+  VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0));
+  VulkanBuffer* pbuf = new VulkanBuffer();
+  pbuf->memory = memory;
+  pbuf->buffer = buffer;
+  return pbuf;
+}
+
 class VulkanDeviceAPI final : public DeviceAPI {
  public:
   VulkanDeviceAPI();
@@ -124,70 +224,9 @@ class VulkanDeviceAPI final : public DeviceAPI {
       nbytes = 1;
     }
     const auto& vctx = context(dev.device_id);
-    VkBufferCreateInfo info;
-    info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
-    info.pNext = nullptr;
-    info.flags = 0;
-    info.size = nbytes;
-    info.queueFamilyIndexCount = 1;
-    info.pQueueFamilyIndices = &(vctx.queue_family_index);
-    info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
-    info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
+    auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
                  VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
-    // create buffer
-    VkBuffer buffer;
-    VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer));
-    // bind to memory
-    VkBufferMemoryRequirementsInfo2KHR req_info2;
-    req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR;
-    req_info2.pNext = 0;
-    req_info2.buffer = buffer;
-
-    VkMemoryRequirements2KHR req2;
-    req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR;
-    req2.pNext = 0;
-
-    VkMemoryDedicatedRequirementsKHR dedicated_req;
-    dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR;
-    dedicated_req.pNext = 0;
-    req2.pNext = &dedicated_req;
-
-    bool dedicated_allocation = false;
-    if (vctx.get_buffer_memory_requirements_2_functions) {
-      vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR(
-          vctx.device, &req_info2, &req2);
-      dedicated_allocation =
-          dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation;
-    }
-
-    VkDeviceMemory memory;
-    if (!dedicated_allocation) {
-      VkMemoryAllocateInfo minfo;
-      minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
-      minfo.pNext = nullptr;
-      minfo.allocationSize = nbytes;
-      minfo.memoryTypeIndex = vctx.compute_mtype_index;
-      VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
-    } else {
-      VkMemoryAllocateInfo minfo;
-      minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
-      minfo.pNext = nullptr;
-      minfo.allocationSize = req2.memoryRequirements.size;
-      minfo.memoryTypeIndex = vctx.compute_mtype_index;
-
-      VkMemoryDedicatedAllocateInfoKHR mdinfo;
-      mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR;
-      mdinfo.pNext = 0;
-      mdinfo.image = 0;
-      mdinfo.buffer = buffer;
-      minfo.pNext = &mdinfo;
-      VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
-    }
-    VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0));
-    VulkanBuffer* pbuf = new VulkanBuffer();
-    pbuf->memory = memory;
-    pbuf->buffer = buffer;
-    return pbuf;
+    return CreateBuffer(vctx, nbytes, usage);
   }
 
   void FreeDataSpace(Device dev, void* ptr) final {
@@ -747,7 +786,7 @@ class VulkanModuleNode final : public runtime::ModuleNode {
  public:
   explicit VulkanModuleNode(std::unordered_map<std::string, VulkanShader> smap,
                             std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
-      : smap_(smap), fmap_(fmap), source_(source) {}
+      : smap_(smap), fmap_(fmap), source_(source), max_push_constants_(GetMaxPushConstantsSize()) {}
 
   const char* type_key() const final { return "vulkan"; }
 
@@ -781,6 +820,13 @@ class VulkanModuleNode final : public runtime::ModuleNode {
         vkDestroyDescriptorPool(vctx.device, pe->descriptor_pool, nullptr);
         vkDestroyDescriptorSetLayout(vctx.device, pe->descriptor_set_layout, nullptr);
         vkDestroyShaderModule(vctx.device, pe->shader, nullptr);
+        // UBO
+        if (pe->ubo.vk_buf) {
+          vkUnmapMemory(vctx.device, pe->ubo.vk_buf->memory);
+          vkDestroyBuffer(vctx.device, pe->ubo.vk_buf->buffer, nullptr);
+          vkFreeMemory(vctx.device, pe->ubo.vk_buf->memory, nullptr);
+          delete pe->ubo.vk_buf;
+        }
       }
     }
   }
@@ -812,30 +858,35 @@ class VulkanModuleNode final : public runtime::ModuleNode {
     std::vector<VkDescriptorUpdateTemplateEntryKHR> arg_template;
     uint32_t num_pod = 0, num_buffer = 0;
 
+    auto push_arg_info = [&arg_binding, &arg_template](uint32_t binding,
+                                                       VkDescriptorType desc_type) {
+      {
+        VkDescriptorSetLayoutBinding bd;
+        bd.binding = binding;
+        bd.descriptorType = desc_type;
+        bd.descriptorCount = 1;
+        bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
+        bd.pImmutableSamplers = nullptr;
+        arg_binding.push_back(bd);
+      }
+      {
+        VkDescriptorUpdateTemplateEntryKHR tpl;
+        tpl.dstBinding = binding;
+        tpl.dstArrayElement = 0;
+        tpl.descriptorCount = 1;
+        tpl.descriptorType = desc_type;
+        tpl.offset = binding * sizeof(VkDescriptorBufferInfo);
+        tpl.stride = sizeof(VkDescriptorBufferInfo);
+        arg_template.push_back(tpl);
+      }
+    };
+
     {
       auto fit = fmap_.find(func_name);
       ICHECK(fit != fmap_.end());
       for (DLDataType arg_type : fit->second.arg_types) {
         if (arg_type.code == kTVMOpaqueHandle) {
-          {
-            VkDescriptorSetLayoutBinding bd;
-            bd.binding = num_buffer;
-            bd.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
-            bd.descriptorCount = 1;
-            bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
-            bd.pImmutableSamplers = nullptr;
-            arg_binding.push_back(bd);
-          }
-          {
-            VkDescriptorUpdateTemplateEntryKHR tpl;
-            tpl.dstBinding = num_buffer;
-            tpl.dstArrayElement = 0;
-            tpl.descriptorCount = 1;
-            tpl.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
-            tpl.offset = num_buffer * sizeof(VkDescriptorBufferInfo);
-            tpl.stride = sizeof(VkDescriptorBufferInfo);
-            arg_template.push_back(tpl);
-          }
+          push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
           ++num_buffer;
         } else {
           ++num_pod;
@@ -843,6 +894,11 @@ class VulkanModuleNode final : public runtime::ModuleNode {
       }
     }
 
+    size_t nbytes_scalars = num_pod * sizeof(ArgUnion64);
+    if (nbytes_scalars > max_push_constants_) {
+      push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER);
+    }
+
     {
       VkDescriptorSetLayoutCreateInfo descrip_cinfo;
       descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
@@ -894,7 +950,7 @@ class VulkanModuleNode final : public runtime::ModuleNode {
     playout_cinfo.setLayoutCount = 1;
     playout_cinfo.pSetLayouts = &(pe->descriptor_set_layout);
 
-    if (num_pack_args != 0) {
+    if (0 < nbytes_scalars && nbytes_scalars <= max_push_constants_) {
       playout_cinfo.pushConstantRangeCount = 1;
       playout_cinfo.pPushConstantRanges = &crange;
       ICHECK_LE(crange.size, vctx.phy_device_prop.limits.maxPushConstantsSize);
@@ -923,6 +979,13 @@ class VulkanModuleNode final : public runtime::ModuleNode {
     VULKAN_CALL(vkCreateComputePipelines(vctx.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr,
                                          &(pe->pipeline)));
 
+    if (nbytes_scalars > max_push_constants_) {
+      // Allocate, bind and map UBO
+      UniformBuffer& ubo = pe->ubo;
+      ubo.vk_buf = CreateBuffer(vctx, nbytes_scalars, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT);
+      vkMapMemory(vctx.device, ubo.vk_buf->memory, 0, nbytes_scalars, 0, &(ubo.host_buf));
+    }
+
     if (vctx.UseImmediate()) {
       VkDescriptorUpdateTemplateCreateInfoKHR descrip_template_cinfo;
       descrip_template_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR;
@@ -966,6 +1029,8 @@ class VulkanModuleNode final : public runtime::ModuleNode {
     return source_;
   }
 
+  uint32_t MaxPushConstantsSize() const { return max_push_constants_; }
+
  private:
   // function information table.
   std::unordered_map<std::string, VulkanShader> smap_;
@@ -975,6 +1040,8 @@ class VulkanModuleNode final : public runtime::ModuleNode {
   std::string fmt_{"vulkan"};
   // The source
   std::string source_;
+  // The maximum size of push constants in bytes
+  const uint32_t max_push_constants_;
 
   // Guards accesses to `ecache_`
   std::mutex mutex_;
@@ -1076,6 +1143,17 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
     binfo.range = VK_WHOLE_SIZE;
     descriptor_buffers[i] = binfo;
   }
+  const size_t nbytes_scalars = num_pack_args_ * sizeof(ArgUnion64);
+  bool use_ubo = num_pack_args_ != 0 && nbytes_scalars > m_->MaxPushConstantsSize();
+  if (use_ubo) {
+    CHECK(pipeline->ubo.host_buf) << "The UBO host buffer is not allocated";
+    memcpy(pipeline->ubo.host_buf, pack_args, nbytes_scalars);
+    VkDescriptorBufferInfo binfo;
+    binfo.buffer = pipeline->ubo.vk_buf->buffer;
+    binfo.offset = 0;
+    binfo.range = VK_WHOLE_SIZE;
+    descriptor_buffers.push_back(binfo);
+  }
   if (vctx.UseImmediate()) {
     // Can safely capture by reference as this lambda is immediately executed on the calling thread.
     VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Launch([&](VulkanStreamState* state) {
@@ -1084,7 +1162,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
       vctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR(
           state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0,
           descriptor_buffers.data());
-      if (num_pack_args_ != 0) {
+      if (num_pack_args_ > 0 && num_pack_args_ <= m_->MaxPushConstantsSize()) {
         vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout,
                            VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64),
                            pack_args);
@@ -1105,7 +1183,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
 
   // Otherwise, the more expensive deferred path.
   std::vector<ArgUnion64> pack_args_storage(pack_args, pack_args + num_pack_args_);
-  const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() {
+  const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers, use_ubo]() {
     std::vector<VkWriteDescriptorSet> write_descriptor_sets;
     write_descriptor_sets.resize(descriptor_buffers.size());
     for (size_t i = 0; i < write_descriptor_sets.size(); i++) {
@@ -1115,20 +1193,26 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
       write_descriptor_sets[i].dstBinding = i;
       write_descriptor_sets[i].dstArrayElement = 0;
       write_descriptor_sets[i].descriptorCount = 1;
-      write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
       write_descriptor_sets[i].pImageInfo = 0;
       write_descriptor_sets[i].pBufferInfo = &(descriptor_buffers[i]);
       write_descriptor_sets[i].pTexelBufferView = 0;
+
+      if (use_ubo && i == write_descriptor_sets.size() - 1) {
+        // The last binding is for UBO
+        write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
+      } else {
+        write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
+      }
     }
     vkUpdateDescriptorSets(vctx.device, write_descriptor_sets.size(), write_descriptor_sets.data(),
                            0, 0);
   };
-  const auto& deferred_kernel = [pipeline, wl, pack_args_storage](VulkanStreamState* state) {
+  const auto& deferred_kernel = [this, pipeline, wl, pack_args_storage](VulkanStreamState* state) {
     vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline);
     vkCmdBindDescriptorSets(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE,
                             pipeline->pipeline_layout, 0, 1, &(pipeline->descriptor_set), 0,
                             nullptr);
-    if (pack_args_storage.size() != 0) {
+    if (num_pack_args_ > 0 && num_pack_args_ <= m_->MaxPushConstantsSize()) {
       vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT,
                          0, pack_args_storage.size() * sizeof(ArgUnion64),
                          pack_args_storage.data());
@@ -1183,6 +1267,12 @@ Module VulkanModuleLoadBinary(void* strm) {
   return VulkanModuleCreate(smap, fmap, "");
 }
 
+uint32_t GetMaxPushConstantsSize() {
+  int device_id = VulkanThreadEntry::ThreadLocal()->device.device_id;
+  const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
+  return vctx.phy_device_prop.limits.maxPushConstantsSize;
+}
+
 TVM_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile);
 
 TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary);
diff --git a/src/runtime/vulkan/vulkan_common.h b/src/runtime/vulkan/vulkan_common.h
index 3083ba6..e94a9fe 100644
--- a/src/runtime/vulkan/vulkan_common.h
+++ b/src/runtime/vulkan/vulkan_common.h
@@ -142,6 +142,9 @@ struct VulkanContext {
   bool UseImmediate() const { return descriptor_template_khr_functions.get() != nullptr; }
 };
 
+/*! \brief returns maximum push constant sizes in bytes for the target platform */
+uint32_t GetMaxPushConstantsSize();
+
 }  // namespace vulkan
 }  // namespace runtime
 }  // namespace tvm
diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc
index 24608eb..4d55f4c 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -30,6 +30,9 @@
 
 #include <string>
 
+#include "../../runtime/pack_args.h"
+#include "../../runtime/vulkan/vulkan_common.h"
+
 namespace tvm {
 namespace codegen {
 
@@ -66,16 +69,26 @@ std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::
   spirv::Value func_ptr = builder_->NewFunction();
   builder_->StartFunction(func_ptr);
 
-  // All the POD arguments are passed in through PushConstant
   if (pod_args.size() != 0) {
     std::vector<spirv::SType> value_types;
     for (size_t i = 0; i < pod_args.size(); ++i) {
       value_types.push_back(builder_->GetSType(pod_args[i].dtype()));
     }
-    spirv::Value ptr = builder_->DeclarePushConstant(value_types);
-    for (size_t i = 0; i < pod_args.size(); ++i) {
-      spirv::Value value = builder_->GetPushConstant(ptr, value_types[i], static_cast<uint32_t>(i));
-      var_map_[pod_args[i].get()] = value;
+    const auto max_push_constants = runtime::vulkan::GetMaxPushConstantsSize();
+    if (pod_args.size() * sizeof(runtime::ArgUnion64) <= max_push_constants) {
+      spirv::Value ptr = builder_->DeclarePushConstant(value_types);
+      for (size_t i = 0; i < pod_args.size(); ++i) {
+        spirv::Value value =
+            builder_->GetPushConstant(ptr, value_types[i], static_cast<uint32_t>(i));
+        var_map_[pod_args[i].get()] = value;
+      }
+    } else {
+      // If we need to pass more arguments than push constants could handle, we use UBO.
+      spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, num_buffer);
+      for (size_t i = 0; i < pod_args.size(); ++i) {
+        spirv::Value value = builder_->GetUniform(ptr, value_types[i], static_cast<uint32_t>(i));
+        var_map_[pod_args[i].get()] = value;
+      }
     }
   }
   this->VisitStmt(f->body);
diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc
index 5a14573..cd48c93 100644
--- a/src/target/spirv/ir_builder.cc
+++ b/src/target/spirv/ir_builder.cc
@@ -205,8 +205,8 @@ Value IRBuilder::BufferArgument(const SType& value_type, uint32_t descriptor_set
   return val;
 }
 
-Value IRBuilder::DeclarePushConstant(const std::vector<SType>& value_types) {
-  ICHECK_EQ(push_const_.id, 0);
+Value IRBuilder::DeclareStorageVariable(const std::vector<SType>& value_types,
+                                        spv::StorageClass storage_class, ValueKind kind) {
   SType struct_type;
   struct_type.id = id_counter_++;
   struct_type.type = DataType::Handle();
@@ -226,22 +226,26 @@ Value IRBuilder::DeclarePushConstant(const std::vector<SType>& value_types) {
     ICHECK_EQ(nbits % 8, 0);
     uint32_t bytes = (nbits / 8);
     if (t.bits() == 32) {
-      // In our Vulkan runtime, each push constant always occupies 64 bit.
+      // In our Vulkan runtime, each scalar argument always occupies 64 bit.
       offset += bytes * 2;
     } else {
       ICHECK_EQ(t.bits(), 64);
       offset += bytes;
     }
   }
-  // Decorate push constants as UBO
   this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock);
 
-  SType ptr_type = GetPointerType(struct_type, spv::StorageClassPushConstant);
-  Value val = NewValue(ptr_type, kPushConstantPtr);
-  ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, spv::StorageClassPushConstant).Commit(&global_);
+  SType ptr_type = GetPointerType(struct_type, storage_class);
+  Value val = NewValue(ptr_type, kind);
+  ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&global_);
   return val;
 }
 
+Value IRBuilder::DeclarePushConstant(const std::vector<SType>& value_types) {
+  ICHECK_EQ(push_const_.id, 0);
+  return DeclareStorageVariable(value_types, spv::StorageClassPushConstant, kPushConstantPtr);
+}
+
 Value IRBuilder::GetPushConstant(Value ptr_push_const, const SType& v_type, uint32_t index) {
   SType ptr_vtype = this->GetPointerType(v_type, spv::StorageClassPushConstant);
   Value ptr = this->MakeValue(spv::OpAccessChain, ptr_vtype, ptr_push_const,
@@ -249,6 +253,19 @@ Value IRBuilder::GetPushConstant(Value ptr_push_const, const SType& v_type, uint
   return this->MakeValue(spv::OpLoad, v_type, ptr);
 }
 
+Value IRBuilder::DeclareUniformBuffer(const std::vector<SType>& value_types, uint32_t binding) {
+  Value val = DeclareStorageVariable(value_types, spv::StorageClassUniform, kUniformPtr);
+  this->Decorate(spv::OpDecorate, val, spv::DecorationBinding, binding);
+  return val;
+}
+
+Value IRBuilder::GetUniform(Value ptr_push_const, const SType& v_type, uint32_t index) {
+  SType ptr_vtype = this->GetPointerType(v_type, spv::StorageClassUniform);
+  Value ptr = this->MakeValue(spv::OpAccessChain, ptr_vtype, ptr_push_const,
+                              IntImm(t_int32_, static_cast<int64_t>(index)));
+  return this->MakeValue(spv::OpLoad, v_type, ptr);
+}
+
 Value IRBuilder::NewFunction() { return NewValue(t_void_func_, kFunction); }
 
 void IRBuilder::CommitKernelFunction(const Value& func, const std::string& name) {
diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h
index 8a08048..05a2bc6 100644
--- a/src/target/spirv/ir_builder.h
+++ b/src/target/spirv/ir_builder.h
@@ -60,7 +60,8 @@ enum ValueKind {
   kStructArrayPtr,
   kPushConstantPtr,
   kFunction,
-  kExtInst
+  kExtInst,
+  kUniformPtr
 };
 
 /*! \brief Represent the SPIRV Value */
@@ -473,6 +474,7 @@ class IRBuilder {
    * \param The argument type.
    */
   Value BufferArgument(const SType& value_type, uint32_t descriptor_set, uint32_t binding);
+
   /*!
    * \brief Declare POD arguments through push constants.
    *
@@ -488,6 +490,23 @@ class IRBuilder {
    * \return the value of push constant
    */
   Value GetPushConstant(Value ptr_push_const, const SType& v_type, uint32_t index);
+
+  /*!
+   * \brief Declare POD arguments through uniform buffer.
+   *
+   * \note Only call this function once!
+   * \param value_types The values in the uniform buffer
+   * \param binding The binding locaiton in descriptor set
+   * \return reference to self.
+   */
+  Value DeclareUniformBuffer(const std::vector<SType>& value_types, uint32_t binding);
+  /*!
+   * \brief Get i-th uniform constant
+   * \param v_type The value type
+   * \param index The uniform index
+   * \return the value of uniform constant
+   */
+  Value GetUniform(Value ptr_ubo, const SType& v_type, uint32_t index);
   /*!
    * \brief Declare a new function
    * \return The created function ID.
@@ -555,6 +574,17 @@ class IRBuilder {
     val.flag = flag;
     return val;
   }
+
+  /*!
+   * \brief The common function to declare push constants or uniform buffer
+   * \param value_types The values in the push constants or uniform buffer
+   * \param storage_class An enum defined by SPIR-V indicating push constant or uniform
+   * \param kind An enum indicating push constant or uniform
+   * \return The created new label
+   */
+  Value DeclareStorageVariable(const std::vector<SType>& value_types,
+                               spv::StorageClass storage_class, ValueKind kind);
+
   // get constant given value encoded in uint64_t
   Value GetConst_(const SType& dtype, const uint64_t* pvalue);
   // declare type