You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/08/02 11:26:52 UTC

[tvm] branch main updated: [ROCm] Fix some ROCm codegen bugs (#15454)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bab295e409 [ROCm] Fix some ROCm codegen bugs (#15454)
bab295e409 is described below

commit bab295e4096f4a2e7a7f220a5e4d77f322101412
Author: Bohan Hou <bo...@andrew.cmu.edu>
AuthorDate: Wed Aug 2 04:26:45 2023 -0700

    [ROCm] Fix some ROCm codegen bugs (#15454)
    
    * rocm bug fix:Module hip should be either dso exportable or binary serializable
    
    rocm bug fix: llvm.amdgcn.ds.bpermute Intrinsic has incorrect return type
    
    rocm bug fix:ptr addrspace(3) @shmem Global is external, but doesn't have external or weak linkage
    
    Co-authored-by: zhangxiao-stack <12...@qq.com>
    
    * lint
    
    ---------
    
    Co-authored-by: zhangxiao-stack <zh...@sugon.com>
    Co-authored-by: zhangxiao-stack <12...@qq.com>
---
 src/runtime/rocm/rocm_module.cc              | 4 +++-
 src/target/llvm/codegen_llvm.cc              | 4 ++--
 src/tir/transforms/lower_thread_allreduce.cc | 2 +-
 3 files changed, 6 insertions(+), 4 deletions(-)

diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc
index cf3530c0af..9acd1ca903 100644
--- a/src/runtime/rocm/rocm_module.cc
+++ b/src/runtime/rocm/rocm_module.cc
@@ -63,7 +63,9 @@ class ROCMModuleNode : public runtime::ModuleNode {
   }
 
   const char* type_key() const final { return "hip"; }
-
+  int GetPropertyMask() const final {
+    return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable;
+  }
   PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final;
 
   void SaveToFile(const String& file_name, const String& format) final {
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 67c81d2803..02d203b7e9 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -702,8 +702,8 @@ llvm::GlobalVariable* CodeGenLLVM::AllocateSharedMemory(DataType dtype, size_t s
                                                         llvm::GlobalValue::LinkageTypes linkage) {
   llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(dtype), size);
   llvm::GlobalVariable* global =
-      new llvm::GlobalVariable(*module_, type, false, linkage, nullptr, "shmem", nullptr,
-                               llvm::GlobalValue::NotThreadLocal, shared_address_space);
+      new llvm::GlobalVariable(*module_, type, false, linkage, llvm::UndefValue::get(type), "shmem",
+                               nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space);
 #if TVM_LLVM_VERSION >= 100
   global->setAlignment(llvm::Align(alignment));
 #else
diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc
index fba62a0c18..abc288f0eb 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -729,7 +729,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     // rocm only supports 32 bit operands for shuffling at the moment
     if ((target_->kind->name == "rocm") &&
         (std::any_of(types.begin(), types.end(), [](DataType ty) {
-          if (ty.is_vector()) return true;
+          if ((ty.is_vector()) || !ty.is_int()) return true;
           return ty.bits() != 32;
         }))) {
       return false;