You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2021/03/04 22:04:21 UTC

[tvm] branch vk-i64 created (now 15189ca)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a change to branch vk-i64
in repository https://gitbox.apache.org/repos/asf/tvm.git.


      at 15189ca  add test

This branch includes the following new commits:

     new ad07502  introduce ArgUnion64
     new 3a1daca  add missing intrinsic
     new 1bea761  test cumsum on vulkan
     new ef1ed2d  update metal runtime to use ArgUnion64 (not tested)
     new acef292  test get_valid_counts on vulkan
     new 100abfe  formatting
     new 62889b0  pytest fix
     new 341d346  ArgUnion -> ArgUnion32
     new f7d19e9  Update metal codegen for ArgUnion64
     new 47d45b2  Add explici 64-bit support in metal codegen
     new 15189ca  add test

The 11 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.



[tvm] 11/11: add test

Posted by wu...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch vk-i64
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 15189caf985e2edcb56b679497cd265f298395d2
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Thu Mar 4 17:03:35 2021 -0500

    add test
---
 tests/python/topi/python/test_topi_cumsum.py | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/tests/python/topi/python/test_topi_cumsum.py b/tests/python/topi/python/test_topi_cumsum.py
index 6b99239..79330e7 100644
--- a/tests/python/topi/python/test_topi_cumsum.py
+++ b/tests/python/topi/python/test_topi_cumsum.py
@@ -29,6 +29,7 @@ def test_cumsum(ctx, target):
             "cuda": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
             "nvptx": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
             "vulkan": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
+            "metal": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
         }
         fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations)
         tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, fcompute, fschedule)
@@ -47,6 +48,9 @@ def test_cumsum(ctx, target):
         check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32")
 
     for in_dtype in ["float32", "float64"]:
+        if str(target.kind) == 'metal' and in_dtype == 'float64':
+            # float64 is not supported in metal
+            continue
         data = np.random.randn(10, 10).astype(in_dtype)
         check_cumsum(np.cumsum(data), data)
         check_cumsum(np.cumsum(data, axis=0), data, axis=0)
@@ -74,3 +78,4 @@ if __name__ == "__main__":
     test_cumsum(tvm.context("cuda"), tvm.target.Target("cuda"))
     test_cumsum(tvm.context("nvptx"), tvm.target.Target("nvptx"))
     test_cumsum(tvm.context("vulkan"), tvm.target.Target("vulkan"))
+    test_cumsum(tvm.context("metal"), tvm.target.Target("metal"))


[tvm] 02/11: add missing intrinsic

Posted by wu...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch vk-i64
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 3a1dacada52e40960c03c27c4a6e47b3ef14167a
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Wed Mar 3 07:38:02 2021 +0900

    add missing intrinsic
---
 src/target/spirv/intrin_rule_spirv.cc | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc
index 90b2eb2..b75fb53 100644
--- a/src/target/spirv/intrin_rule_spirv.cc
+++ b/src/target/spirv/intrin_rule_spirv.cc
@@ -62,8 +62,14 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.fabs").set_body(DispatchGLSLPureIntr
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp").set_body(DispatchGLSLPureIntrin<GLSLstd450Exp>);
 
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sin").set_body(DispatchGLSLPureIntrin<GLSLstd450Sin>);
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.cos").set_body(DispatchGLSLPureIntrin<GLSLstd450Cos>);
+
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log").set_body(DispatchGLSLPureIntrin<GLSLstd450Log>);
 
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log2").set_body(DispatchGLSLPureIntrin<GLSLstd450Log2>);
+
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt").set_body(DispatchGLSLPureIntrin<GLSLstd450Sqrt>);
 
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow").set_body(DispatchGLSLPureIntrin<GLSLstd450Pow>);


[tvm] 08/11: ArgUnion -> ArgUnion32

Posted by wu...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch vk-i64
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 341d3468464489363d270e298adbb2344cd2d609
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Thu Mar 4 04:35:20 2021 +0900

    ArgUnion -> ArgUnion32
