You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/04/11 19:08:49 UTC

[tvm] branch main updated: Revert "[Vulkan] Support uniform buffer object for passing many scalar arguments (#7717)" (#7821)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ab0dc2e  Revert "[Vulkan] Support uniform buffer object for passing many scalar arguments (#7717)" (#7821)
ab0dc2e is described below

commit ab0dc2e6c3aa5cdd93fe3e71f37abff80cec2d38
Author: masahi <ma...@gmail.com>
AuthorDate: Mon Apr 12 04:08:24 2021 +0900

    Revert "[Vulkan] Support uniform buffer object for passing many scalar arguments (#7717)" (#7821)
    
    This reverts commit 5bc1cec4c4acf0a54889227c1d19a6b65b6803c2.
---
 src/runtime/vulkan/vulkan.cc       | 268 ++++++++++++-------------------------
 src/runtime/vulkan/vulkan_common.h |   3 -
 src/target/spirv/codegen_spirv.cc  |  23 +---
 src/target/spirv/ir_builder.cc     |  31 +----
 src/target/spirv/ir_builder.h      |  32 +----
 5 files changed, 102 insertions(+), 255 deletions(-)

diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc
index c8a0858..5cd4812 100644
--- a/src/runtime/vulkan/vulkan.cc
+++ b/src/runtime/vulkan/vulkan.cc
@@ -91,11 +91,6 @@ struct VulkanBuffer {
   VkDeviceMemory memory{VK_NULL_HANDLE};
 };
 
-struct UniformBuffer {
-  VulkanBuffer* vk_buf;
-  void* host_buf;
-};
-
 struct VulkanPipeline {
   VulkanContext* vctx_{nullptr};
   VkShaderModule shader{VK_NULL_HANDLE};
@@ -105,105 +100,10 @@ struct VulkanPipeline {
   VkPipelineLayout pipeline_layout{VK_NULL_HANDLE};
   VkPipeline pipeline{VK_NULL_HANDLE};
   VkDescriptorUpdateTemplateKHR descriptor_update_template{VK_NULL_HANDLE};
-  UniformBuffer ubo;
 };
 
 typedef dmlc::ThreadLocalStore<VulkanThreadEntry> VulkanThreadStore;
 
-uint32_t FindMemoryType(VkDevice logical_device, VkPhysicalDevice phy_device, VkBuffer buffer,
-                        VkMemoryPropertyFlags req_prop) {
-  VkMemoryRequirements mem_reqs;
-  vkGetBufferMemoryRequirements(logical_device, buffer, &mem_reqs);
-  uint32_t type_bits = mem_reqs.memoryTypeBits;
-  VkPhysicalDeviceMemoryProperties phy_mem_prop;
-  vkGetPhysicalDeviceMemoryProperties(phy_device, &phy_mem_prop);
-  for (uint32_t i = 0; i < phy_mem_prop.memoryTypeCount; i++) {
-    if ((type_bits & 1) == 1 &&
-        (phy_mem_prop.memoryTypes[i].propertyFlags & req_prop) == req_prop) {
-      return i;
-    }
-    type_bits >>= 1;
-  }
-  LOG(FATAL) << "Requested memory type not found";
-  return 0;
-}
-
-VulkanBuffer* CreateBuffer(const VulkanContext& vctx, size_t nbytes, VkBufferUsageFlags usage) {
-  VkBufferCreateInfo info;
-  info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
-  info.pNext = nullptr;
-  info.flags = 0;
-  info.size = nbytes;
-  info.queueFamilyIndexCount = 1;
-  info.pQueueFamilyIndices = &(vctx.queue_family_index);
-  info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
-  info.usage = usage;
-  // create buffer
-  VkBuffer buffer;
-  VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer));
-
-  uint32_t mem_type_index = vctx.compute_mtype_index;
-
-  if (usage & VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT) {
-    // Find a memory type that supports UBO
-    auto prop = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT;
-    mem_type_index = FindMemoryType(vctx.device, vctx.phy_device, buffer, prop);
-  }
-
-  // bind to memory
-  bool dedicated_allocation = false;
-  VkMemoryRequirements2KHR req2;
-
-  if (vctx.get_buffer_memory_requirements_2_functions) {
-    VkBufferMemoryRequirementsInfo2KHR req_info2;
-    req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR;
-    req_info2.pNext = 0;
-    req_info2.buffer = buffer;
-
-    req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR;
-    req2.pNext = 0;
-
-    VkMemoryDedicatedRequirementsKHR dedicated_req;
-    dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR;
-    dedicated_req.pNext = 0;
-    req2.pNext = &dedicated_req;
-
-    vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR(
-        vctx.device, &req_info2, &req2);
-    dedicated_allocation =
-        dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation;
-  }
-
-  VkDeviceMemory memory;
-  if (!dedicated_allocation) {
-    VkMemoryAllocateInfo minfo;
-    minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
-    minfo.pNext = nullptr;
-    minfo.allocationSize = nbytes;
-    minfo.memoryTypeIndex = mem_type_index;
-    VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
-  } else {
-    VkMemoryAllocateInfo minfo;
-    minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
-    minfo.pNext = nullptr;
-    minfo.allocationSize = req2.memoryRequirements.size;
-    minfo.memoryTypeIndex = mem_type_index;
-
-    VkMemoryDedicatedAllocateInfoKHR mdinfo;
-    mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR;
-    mdinfo.pNext = 0;
-    mdinfo.image = 0;
-    mdinfo.buffer = buffer;
-    minfo.pNext = &mdinfo;
-    VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
-  }
-  VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0));
-  VulkanBuffer* pbuf = new VulkanBuffer();
-  pbuf->memory = memory;
-  pbuf->buffer = buffer;
-  return pbuf;
-}
-
 class VulkanDeviceAPI final : public DeviceAPI {
  public:
   VulkanDeviceAPI();
@@ -224,9 +124,70 @@ class VulkanDeviceAPI final : public DeviceAPI {
       nbytes = 1;
     }
     const auto& vctx = context(dev.device_id);
-    auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
+    VkBufferCreateInfo info;
+    info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
+    info.pNext = nullptr;
+    info.flags = 0;
+    info.size = nbytes;
+    info.queueFamilyIndexCount = 1;
+    info.pQueueFamilyIndices = &(vctx.queue_family_index);
+    info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
+    info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
                  VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
-    return CreateBuffer(vctx, nbytes, usage);
+    // create buffer
+    VkBuffer buffer;
+    VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer));
+    // bind to memory
+    VkBufferMemoryRequirementsInfo2KHR req_info2;
+    req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR;
+    req_info2.pNext = 0;
+    req_info2.buffer = buffer;
+
+    VkMemoryRequirements2KHR req2;
+    req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR;
+    req2.pNext = 0;
+
+    VkMemoryDedicatedRequirementsKHR dedicated_req;
+    dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR;
+    dedicated_req.pNext = 0;
+    req2.pNext = &dedicated_req;
+
+    bool dedicated_allocation = false;
+    if (vctx.get_buffer_memory_requirements_2_functions) {
+      vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR(
+          vctx.device, &req_info2, &req2);
+      dedicated_allocation =
+          dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation;
+    }
+
+    VkDeviceMemory memory;
+    if (!dedicated_allocation) {
+      VkMemoryAllocateInfo minfo;
+      minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
+      minfo.pNext = nullptr;
+      minfo.allocationSize = nbytes;
+      minfo.memoryTypeIndex = vctx.compute_mtype_index;
+      VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
+    } else {
+      VkMemoryAllocateInfo minfo;
+      minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
+      minfo.pNext = nullptr;
+      minfo.allocationSize = req2.memoryRequirements.size;
+      minfo.memoryTypeIndex = vctx.compute_mtype_index;
+
+      VkMemoryDedicatedAllocateInfoKHR mdinfo;
+      mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR;
+      mdinfo.pNext = 0;
+      mdinfo.image = 0;
+      mdinfo.buffer = buffer;
+      minfo.pNext = &mdinfo;
+      VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
+    }
+    VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0));
+    VulkanBuffer* pbuf = new VulkanBuffer();
+    pbuf->memory = memory;
+    pbuf->buffer = buffer;
+    return pbuf;
   }
 
   void FreeDataSpace(Device dev, void* ptr) final {
@@ -786,7 +747,7 @@ class VulkanModuleNode final : public runtime::ModuleNode {
  public:
   explicit VulkanModuleNode(std::unordered_map<std::string, VulkanShader> smap,
                             std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
-      : smap_(smap), fmap_(fmap), source_(source), max_push_constants_(GetMaxPushConstantsSize()) {}
+      : smap_(smap), fmap_(fmap), source_(source) {}
 
   const char* type_key() const final { return "vulkan"; }
 
@@ -820,13 +781,6 @@ class VulkanModuleNode final : public runtime::ModuleNode {
         vkDestroyDescriptorPool(vctx.device, pe->descriptor_pool, nullptr);
         vkDestroyDescriptorSetLayout(vctx.device, pe->descriptor_set_layout, nullptr);
         vkDestroyShaderModule(vctx.device, pe->shader, nullptr);
-        // UBO
-        if (pe->ubo.vk_buf) {
-          vkUnmapMemory(vctx.device, pe->ubo.vk_buf->memory);
-          vkDestroyBuffer(vctx.device, pe->ubo.vk_buf->buffer, nullptr);
-          vkFreeMemory(vctx.device, pe->ubo.vk_buf->memory, nullptr);
-          delete pe->ubo.vk_buf;
-        }
       }
     }
   }
@@ -858,35 +812,30 @@ class VulkanModuleNode final : public runtime::ModuleNode {
     std::vector<VkDescriptorUpdateTemplateEntryKHR> arg_template;
     uint32_t num_pod = 0, num_buffer = 0;
 
-    auto push_arg_info = [&arg_binding, &arg_template](uint32_t binding,
-                                                       VkDescriptorType desc_type) {
-      {
-        VkDescriptorSetLayoutBinding bd;
-        bd.binding = binding;
-        bd.descriptorType = desc_type;
-        bd.descriptorCount = 1;
-        bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
-        bd.pImmutableSamplers = nullptr;
-        arg_binding.push_back(bd);
-      }
-      {
-        VkDescriptorUpdateTemplateEntryKHR tpl;
-        tpl.dstBinding = binding;
-        tpl.dstArrayElement = 0;
-        tpl.descriptorCount = 1;
-        tpl.descriptorType = desc_type;
-        tpl.offset = binding * sizeof(VkDescriptorBufferInfo);
-        tpl.stride = sizeof(VkDescriptorBufferInfo);
-        arg_template.push_back(tpl);
-      }
-    };
-
     {
       auto fit = fmap_.find(func_name);
       ICHECK(fit != fmap_.end());
       for (DLDataType arg_type : fit->second.arg_types) {
         if (arg_type.code == kTVMOpaqueHandle) {
-          push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
+          {
+            VkDescriptorSetLayoutBinding bd;
+            bd.binding = num_buffer;
+            bd.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
+            bd.descriptorCount = 1;
+            bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
+            bd.pImmutableSamplers = nullptr;
+            arg_binding.push_back(bd);
+          }
+          {
+            VkDescriptorUpdateTemplateEntryKHR tpl;
+            tpl.dstBinding = num_buffer;
+            tpl.dstArrayElement = 0;
+            tpl.descriptorCount = 1;
+            tpl.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
+            tpl.offset = num_buffer * sizeof(VkDescriptorBufferInfo);
+            tpl.stride = sizeof(VkDescriptorBufferInfo);
+            arg_template.push_back(tpl);
+          }
           ++num_buffer;
         } else {
           ++num_pod;
@@ -894,11 +843,6 @@ class VulkanModuleNode final : public runtime::ModuleNode {
       }
     }
 
-    size_t nbytes_scalars = num_pod * sizeof(ArgUnion64);
-    if (nbytes_scalars > max_push_constants_) {
-      push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER);
-    }
-
     {
       VkDescriptorSetLayoutCreateInfo descrip_cinfo;
       descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
@@ -950,7 +894,7 @@ class VulkanModuleNode final : public runtime::ModuleNode {
     playout_cinfo.setLayoutCount = 1;
     playout_cinfo.pSetLayouts = &(pe->descriptor_set_layout);
 
-    if (0 < nbytes_scalars && nbytes_scalars <= max_push_constants_) {
+    if (num_pack_args != 0) {
       playout_cinfo.pushConstantRangeCount = 1;
       playout_cinfo.pPushConstantRanges = &crange;
       ICHECK_LE(crange.size, vctx.phy_device_prop.limits.maxPushConstantsSize);
@@ -979,13 +923,6 @@ class VulkanModuleNode final : public runtime::ModuleNode {
     VULKAN_CALL(vkCreateComputePipelines(vctx.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr,
                                          &(pe->pipeline)));
 
-    if (nbytes_scalars > max_push_constants_) {
-      // Allocate, bind and map UBO
-      UniformBuffer& ubo = pe->ubo;
-      ubo.vk_buf = CreateBuffer(vctx, nbytes_scalars, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT);
-      vkMapMemory(vctx.device, ubo.vk_buf->memory, 0, nbytes_scalars, 0, &(ubo.host_buf));
-    }
-
     if (vctx.UseImmediate()) {
       VkDescriptorUpdateTemplateCreateInfoKHR descrip_template_cinfo;
       descrip_template_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR;
@@ -1029,8 +966,6 @@ class VulkanModuleNode final : public runtime::ModuleNode {
     return source_;
   }
 
-  uint32_t MaxPushConstantsSize() const { return max_push_constants_; }
-
  private:
   // function information table.
   std::unordered_map<std::string, VulkanShader> smap_;
@@ -1040,8 +975,6 @@ class VulkanModuleNode final : public runtime::ModuleNode {
   std::string fmt_{"vulkan"};
   // The source
   std::string source_;
-  // The maximum size of push constants in bytes
-  const uint32_t max_push_constants_;
 
   // Guards accesses to `ecache_`
   std::mutex mutex_;
@@ -1143,17 +1076,6 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
     binfo.range = VK_WHOLE_SIZE;
     descriptor_buffers[i] = binfo;
   }
-  const size_t nbytes_scalars = num_pack_args_ * sizeof(ArgUnion64);
-  bool use_ubo = num_pack_args_ != 0 && nbytes_scalars > m_->MaxPushConstantsSize();
-  if (use_ubo) {
-    CHECK(pipeline->ubo.host_buf) << "The UBO host buffer is not allocated";
-    memcpy(pipeline->ubo.host_buf, pack_args, nbytes_scalars);
-    VkDescriptorBufferInfo binfo;
-    binfo.buffer = pipeline->ubo.vk_buf->buffer;
-    binfo.offset = 0;
-    binfo.range = VK_WHOLE_SIZE;
-    descriptor_buffers.push_back(binfo);
-  }
   if (vctx.UseImmediate()) {
     // Can safely capture by reference as this lambda is immediately executed on the calling thread.
     VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Launch([&](VulkanStreamState* state) {
@@ -1162,7 +1084,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
       vctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR(
           state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0,
           descriptor_buffers.data());
-      if (num_pack_args_ > 0 && num_pack_args_ <= m_->MaxPushConstantsSize()) {
+      if (num_pack_args_ != 0) {
         vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout,
                            VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64),
                            pack_args);
@@ -1183,7 +1105,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
 
   // Otherwise, the more expensive deferred path.
   std::vector<ArgUnion64> pack_args_storage(pack_args, pack_args + num_pack_args_);
-  const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers, use_ubo]() {
+  const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() {
     std::vector<VkWriteDescriptorSet> write_descriptor_sets;
     write_descriptor_sets.resize(descriptor_buffers.size());
     for (size_t i = 0; i < write_descriptor_sets.size(); i++) {
@@ -1193,26 +1115,20 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
       write_descriptor_sets[i].dstBinding = i;
       write_descriptor_sets[i].dstArrayElement = 0;
       write_descriptor_sets[i].descriptorCount = 1;
+      write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
       write_descriptor_sets[i].pImageInfo = 0;
       write_descriptor_sets[i].pBufferInfo = &(descriptor_buffers[i]);
       write_descriptor_sets[i].pTexelBufferView = 0;
-
-      if (use_ubo && i == write_descriptor_sets.size() - 1) {
-        // The last binding is for UBO
-        write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
-      } else {
-        write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
-      }
     }
     vkUpdateDescriptorSets(vctx.device, write_descriptor_sets.size(), write_descriptor_sets.data(),
                            0, 0);
   };
-  const auto& deferred_kernel = [this, pipeline, wl, pack_args_storage](VulkanStreamState* state) {
+  const auto& deferred_kernel = [pipeline, wl, pack_args_storage](VulkanStreamState* state) {
     vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline);
     vkCmdBindDescriptorSets(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE,
                             pipeline->pipeline_layout, 0, 1, &(pipeline->descriptor_set), 0,
                             nullptr);
-    if (num_pack_args_ > 0 && num_pack_args_ <= m_->MaxPushConstantsSize()) {
+    if (pack_args_storage.size() != 0) {
       vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT,
                          0, pack_args_storage.size() * sizeof(ArgUnion64),
                          pack_args_storage.data());
@@ -1267,12 +1183,6 @@ Module VulkanModuleLoadBinary(void* strm) {
   return VulkanModuleCreate(smap, fmap, "");
 }
 
-uint32_t GetMaxPushConstantsSize() {
-  int device_id = VulkanThreadEntry::ThreadLocal()->device.device_id;
-  const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
-  return vctx.phy_device_prop.limits.maxPushConstantsSize;
-}
-
 TVM_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile);
 
 TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary);
diff --git a/src/runtime/vulkan/vulkan_common.h b/src/runtime/vulkan/vulkan_common.h
index e94a9fe..3083ba6 100644
--- a/src/runtime/vulkan/vulkan_common.h
+++ b/src/runtime/vulkan/vulkan_common.h
@@ -142,9 +142,6 @@ struct VulkanContext {
   bool UseImmediate() const { return descriptor_template_khr_functions.get() != nullptr; }
 };
 
-/*! \brief returns maximum push constant sizes in bytes for the target platform */
-uint32_t GetMaxPushConstantsSize();
-
 }  // namespace vulkan
 }  // namespace runtime
 }  // namespace tvm
diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc
index 4d55f4c..24608eb 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -30,9 +30,6 @@
 
 #include <string>
 
-#include "../../runtime/pack_args.h"
-#include "../../runtime/vulkan/vulkan_common.h"
-
 namespace tvm {
 namespace codegen {
 
@@ -69,26 +66,16 @@ std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::
   spirv::Value func_ptr = builder_->NewFunction();
   builder_->StartFunction(func_ptr);
 
+  // All the POD arguments are passed in through PushConstant
   if (pod_args.size() != 0) {
     std::vector<spirv::SType> value_types;
     for (size_t i = 0; i < pod_args.size(); ++i) {
       value_types.push_back(builder_->GetSType(pod_args[i].dtype()));
     }
-    const auto max_push_constants = runtime::vulkan::GetMaxPushConstantsSize();
-    if (pod_args.size() * sizeof(runtime::ArgUnion64) <= max_push_constants) {
-      spirv::Value ptr = builder_->DeclarePushConstant(value_types);
-      for (size_t i = 0; i < pod_args.size(); ++i) {
-        spirv::Value value =
-            builder_->GetPushConstant(ptr, value_types[i], static_cast<uint32_t>(i));
-        var_map_[pod_args[i].get()] = value;
-      }
-    } else {
-      // If we need to pass more arguments than push constants could handle, we use UBO.
-      spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, num_buffer);
-      for (size_t i = 0; i < pod_args.size(); ++i) {
-        spirv::Value value = builder_->GetUniform(ptr, value_types[i], static_cast<uint32_t>(i));
-        var_map_[pod_args[i].get()] = value;
-      }
+    spirv::Value ptr = builder_->DeclarePushConstant(value_types);
+    for (size_t i = 0; i < pod_args.size(); ++i) {
+      spirv::Value value = builder_->GetPushConstant(ptr, value_types[i], static_cast<uint32_t>(i));
+      var_map_[pod_args[i].get()] = value;
     }
   }
   this->VisitStmt(f->body);
diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc
index cd48c93..5a14573 100644
--- a/src/target/spirv/ir_builder.cc
+++ b/src/target/spirv/ir_builder.cc
@@ -205,8 +205,8 @@ Value IRBuilder::BufferArgument(const SType& value_type, uint32_t descriptor_set
   return val;
 }
 
-Value IRBuilder::DeclareStorageVariable(const std::vector<SType>& value_types,
-                                        spv::StorageClass storage_class, ValueKind kind) {
+Value IRBuilder::DeclarePushConstant(const std::vector<SType>& value_types) {
+  ICHECK_EQ(push_const_.id, 0);
   SType struct_type;
   struct_type.id = id_counter_++;
   struct_type.type = DataType::Handle();
@@ -226,26 +226,22 @@ Value IRBuilder::DeclareStorageVariable(const std::vector<SType>& value_types,
     ICHECK_EQ(nbits % 8, 0);
     uint32_t bytes = (nbits / 8);
     if (t.bits() == 32) {
-      // In our Vulkan runtime, each scalar argument always occupies 64 bit.
+      // In our Vulkan runtime, each push constant always occupies 64 bit.
       offset += bytes * 2;
     } else {
       ICHECK_EQ(t.bits(), 64);
       offset += bytes;
     }
   }
+  // Decorate push constants as UBO
   this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock);
 
-  SType ptr_type = GetPointerType(struct_type, storage_class);
-  Value val = NewValue(ptr_type, kind);
-  ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&global_);
+  SType ptr_type = GetPointerType(struct_type, spv::StorageClassPushConstant);
+  Value val = NewValue(ptr_type, kPushConstantPtr);
+  ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, spv::StorageClassPushConstant).Commit(&global_);
   return val;
 }
 
-Value IRBuilder::DeclarePushConstant(const std::vector<SType>& value_types) {
-  ICHECK_EQ(push_const_.id, 0);
-  return DeclareStorageVariable(value_types, spv::StorageClassPushConstant, kPushConstantPtr);
-}
-
 Value IRBuilder::GetPushConstant(Value ptr_push_const, const SType& v_type, uint32_t index) {
   SType ptr_vtype = this->GetPointerType(v_type, spv::StorageClassPushConstant);
   Value ptr = this->MakeValue(spv::OpAccessChain, ptr_vtype, ptr_push_const,
@@ -253,19 +249,6 @@ Value IRBuilder::GetPushConstant(Value ptr_push_const, const SType& v_type, uint
   return this->MakeValue(spv::OpLoad, v_type, ptr);
 }
 
-Value IRBuilder::DeclareUniformBuffer(const std::vector<SType>& value_types, uint32_t binding) {
-  Value val = DeclareStorageVariable(value_types, spv::StorageClassUniform, kUniformPtr);
-  this->Decorate(spv::OpDecorate, val, spv::DecorationBinding, binding);
-  return val;
-}
-
-Value IRBuilder::GetUniform(Value ptr_push_const, const SType& v_type, uint32_t index) {
-  SType ptr_vtype = this->GetPointerType(v_type, spv::StorageClassUniform);
-  Value ptr = this->MakeValue(spv::OpAccessChain, ptr_vtype, ptr_push_const,
-                              IntImm(t_int32_, static_cast<int64_t>(index)));
-  return this->MakeValue(spv::OpLoad, v_type, ptr);
-}
-
 Value IRBuilder::NewFunction() { return NewValue(t_void_func_, kFunction); }
 
 void IRBuilder::CommitKernelFunction(const Value& func, const std::string& name) {
diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h
index 05a2bc6..8a08048 100644
--- a/src/target/spirv/ir_builder.h
+++ b/src/target/spirv/ir_builder.h
@@ -60,8 +60,7 @@ enum ValueKind {
   kStructArrayPtr,
   kPushConstantPtr,
   kFunction,
-  kExtInst,
-  kUniformPtr
+  kExtInst
 };
 
 /*! \brief Represent the SPIRV Value */
@@ -474,7 +473,6 @@ class IRBuilder {
    * \param The argument type.
    */
   Value BufferArgument(const SType& value_type, uint32_t descriptor_set, uint32_t binding);
-
   /*!
    * \brief Declare POD arguments through push constants.
    *
@@ -490,23 +488,6 @@ class IRBuilder {
    * \return the value of push constant
    */
   Value GetPushConstant(Value ptr_push_const, const SType& v_type, uint32_t index);
-
-  /*!
-   * \brief Declare POD arguments through uniform buffer.
-   *
-   * \note Only call this function once!
-   * \param value_types The values in the uniform buffer
-   * \param binding The binding locaiton in descriptor set
-   * \return reference to self.
-   */
-  Value DeclareUniformBuffer(const std::vector<SType>& value_types, uint32_t binding);
-  /*!
-   * \brief Get i-th uniform constant
-   * \param v_type The value type
-   * \param index The uniform index
-   * \return the value of uniform constant
-   */
-  Value GetUniform(Value ptr_ubo, const SType& v_type, uint32_t index);
   /*!
    * \brief Declare a new function
    * \return The created function ID.
@@ -574,17 +555,6 @@ class IRBuilder {
     val.flag = flag;
     return val;
   }
-
-  /*!
-   * \brief The common function to declare push constants or uniform buffer
-   * \param value_types The values in the push constants or uniform buffer
-   * \param storage_class An enum defined by SPIR-V indicating push constant or uniform
-   * \param kind An enum indicating push constant or uniform
-   * \return The created new label
-   */
-  Value DeclareStorageVariable(const std::vector<SType>& value_types,
-                               spv::StorageClass storage_class, ValueKind kind);
-
   // get constant given value encoded in uint64_t
   Value GetConst_(const SType& dtype, const uint64_t* pvalue);
   // declare type