You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/04/11 19:08:49 UTC
[tvm] branch main updated: Revert "[Vulkan] Support uniform buffer
object for passing many scalar arguments (#7717)" (#7821)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ab0dc2e Revert "[Vulkan] Support uniform buffer object for passing many scalar arguments (#7717)" (#7821)
ab0dc2e is described below
commit ab0dc2e6c3aa5cdd93fe3e71f37abff80cec2d38
Author: masahi <ma...@gmail.com>
AuthorDate: Mon Apr 12 04:08:24 2021 +0900
Revert "[Vulkan] Support uniform buffer object for passing many scalar arguments (#7717)" (#7821)
This reverts commit 5bc1cec4c4acf0a54889227c1d19a6b65b6803c2.
---
src/runtime/vulkan/vulkan.cc | 268 ++++++++++++-------------------------
src/runtime/vulkan/vulkan_common.h | 3 -
src/target/spirv/codegen_spirv.cc | 23 +---
src/target/spirv/ir_builder.cc | 31 +----
src/target/spirv/ir_builder.h | 32 +----
5 files changed, 102 insertions(+), 255 deletions(-)
diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc
index c8a0858..5cd4812 100644
--- a/src/runtime/vulkan/vulkan.cc
+++ b/src/runtime/vulkan/vulkan.cc
@@ -91,11 +91,6 @@ struct VulkanBuffer {
VkDeviceMemory memory{VK_NULL_HANDLE};
};
-struct UniformBuffer {
- VulkanBuffer* vk_buf;
- void* host_buf;
-};
-
struct VulkanPipeline {
VulkanContext* vctx_{nullptr};
VkShaderModule shader{VK_NULL_HANDLE};
@@ -105,105 +100,10 @@ struct VulkanPipeline {
VkPipelineLayout pipeline_layout{VK_NULL_HANDLE};
VkPipeline pipeline{VK_NULL_HANDLE};
VkDescriptorUpdateTemplateKHR descriptor_update_template{VK_NULL_HANDLE};
- UniformBuffer ubo;
};
typedef dmlc::ThreadLocalStore<VulkanThreadEntry> VulkanThreadStore;
-uint32_t FindMemoryType(VkDevice logical_device, VkPhysicalDevice phy_device, VkBuffer buffer,
- VkMemoryPropertyFlags req_prop) {
- VkMemoryRequirements mem_reqs;
- vkGetBufferMemoryRequirements(logical_device, buffer, &mem_reqs);
- uint32_t type_bits = mem_reqs.memoryTypeBits;
- VkPhysicalDeviceMemoryProperties phy_mem_prop;
- vkGetPhysicalDeviceMemoryProperties(phy_device, &phy_mem_prop);
- for (uint32_t i = 0; i < phy_mem_prop.memoryTypeCount; i++) {
- if ((type_bits & 1) == 1 &&
- (phy_mem_prop.memoryTypes[i].propertyFlags & req_prop) == req_prop) {
- return i;
- }
- type_bits >>= 1;
- }
- LOG(FATAL) << "Requested memory type not found";
- return 0;
-}
-
-VulkanBuffer* CreateBuffer(const VulkanContext& vctx, size_t nbytes, VkBufferUsageFlags usage) {
- VkBufferCreateInfo info;
- info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
- info.pNext = nullptr;
- info.flags = 0;
- info.size = nbytes;
- info.queueFamilyIndexCount = 1;
- info.pQueueFamilyIndices = &(vctx.queue_family_index);
- info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
- info.usage = usage;
- // create buffer
- VkBuffer buffer;
- VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer));
-
- uint32_t mem_type_index = vctx.compute_mtype_index;
-
- if (usage & VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT) {
- // Find a memory type that supports UBO
- auto prop = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT;
- mem_type_index = FindMemoryType(vctx.device, vctx.phy_device, buffer, prop);
- }
-
- // bind to memory
- bool dedicated_allocation = false;
- VkMemoryRequirements2KHR req2;
-
- if (vctx.get_buffer_memory_requirements_2_functions) {
- VkBufferMemoryRequirementsInfo2KHR req_info2;
- req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR;
- req_info2.pNext = 0;
- req_info2.buffer = buffer;
-
- req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR;
- req2.pNext = 0;
-
- VkMemoryDedicatedRequirementsKHR dedicated_req;
- dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR;
- dedicated_req.pNext = 0;
- req2.pNext = &dedicated_req;
-
- vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR(
- vctx.device, &req_info2, &req2);
- dedicated_allocation =
- dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation;
- }
-
- VkDeviceMemory memory;
- if (!dedicated_allocation) {
- VkMemoryAllocateInfo minfo;
- minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
- minfo.pNext = nullptr;
- minfo.allocationSize = nbytes;
- minfo.memoryTypeIndex = mem_type_index;
- VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
- } else {
- VkMemoryAllocateInfo minfo;
- minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
- minfo.pNext = nullptr;
- minfo.allocationSize = req2.memoryRequirements.size;
- minfo.memoryTypeIndex = mem_type_index;
-
- VkMemoryDedicatedAllocateInfoKHR mdinfo;
- mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR;
- mdinfo.pNext = 0;
- mdinfo.image = 0;
- mdinfo.buffer = buffer;
- minfo.pNext = &mdinfo;
- VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
- }
- VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0));
- VulkanBuffer* pbuf = new VulkanBuffer();
- pbuf->memory = memory;
- pbuf->buffer = buffer;
- return pbuf;
-}
-
class VulkanDeviceAPI final : public DeviceAPI {
public:
VulkanDeviceAPI();
@@ -224,9 +124,70 @@ class VulkanDeviceAPI final : public DeviceAPI {
nbytes = 1;
}
const auto& vctx = context(dev.device_id);
- auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
+ VkBufferCreateInfo info;
+ info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
+ info.pNext = nullptr;
+ info.flags = 0;
+ info.size = nbytes;
+ info.queueFamilyIndexCount = 1;
+ info.pQueueFamilyIndices = &(vctx.queue_family_index);
+ info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
+ info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT |
VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
- return CreateBuffer(vctx, nbytes, usage);
+ // create buffer
+ VkBuffer buffer;
+ VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer));
+ // bind to memory
+ VkBufferMemoryRequirementsInfo2KHR req_info2;
+ req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR;
+ req_info2.pNext = 0;
+ req_info2.buffer = buffer;
+
+ VkMemoryRequirements2KHR req2;
+ req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR;
+ req2.pNext = 0;
+
+ VkMemoryDedicatedRequirementsKHR dedicated_req;
+ dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR;
+ dedicated_req.pNext = 0;
+ req2.pNext = &dedicated_req;
+
+ bool dedicated_allocation = false;
+ if (vctx.get_buffer_memory_requirements_2_functions) {
+ vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR(
+ vctx.device, &req_info2, &req2);
+ dedicated_allocation =
+ dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation;
+ }
+
+ VkDeviceMemory memory;
+ if (!dedicated_allocation) {
+ VkMemoryAllocateInfo minfo;
+ minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
+ minfo.pNext = nullptr;
+ minfo.allocationSize = nbytes;
+ minfo.memoryTypeIndex = vctx.compute_mtype_index;
+ VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
+ } else {
+ VkMemoryAllocateInfo minfo;
+ minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
+ minfo.pNext = nullptr;
+ minfo.allocationSize = req2.memoryRequirements.size;
+ minfo.memoryTypeIndex = vctx.compute_mtype_index;
+
+ VkMemoryDedicatedAllocateInfoKHR mdinfo;
+ mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR;
+ mdinfo.pNext = 0;
+ mdinfo.image = 0;
+ mdinfo.buffer = buffer;
+ minfo.pNext = &mdinfo;
+ VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory));
+ }
+ VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0));
+ VulkanBuffer* pbuf = new VulkanBuffer();
+ pbuf->memory = memory;
+ pbuf->buffer = buffer;
+ return pbuf;
}
void FreeDataSpace(Device dev, void* ptr) final {
@@ -786,7 +747,7 @@ class VulkanModuleNode final : public runtime::ModuleNode {
public:
explicit VulkanModuleNode(std::unordered_map<std::string, VulkanShader> smap,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
- : smap_(smap), fmap_(fmap), source_(source), max_push_constants_(GetMaxPushConstantsSize()) {}
+ : smap_(smap), fmap_(fmap), source_(source) {}
const char* type_key() const final { return "vulkan"; }
@@ -820,13 +781,6 @@ class VulkanModuleNode final : public runtime::ModuleNode {
vkDestroyDescriptorPool(vctx.device, pe->descriptor_pool, nullptr);
vkDestroyDescriptorSetLayout(vctx.device, pe->descriptor_set_layout, nullptr);
vkDestroyShaderModule(vctx.device, pe->shader, nullptr);
- // UBO
- if (pe->ubo.vk_buf) {
- vkUnmapMemory(vctx.device, pe->ubo.vk_buf->memory);
- vkDestroyBuffer(vctx.device, pe->ubo.vk_buf->buffer, nullptr);
- vkFreeMemory(vctx.device, pe->ubo.vk_buf->memory, nullptr);
- delete pe->ubo.vk_buf;
- }
}
}
}
@@ -858,35 +812,30 @@ class VulkanModuleNode final : public runtime::ModuleNode {
std::vector<VkDescriptorUpdateTemplateEntryKHR> arg_template;
uint32_t num_pod = 0, num_buffer = 0;
- auto push_arg_info = [&arg_binding, &arg_template](uint32_t binding,
- VkDescriptorType desc_type) {
- {
- VkDescriptorSetLayoutBinding bd;
- bd.binding = binding;
- bd.descriptorType = desc_type;
- bd.descriptorCount = 1;
- bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
- bd.pImmutableSamplers = nullptr;
- arg_binding.push_back(bd);
- }
- {
- VkDescriptorUpdateTemplateEntryKHR tpl;
- tpl.dstBinding = binding;
- tpl.dstArrayElement = 0;
- tpl.descriptorCount = 1;
- tpl.descriptorType = desc_type;
- tpl.offset = binding * sizeof(VkDescriptorBufferInfo);
- tpl.stride = sizeof(VkDescriptorBufferInfo);
- arg_template.push_back(tpl);
- }
- };
-
{
auto fit = fmap_.find(func_name);
ICHECK(fit != fmap_.end());
for (DLDataType arg_type : fit->second.arg_types) {
if (arg_type.code == kTVMOpaqueHandle) {
- push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
+ {
+ VkDescriptorSetLayoutBinding bd;
+ bd.binding = num_buffer;
+ bd.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
+ bd.descriptorCount = 1;
+ bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
+ bd.pImmutableSamplers = nullptr;
+ arg_binding.push_back(bd);
+ }
+ {
+ VkDescriptorUpdateTemplateEntryKHR tpl;
+ tpl.dstBinding = num_buffer;
+ tpl.dstArrayElement = 0;
+ tpl.descriptorCount = 1;
+ tpl.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
+ tpl.offset = num_buffer * sizeof(VkDescriptorBufferInfo);
+ tpl.stride = sizeof(VkDescriptorBufferInfo);
+ arg_template.push_back(tpl);
+ }
++num_buffer;
} else {
++num_pod;
@@ -894,11 +843,6 @@ class VulkanModuleNode final : public runtime::ModuleNode {
}
}
- size_t nbytes_scalars = num_pod * sizeof(ArgUnion64);
- if (nbytes_scalars > max_push_constants_) {
- push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER);
- }
-
{
VkDescriptorSetLayoutCreateInfo descrip_cinfo;
descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
@@ -950,7 +894,7 @@ class VulkanModuleNode final : public runtime::ModuleNode {
playout_cinfo.setLayoutCount = 1;
playout_cinfo.pSetLayouts = &(pe->descriptor_set_layout);
- if (0 < nbytes_scalars && nbytes_scalars <= max_push_constants_) {
+ if (num_pack_args != 0) {
playout_cinfo.pushConstantRangeCount = 1;
playout_cinfo.pPushConstantRanges = &crange;
ICHECK_LE(crange.size, vctx.phy_device_prop.limits.maxPushConstantsSize);
@@ -979,13 +923,6 @@ class VulkanModuleNode final : public runtime::ModuleNode {
VULKAN_CALL(vkCreateComputePipelines(vctx.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr,
&(pe->pipeline)));
- if (nbytes_scalars > max_push_constants_) {
- // Allocate, bind and map UBO
- UniformBuffer& ubo = pe->ubo;
- ubo.vk_buf = CreateBuffer(vctx, nbytes_scalars, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT);
- vkMapMemory(vctx.device, ubo.vk_buf->memory, 0, nbytes_scalars, 0, &(ubo.host_buf));
- }
-
if (vctx.UseImmediate()) {
VkDescriptorUpdateTemplateCreateInfoKHR descrip_template_cinfo;
descrip_template_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR;
@@ -1029,8 +966,6 @@ class VulkanModuleNode final : public runtime::ModuleNode {
return source_;
}
- uint32_t MaxPushConstantsSize() const { return max_push_constants_; }
-
private:
// function information table.
std::unordered_map<std::string, VulkanShader> smap_;
@@ -1040,8 +975,6 @@ class VulkanModuleNode final : public runtime::ModuleNode {
std::string fmt_{"vulkan"};
// The source
std::string source_;
- // The maximum size of push constants in bytes
- const uint32_t max_push_constants_;
// Guards accesses to `ecache_`
std::mutex mutex_;
@@ -1143,17 +1076,6 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
binfo.range = VK_WHOLE_SIZE;
descriptor_buffers[i] = binfo;
}
- const size_t nbytes_scalars = num_pack_args_ * sizeof(ArgUnion64);
- bool use_ubo = num_pack_args_ != 0 && nbytes_scalars > m_->MaxPushConstantsSize();
- if (use_ubo) {
- CHECK(pipeline->ubo.host_buf) << "The UBO host buffer is not allocated";
- memcpy(pipeline->ubo.host_buf, pack_args, nbytes_scalars);
- VkDescriptorBufferInfo binfo;
- binfo.buffer = pipeline->ubo.vk_buf->buffer;
- binfo.offset = 0;
- binfo.range = VK_WHOLE_SIZE;
- descriptor_buffers.push_back(binfo);
- }
if (vctx.UseImmediate()) {
// Can safely capture by reference as this lambda is immediately executed on the calling thread.
VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Launch([&](VulkanStreamState* state) {
@@ -1162,7 +1084,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
vctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR(
state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0,
descriptor_buffers.data());
- if (num_pack_args_ > 0 && num_pack_args_ <= m_->MaxPushConstantsSize()) {
+ if (num_pack_args_ != 0) {
vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout,
VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64),
pack_args);
@@ -1183,7 +1105,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
// Otherwise, the more expensive deferred path.
std::vector<ArgUnion64> pack_args_storage(pack_args, pack_args + num_pack_args_);
- const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers, use_ubo]() {
+ const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() {
std::vector<VkWriteDescriptorSet> write_descriptor_sets;
write_descriptor_sets.resize(descriptor_buffers.size());
for (size_t i = 0; i < write_descriptor_sets.size(); i++) {
@@ -1193,26 +1115,20 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
write_descriptor_sets[i].dstBinding = i;
write_descriptor_sets[i].dstArrayElement = 0;
write_descriptor_sets[i].descriptorCount = 1;
+ write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
write_descriptor_sets[i].pImageInfo = 0;
write_descriptor_sets[i].pBufferInfo = &(descriptor_buffers[i]);
write_descriptor_sets[i].pTexelBufferView = 0;
-
- if (use_ubo && i == write_descriptor_sets.size() - 1) {
- // The last binding is for UBO
- write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
- } else {
- write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
- }
}
vkUpdateDescriptorSets(vctx.device, write_descriptor_sets.size(), write_descriptor_sets.data(),
0, 0);
};
- const auto& deferred_kernel = [this, pipeline, wl, pack_args_storage](VulkanStreamState* state) {
+ const auto& deferred_kernel = [pipeline, wl, pack_args_storage](VulkanStreamState* state) {
vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline);
vkCmdBindDescriptorSets(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE,
pipeline->pipeline_layout, 0, 1, &(pipeline->descriptor_set), 0,
nullptr);
- if (num_pack_args_ > 0 && num_pack_args_ <= m_->MaxPushConstantsSize()) {
+ if (pack_args_storage.size() != 0) {
vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT,
0, pack_args_storage.size() * sizeof(ArgUnion64),
pack_args_storage.data());
@@ -1267,12 +1183,6 @@ Module VulkanModuleLoadBinary(void* strm) {
return VulkanModuleCreate(smap, fmap, "");
}
-uint32_t GetMaxPushConstantsSize() {
- int device_id = VulkanThreadEntry::ThreadLocal()->device.device_id;
- const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
- return vctx.phy_device_prop.limits.maxPushConstantsSize;
-}
-
TVM_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile);
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary);
diff --git a/src/runtime/vulkan/vulkan_common.h b/src/runtime/vulkan/vulkan_common.h
index e94a9fe..3083ba6 100644
--- a/src/runtime/vulkan/vulkan_common.h
+++ b/src/runtime/vulkan/vulkan_common.h
@@ -142,9 +142,6 @@ struct VulkanContext {
bool UseImmediate() const { return descriptor_template_khr_functions.get() != nullptr; }
};
-/*! \brief returns maximum push constant sizes in bytes for the target platform */
-uint32_t GetMaxPushConstantsSize();
-
} // namespace vulkan
} // namespace runtime
} // namespace tvm
diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc
index 4d55f4c..24608eb 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -30,9 +30,6 @@
#include <string>
-#include "../../runtime/pack_args.h"
-#include "../../runtime/vulkan/vulkan_common.h"
-
namespace tvm {
namespace codegen {
@@ -69,26 +66,16 @@ std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::
spirv::Value func_ptr = builder_->NewFunction();
builder_->StartFunction(func_ptr);
+ // All the POD arguments are passed in through PushConstant
if (pod_args.size() != 0) {
std::vector<spirv::SType> value_types;
for (size_t i = 0; i < pod_args.size(); ++i) {
value_types.push_back(builder_->GetSType(pod_args[i].dtype()));
}
- const auto max_push_constants = runtime::vulkan::GetMaxPushConstantsSize();
- if (pod_args.size() * sizeof(runtime::ArgUnion64) <= max_push_constants) {
- spirv::Value ptr = builder_->DeclarePushConstant(value_types);
- for (size_t i = 0; i < pod_args.size(); ++i) {
- spirv::Value value =
- builder_->GetPushConstant(ptr, value_types[i], static_cast<uint32_t>(i));
- var_map_[pod_args[i].get()] = value;
- }
- } else {
- // If we need to pass more arguments than push constants could handle, we use UBO.
- spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, num_buffer);
- for (size_t i = 0; i < pod_args.size(); ++i) {
- spirv::Value value = builder_->GetUniform(ptr, value_types[i], static_cast<uint32_t>(i));
- var_map_[pod_args[i].get()] = value;
- }
+ spirv::Value ptr = builder_->DeclarePushConstant(value_types);
+ for (size_t i = 0; i < pod_args.size(); ++i) {
+ spirv::Value value = builder_->GetPushConstant(ptr, value_types[i], static_cast<uint32_t>(i));
+ var_map_[pod_args[i].get()] = value;
}
}
this->VisitStmt(f->body);
diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc
index cd48c93..5a14573 100644
--- a/src/target/spirv/ir_builder.cc
+++ b/src/target/spirv/ir_builder.cc
@@ -205,8 +205,8 @@ Value IRBuilder::BufferArgument(const SType& value_type, uint32_t descriptor_set
return val;
}
-Value IRBuilder::DeclareStorageVariable(const std::vector<SType>& value_types,
- spv::StorageClass storage_class, ValueKind kind) {
+Value IRBuilder::DeclarePushConstant(const std::vector<SType>& value_types) {
+ ICHECK_EQ(push_const_.id, 0);
SType struct_type;
struct_type.id = id_counter_++;
struct_type.type = DataType::Handle();
@@ -226,26 +226,22 @@ Value IRBuilder::DeclareStorageVariable(const std::vector<SType>& value_types,
ICHECK_EQ(nbits % 8, 0);
uint32_t bytes = (nbits / 8);
if (t.bits() == 32) {
- // In our Vulkan runtime, each scalar argument always occupies 64 bit.
+ // In our Vulkan runtime, each push constant always occupies 64 bit.
offset += bytes * 2;
} else {
ICHECK_EQ(t.bits(), 64);
offset += bytes;
}
}
+ // Decorate push constants as UBO
this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock);
- SType ptr_type = GetPointerType(struct_type, storage_class);
- Value val = NewValue(ptr_type, kind);
- ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&global_);
+ SType ptr_type = GetPointerType(struct_type, spv::StorageClassPushConstant);
+ Value val = NewValue(ptr_type, kPushConstantPtr);
+ ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, spv::StorageClassPushConstant).Commit(&global_);
return val;
}
-Value IRBuilder::DeclarePushConstant(const std::vector<SType>& value_types) {
- ICHECK_EQ(push_const_.id, 0);
- return DeclareStorageVariable(value_types, spv::StorageClassPushConstant, kPushConstantPtr);
-}
-
Value IRBuilder::GetPushConstant(Value ptr_push_const, const SType& v_type, uint32_t index) {
SType ptr_vtype = this->GetPointerType(v_type, spv::StorageClassPushConstant);
Value ptr = this->MakeValue(spv::OpAccessChain, ptr_vtype, ptr_push_const,
@@ -253,19 +249,6 @@ Value IRBuilder::GetPushConstant(Value ptr_push_const, const SType& v_type, uint
return this->MakeValue(spv::OpLoad, v_type, ptr);
}
-Value IRBuilder::DeclareUniformBuffer(const std::vector<SType>& value_types, uint32_t binding) {
- Value val = DeclareStorageVariable(value_types, spv::StorageClassUniform, kUniformPtr);
- this->Decorate(spv::OpDecorate, val, spv::DecorationBinding, binding);
- return val;
-}
-
-Value IRBuilder::GetUniform(Value ptr_push_const, const SType& v_type, uint32_t index) {
- SType ptr_vtype = this->GetPointerType(v_type, spv::StorageClassUniform);
- Value ptr = this->MakeValue(spv::OpAccessChain, ptr_vtype, ptr_push_const,
- IntImm(t_int32_, static_cast<int64_t>(index)));
- return this->MakeValue(spv::OpLoad, v_type, ptr);
-}
-
Value IRBuilder::NewFunction() { return NewValue(t_void_func_, kFunction); }
void IRBuilder::CommitKernelFunction(const Value& func, const std::string& name) {
diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h
index 05a2bc6..8a08048 100644
--- a/src/target/spirv/ir_builder.h
+++ b/src/target/spirv/ir_builder.h
@@ -60,8 +60,7 @@ enum ValueKind {
kStructArrayPtr,
kPushConstantPtr,
kFunction,
- kExtInst,
- kUniformPtr
+ kExtInst
};
/*! \brief Represent the SPIRV Value */
@@ -474,7 +473,6 @@ class IRBuilder {
* \param The argument type.
*/
Value BufferArgument(const SType& value_type, uint32_t descriptor_set, uint32_t binding);
-
/*!
* \brief Declare POD arguments through push constants.
*
@@ -490,23 +488,6 @@ class IRBuilder {
* \return the value of push constant
*/
Value GetPushConstant(Value ptr_push_const, const SType& v_type, uint32_t index);
-
- /*!
- * \brief Declare POD arguments through uniform buffer.
- *
- * \note Only call this function once!
- * \param value_types The values in the uniform buffer
- * \param binding The binding locaiton in descriptor set
- * \return reference to self.
- */
- Value DeclareUniformBuffer(const std::vector<SType>& value_types, uint32_t binding);
- /*!
- * \brief Get i-th uniform constant
- * \param v_type The value type
- * \param index The uniform index
- * \return the value of uniform constant
- */
- Value GetUniform(Value ptr_ubo, const SType& v_type, uint32_t index);
/*!
* \brief Declare a new function
* \return The created function ID.
@@ -574,17 +555,6 @@ class IRBuilder {
val.flag = flag;
return val;
}
-
- /*!
- * \brief The common function to declare push constants or uniform buffer
- * \param value_types The values in the push constants or uniform buffer
- * \param storage_class An enum defined by SPIR-V indicating push constant or uniform
- * \param kind An enum indicating push constant or uniform
- * \return The created new label
- */
- Value DeclareStorageVariable(const std::vector<SType>& value_types,
- spv::StorageClass storage_class, ValueKind kind);
-
// get constant given value encoded in uint64_t
Value GetConst_(const SType& dtype, const uint64_t* pvalue);
// declare type