---
 src/runtime/pack_args.h | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h
index 2e7a881..7c852da 100644
--- a/src/runtime/pack_args.h
+++ b/src/runtime/pack_args.h
@@ -41,7 +41,7 @@ namespace runtime {
 /*!
  * \brief argument union type of 32bit.
  */
-union ArgUnion {
+union ArgUnion32 {
   int32_t v_int32;
   uint32_t v_uint32;
   float v_float32;
@@ -151,9 +151,9 @@ inline PackedFunc PackFuncVoidAddr_(F f, const std::vector<ArgConvertCode>& code
   int num_args = static_cast<int>(codes.size());
   auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) {
     TempArray<void*, N> addr_(num_args);
-    TempArray<ArgUnion, N> holder_(num_args);
+    TempArray<ArgUnion32, N> holder_(num_args);
     void** addr = addr_.data();
-    ArgUnion* holder = holder_.data();
+    ArgUnion32* holder = holder_.data();
     for (int i = 0; i < num_args; ++i) {
       switch (codes[i]) {
         case INT64_TO_INT64:


[tvm] 03/11: test cumsum on vulkan

Posted by wu...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch vk-i64
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 1bea761d9423d1a79f05caaa9d043a106be2bfe4
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Wed Mar 3 08:18:24 2021 +0900

    test cumsum on vulkan
---
 tests/python/topi/python/test_topi_cumsum.py | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/tests/python/topi/python/test_topi_cumsum.py b/tests/python/topi/python/test_topi_cumsum.py
index a01a496..bf962d9 100644
--- a/tests/python/topi/python/test_topi_cumsum.py
+++ b/tests/python/topi/python/test_topi_cumsum.py
@@ -28,6 +28,7 @@ def test_cumsum(ctx, target):
             "generic": (lambda x: topi.cumsum(x, axis, dtype), topi.generic.schedule_extern),
             "cuda": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
             "nvptx": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
+            "vulkan": (lambda x: topi.cuda.cumsum(x, axis, dtype), topi.cuda.schedule_scan),
         }
         fcompute, fschedule = tvm.topi.testing.dispatch(target, implementations)
         tvm.topi.testing.compare_numpy_tvm([data], np_ref, target, ctx, fcompute, fschedule)
@@ -40,8 +41,10 @@ def test_cumsum(ctx, target):
     check_cumsum(np.cumsum(data, dtype=np.int32), data)
     check_cumsum(np.cumsum(data), data, dtype="int64")
 
-    data = np.random.rand(10) > 0.5
-    check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32")
+    if str(target.kind) != "vulkan":
+        # TODO(masahi): Support bool tensor in SPIRV codegen
+        data = np.random.rand(10) > 0.5
+        check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32")
 
     for in_dtype in ["float32", "float64"]:
         data = np.random.randn(10, 10).astype(in_dtype)
@@ -70,3 +73,4 @@ if __name__ == "__main__":
     test_cumsum(tvm.context("cpu"), tvm.target.Target("llvm"))
     test_cumsum(tvm.context("cuda"), tvm.target.Target("cuda"))
     test_cumsum(tvm.context("nvptx"), tvm.target.Target("nvptx"))
+    test_cumsum(tvm.context("vulkan"), tvm.target.Target("vulkan"))


[tvm] 09/11: Update metal codegen for ArgUnion64

Posted by wu...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch vk-i64
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit f7d19e99604368c3e29a99679565575eab1d7c77
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Thu Mar 4 16:14:42 2021 -0500

    Update metal codegen for ArgUnion64
---
 src/target/source/codegen_metal.cc | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc
index baa3006..f7219cb 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -47,7 +47,7 @@ CodeGenMetal::CodeGenMetal() {
   decl_stream << "#include <metal_stdlib>\n";
   decl_stream << "using namespace metal;\n\n";
   decl_stream << "union __TVMArgUnion {\n"
-              << " int v_int;\n"
+              << " int v_long;\n"
               << "};\n\n";
 }
 
@@ -104,8 +104,8 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
       if (v.dtype().bits() == 32) {
         decl_stream << "  ";
         PrintType(v.dtype(), decl_stream);
-        decl_stream << " " << vid << ";\n";
-        vref << varg << "." << vid;
+        decl_stream << " " << vid << "[2];\n";
+        vref << varg << "." << vid << "[0]";
       } else {
         // For non 32bit type, ref through arg union.
         decl_stream << "  __TVMArgUnion " << vid << ";\n";


[tvm] 05/11: test get_valid_counts on vulkan

Posted by wu...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch vk-i64
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit acef292c42069abb94e09234dab701cc84271636
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Wed Mar 3 08:25:08 2021 +0900

    test get_valid_counts on vulkan
---
 tests/python/topi/python/test_topi_vision.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tests/python/topi/python/test_topi_vision.py b/tests/python/topi/python/test_topi_vision.py
index 8393568..2fdf3cf 100644
--- a/tests/python/topi/python/test_topi_vision.py
+++ b/tests/python/topi/python/test_topi_vision.py
@@ -112,7 +112,7 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index):
         tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)
         tvm.testing.assert_allclose(tvm_out3.asnumpy(), np_out3, rtol=1e-3)
 
-    for device in ["llvm", "cuda", "opencl"]:
+    for device in ["llvm", "cuda", "opencl", "vulkan"]:
         check_device(device)
 
 


[tvm] 01/11: introduce ArgUnion64

Posted by wu...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch vk-i64
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit ad075025a61bed23df0bf5f6c78853b5204725f4
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Wed Mar 3 07:37:52 2021 +0900

    introduce ArgUnion64
---
 src/runtime/pack_args.h      | 26 +++++++++++++++++++-------
 src/runtime/vulkan/vulkan.cc | 12 ++++++------
 2 files changed, 25 insertions(+), 13 deletions(-)

diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h
index 45cde22..54a75d6 100644
--- a/src/runtime/pack_args.h
+++ b/src/runtime/pack_args.h
@@ -47,6 +47,15 @@ union ArgUnion {
   uint32_t v_uint32;
   float v_float32;
 };
+
+union ArgUnion64 {
+  int32_t v_int32[2];
+  uint32_t v_uint32[2];
+  float v_float32[2];
+  int64_t v_int64;
+  uint64_t v_uint64;
+  double v_float64;
+};
 /*!
  * \brief Create a packed function from void addr types.
  *
@@ -177,25 +186,28 @@ template <int N, typename F>
 inline PackedFunc PackFuncNonBufferArg_(F f, int base, const std::vector<ArgConvertCode>& codes) {
   int num_args = static_cast<int>(codes.size());
   auto ret = [f, codes, base, num_args](TVMArgs args, TVMRetValue* ret) {
-    TempArray<ArgUnion, N> holder_(num_args);
-    ArgUnion* holder = holder_.data();
+    TempArray<ArgUnion64, N> holder_(num_args);
+    ArgUnion64* holder = holder_.data();
     for (int i = 0; i < num_args; ++i) {
       switch (codes[i]) {
-        case INT64_TO_INT64:
+        case INT64_TO_INT64: {
+          holder[i].v_int64 = args.values[base + i].v_int64;
+          break;
+        }
         case FLOAT64_TO_FLOAT64: {
-          LOG(FATAL) << "Do not support 64bit argument to device function";
+          holder[i].v_float64 = args.values[base + i].v_float64;
           break;
         }
         case INT64_TO_INT32: {
-          holder[i].v_int32 = static_cast<int32_t>(args.values[base + i].v_int64);
+          holder[i].v_int32[0] = static_cast<int32_t>(args.values[base + i].v_int64);
           break;
         }
         case INT64_TO_UINT32: {
-          holder[i].v_uint32 = static_cast<uint32_t>(args.values[base + i].v_int64);
+          holder[i].v_uint32[0] = static_cast<uint32_t>(args.values[base + i].v_int64);
           break;
         }
         case FLOAT64_TO_FLOAT32: {
-          holder[i].v_float32 = static_cast<float>(args.values[base + i].v_float64);
+          holder[i].v_float32[0] = static_cast<float>(args.values[base + i].v_float64);
           break;
         }
         case HANDLE_TO_HANDLE: {
diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc
index f40fd80..4eb3481 100644
--- a/src/runtime/vulkan/vulkan.cc
+++ b/src/runtime/vulkan/vulkan.cc
@@ -711,7 +711,7 @@ class VulkanWrappedFunc {
     thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags);
   }
 
-  void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const;
+  void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const;
 
  private:
   // internal module
@@ -875,7 +875,7 @@ class VulkanModuleNode final : public runtime::ModuleNode {
     VkPushConstantRange crange;
     crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
     crange.offset = 0;
-    crange.size = sizeof(ArgUnion) * num_pack_args;
+    crange.size = sizeof(ArgUnion64) * num_pack_args;
 
     VkPipelineLayoutCreateInfo playout_cinfo;
     playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
@@ -1046,7 +1046,7 @@ VulkanStream* VulkanThreadEntry::Stream(size_t device_id) {
   return streams_[device_id].get();
 }
 
-void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const {
+void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const {
   int device_id = VulkanThreadEntry::ThreadLocal()->ctx.device_id;
   ICHECK_LT(device_id, kVulkanMaxNumDevice);
   const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
@@ -1075,7 +1075,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion
           descriptor_buffers.data());
       if (num_pack_args_ != 0) {
         vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout,
-                           VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion),
+                           VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64),
                            pack_args);
       }
       vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
@@ -1093,7 +1093,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion
   }
 
   // Otherwise, the more expensive deferred path.
-  std::vector<ArgUnion> pack_args_storage(pack_args, pack_args + num_pack_args_);
+  std::vector<ArgUnion64> pack_args_storage(pack_args, pack_args + num_pack_args_);
   const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() {
     std::vector<VkWriteDescriptorSet> write_descriptor_sets;
     write_descriptor_sets.resize(descriptor_buffers.size());
@@ -1119,7 +1119,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion
                             nullptr);
     if (pack_args_storage.size() != 0) {
       vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT,
-                         0, pack_args_storage.size() * sizeof(ArgUnion), pack_args_storage.data());
+                         0, pack_args_storage.size() * sizeof(ArgUnion64), pack_args_storage.data());
     }
     vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
     VkMemoryBarrier barrier_info;


[tvm] 07/11: pytest fix

Posted by wu...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch vk-i64
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 62889b04edb5b97b5269d2f7d5715de56319d846
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Wed Mar 3 18:14:23 2021 +0900

    pytest fix
---
 tests/python/topi/python/test_topi_cumsum.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tests/python/topi/python/test_topi_cumsum.py b/tests/python/topi/python/test_topi_cumsum.py
index bf962d9..6b99239 100644
--- a/tests/python/topi/python/test_topi_cumsum.py
+++ b/tests/python/topi/python/test_topi_cumsum.py
@@ -41,7 +41,7 @@ def test_cumsum(ctx, target):
     check_cumsum(np.cumsum(data, dtype=np.int32), data)
     check_cumsum(np.cumsum(data), data, dtype="int64")
 
-    if str(target.kind) != "vulkan":
+    if target != "vulkan":
         # TODO(masahi): Support bool tensor in SPIRV codegen
         data = np.random.rand(10) > 0.5
         check_cumsum(np.cumsum(data, dtype=np.int32), data, dtype="int32")


[tvm] 06/11: formatting

Posted by wu...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch vk-i64
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 100abfec25cc62295e564041cf3b8e3c37ded389
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Wed Mar 3 09:50:43 2021 +0900

    formatting
---
 src/runtime/vulkan/vulkan.cc | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc
index 4eb3481..794f3c5 100644
--- a/src/runtime/vulkan/vulkan.cc
+++ b/src/runtime/vulkan/vulkan.cc
@@ -1046,7 +1046,8 @@ VulkanStream* VulkanThreadEntry::Stream(size_t device_id) {
   return streams_[device_id].get();
 }
 
-void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const {
+void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
+                                   const ArgUnion64* pack_args) const {
   int device_id = VulkanThreadEntry::ThreadLocal()->ctx.device_id;
   ICHECK_LT(device_id, kVulkanMaxNumDevice);
   const auto& vctx = VulkanDeviceAPI::Global()->context(device_id);
@@ -1119,7 +1120,8 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion
                             nullptr);
     if (pack_args_storage.size() != 0) {
       vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT,
-                         0, pack_args_storage.size() * sizeof(ArgUnion64), pack_args_storage.data());
+                         0, pack_args_storage.size() * sizeof(ArgUnion64),
+                         pack_args_storage.data());
     }
     vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2));
     VkMemoryBarrier barrier_info;


[tvm] 10/11: Add explici 64-bit support in metal codegen

Posted by wu...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch vk-i64
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 47d45b28b19aba7f6e4eb201b1d9f0bec220f2c5
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Thu Mar 4 17:01:42 2021 -0500

    Add explici 64-bit support in metal codegen
---
 src/target/source/codegen_metal.cc | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc
index f7219cb..c95d578 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -47,7 +47,7 @@ CodeGenMetal::CodeGenMetal() {
   decl_stream << "#include <metal_stdlib>\n";
   decl_stream << "using namespace metal;\n\n";
   decl_stream << "union __TVMArgUnion {\n"
-              << " int v_long;\n"
+              << " int v_int[2];\n"
               << "};\n\n";
 }
 
@@ -106,6 +106,11 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
         PrintType(v.dtype(), decl_stream);
         decl_stream << " " << vid << "[2];\n";
         vref << varg << "." << vid << "[0]";
+      } else if (v.dtype().bits() == 64) {
+        decl_stream << "  ";
+        PrintType(v.dtype(), decl_stream);
+        decl_stream << " " << vid << ";\n";
+        vref << varg << "." << vid;
       } else {
         // For non 32bit type, ref through arg union.
         decl_stream << "  __TVMArgUnion " << vid << ";\n";


[tvm] 04/11: update metal runtime to use ArgUnion64 (not tested)

Posted by wu...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch vk-i64
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit ef1ed2d732000ba2c7bf68bd6a1342ca3c98e0df
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Wed Mar 3 08:23:16 2021 +0900

    update metal runtime to use ArgUnion64 (not tested)
---
 src/runtime/metal/metal_module.mm | 4 ++--
 src/runtime/pack_args.h           | 4 +++-
 2 files changed, 5 insertions(+), 3 deletions(-)

diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm
index 981dd61..8f1fde8 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -180,7 +180,7 @@ class MetalWrappedFunc {
     scache_[dev_id] = m->GetPipelineState(dev_id, func_name);
   }
   // invoke the function with void arguments
-  void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const {
+  void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const {
     metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
     int device_id = t->context.device_id;
     if (scache_[device_id] == nil) {
@@ -197,7 +197,7 @@ class MetalWrappedFunc {
     }
     if (num_pack_args_ != 0) {
       [encoder setBytes:pack_args
-                 length:num_pack_args_ * sizeof(ArgUnion)
+                 length:num_pack_args_ * sizeof(ArgUnion64)
                 atIndex:num_buffer_args_];
     }
     // launch
diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h
index 54a75d6..2e7a881 100644
--- a/src/runtime/pack_args.h
+++ b/src/runtime/pack_args.h
@@ -40,7 +40,6 @@ namespace tvm {
 namespace runtime {
 /*!
  * \brief argument union type of 32bit.
- * Choose 32 bit because most GPU API do not work well with 64 bit.
  */
 union ArgUnion {
   int32_t v_int32;
@@ -48,6 +47,9 @@ union ArgUnion {
   float v_float32;
 };
 
+/*!
+ * \brief argument union type of 64 bit, for use by Vulkan and Metal runtime.
+ */
 union ArgUnion64 {
   int32_t v_int32[2];
   uint32_t v_uint32[2];