You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/05/22 00:32:04 UTC

[tvm] branch main updated: [Vulkan] Remove some interface block decoration (#8102)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a2bf07f  [Vulkan] Remove some interface block decoration (#8102)
a2bf07f is described below

commit a2bf07ff34eb954951309f1672da2ad07f6d10f8
Author: llehtahw <qa...@gmail.com>
AuthorDate: Sat May 22 08:30:55 2021 +0800

    [Vulkan] Remove some interface block decoration (#8102)
    
    * Remove block decorator for shared/local variables
    
    * Fix lint
---
 src/target/spirv/ir_builder.cc         | 33 ++++++++++++++++++---------------
 src/target/spirv/ir_builder.h          |  7 +++++--
 tests/python/integration/test_ewise.py |  3 +++
 3 files changed, 26 insertions(+), 17 deletions(-)

diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc
index b27d0e9..c2460b2 100644
--- a/src/target/spirv/ir_builder.cc
+++ b/src/target/spirv/ir_builder.cc
@@ -137,8 +137,9 @@ SType IRBuilder::GetPointerType(const SType& value_type, spv::StorageClass stora
   return t;
 }
 
-SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems) {
-  auto key = std::make_pair(value_type.id, num_elems);
+SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems,
+                                    bool interface_block) {
+  auto key = std::make_tuple(value_type.id, num_elems, interface_block);
   auto it = struct_array_type_tbl_.find(key);
   if (it != struct_array_type_tbl_.end()) {
     return it->second;
@@ -171,17 +172,19 @@ SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems)
       .AddSeq(struct_type, 0, spv::DecorationOffset, 0)
       .Commit(&decorate_);
 
-  // Runtime array are always decorated as Block or BufferBlock
-  // (shader storage buffer)
-  if (spirv_support_.supports_storage_buffer_storage_class) {
-    // If SPIRV 1.3+, or with extension
-    // SPV_KHR_storage_buffer_storage_class, BufferBlock is
-    // deprecated.
-    extensions_used_.insert("SPV_KHR_storage_buffer_storage_class");
-    this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock);
-  } else {
-    if (num_elems == 0) {
-      this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBufferBlock);
+  if (interface_block) {
+    // Runtime array are always decorated as Block or BufferBlock
+    // (shader storage buffer)
+    if (spirv_support_.supports_storage_buffer_storage_class) {
+      // If SPIRV 1.3+, or with extension
+      // SPV_KHR_storage_buffer_storage_class, BufferBlock is
+      // deprecated.
+      extensions_used_.insert("SPV_KHR_storage_buffer_storage_class");
+      this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock);
+    } else {
+      if (num_elems == 0) {
+        this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBufferBlock);
+      }
     }
   }
   struct_array_type_tbl_[key] = struct_type;
@@ -224,7 +227,7 @@ Value IRBuilder::BufferArgument(const SType& value_type, uint32_t descriptor_set
     storage_class = spv::StorageClassUniform;
   }
 
-  SType sarr_type = GetStructArrayType(value_type, 0);
+  SType sarr_type = GetStructArrayType(value_type, 0, true);
   SType ptr_type = GetPointerType(sarr_type, storage_class);
   Value val = NewValue(ptr_type, kStructArrayPtr);
 
@@ -335,7 +338,7 @@ void IRBuilder::SetLocalSize(const Value& func, uint32_t local_size[3]) {
 Value IRBuilder::Allocate(const SType& value_type, uint32_t num_elems,
                           spv::StorageClass storage_class) {
   ICHECK_NE(num_elems, 0U);
-  SType sarr_type = GetStructArrayType(value_type, num_elems);
+  SType sarr_type = GetStructArrayType(value_type, num_elems, false);
   SType ptr_type = GetPointerType(sarr_type, storage_class);
   Value val = NewValue(ptr_type, kStructArrayPtr);
   if (storage_class == spv::StorageClassFunction) {
diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h
index afd9be9..959ed29 100644
--- a/src/target/spirv/ir_builder.h
+++ b/src/target/spirv/ir_builder.h
@@ -35,6 +35,7 @@
 #include <unordered_map>
 #include <utility>
 #include <vector>
+#include <tuple>
 #include <spirv.hpp>
 // clang-format on
 
@@ -432,10 +433,12 @@ class IRBuilder {
    * \param value_type the content value type.
    * \param num_elems number of elements in array
    *   num_elems = 0 means runtime array with BufferBlock Decoration
+   * \param interface_block if this array type for interface blocks(input, output, uniform,
+   *   storage buffer).
    *
    * \return The corresponding spirv type.
    */
-  SType GetStructArrayType(const SType& value_type, uint32_t num_elems);
+  SType GetStructArrayType(const SType& value_type, uint32_t num_elems, bool interface_block);
   /*!
    * \brief Get a struct array access with a given index.
    * \param ptr_type The pointer type.
@@ -634,7 +637,7 @@ class IRBuilder {
   /*! \brief map from type code to the type */
   std::unordered_map<uint32_t, SType> pod_type_tbl_;
   /*! \brief map from value to array type */
-  std::map<std::pair<uint32_t, uint32_t>, SType> struct_array_type_tbl_;
+  std::map<std::tuple<uint32_t, uint32_t, bool>, SType> struct_array_type_tbl_;
   /*! \brief map from value to its pointer type */
   std::map<std::pair<uint32_t, spv::StorageClass>, SType> pointer_type_tbl_;
   /*! \brief map from constant int to its value */
diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py
index 1f7deb0..5461a51 100644
--- a/tests/python/integration/test_ewise.py
+++ b/tests/python/integration/test_ewise.py
@@ -40,6 +40,9 @@ def test_exp():
         if not tvm.testing.device_enabled(host):
             return
         dev = tvm.device(device, 0)
+        if not tvm.testing.device_enabled(device):
+            print("skip because %s is not enabled.." % device)
+            return
         fexp = tvm.build(s, [A, B], device, host, name="myexp")
         dev = tvm.device(device, 0)
         # launch the kernel.