This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bab295e409 [ROCm] Fix some ROCm codegen bugs (#15454)
bab295e409 is described below
commit bab295e4096f4a2e7a7f220a5e4d77f322101412
Author: Bohan Hou <[email protected]>
AuthorDate: Wed Aug 2 04:26:45 2023 -0700
[ROCm] Fix some ROCm codegen bugs (#15454)
* rocm bug fix:Module hip should be either dso exportable or binary
serializable
rocm bug fix: llvm.amdgcn.ds.bpermute Intrinsic has incorrect return type
rocm bug fix:ptr addrspace(3) @shmem Global is external, but doesn't have
external or weak linkage
Co-authored-by: zhangxiao-stack <[email protected]>
* lint
---------
Co-authored-by: zhangxiao-stack <[email protected]>
Co-authored-by: zhangxiao-stack <[email protected]>
---
src/runtime/rocm/rocm_module.cc | 4 +++-
src/target/llvm/codegen_llvm.cc | 4 ++--
src/tir/transforms/lower_thread_allreduce.cc | 2 +-
3 files changed, 6 insertions(+), 4 deletions(-)
diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc
index cf3530c0af..9acd1ca903 100644
--- a/src/runtime/rocm/rocm_module.cc
+++ b/src/runtime/rocm/rocm_module.cc
@@ -63,7 +63,9 @@ class ROCMModuleNode : public runtime::ModuleNode {
}
const char* type_key() const final { return "hip"; }
-
+ int GetPropertyMask() const final {
+ return ModulePropertyMask::kBinarySerializable |
ModulePropertyMask::kRunnable;
+ }
PackedFunc GetFunction(const String& name, const ObjectPtr<Object>&
sptr_to_self) final;
void SaveToFile(const String& file_name, const String& format) final {
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 67c81d2803..02d203b7e9 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -702,8 +702,8 @@ llvm::GlobalVariable*
CodeGenLLVM::AllocateSharedMemory(DataType dtype, size_t s
llvm::GlobalValue::LinkageTypes linkage) {
llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(dtype), size);
llvm::GlobalVariable* global =
- new llvm::GlobalVariable(*module_, type, false, linkage, nullptr,
"shmem", nullptr,
- llvm::GlobalValue::NotThreadLocal,
shared_address_space);
+ new llvm::GlobalVariable(*module_, type, false, linkage,
llvm::UndefValue::get(type), "shmem",
+ nullptr, llvm::GlobalValue::NotThreadLocal,
shared_address_space);
#if TVM_LLVM_VERSION >= 100
global->setAlignment(llvm::Align(alignment));
#else
diff --git a/src/tir/transforms/lower_thread_allreduce.cc
b/src/tir/transforms/lower_thread_allreduce.cc
index fba62a0c18..abc288f0eb 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -729,7 +729,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator
{
// rocm only supports 32 bit operands for shuffling at the moment
if ((target_->kind->name == "rocm") &&
(std::any_of(types.begin(), types.end(), [](DataType ty) {
- if (ty.is_vector()) return true;
+ if ((ty.is_vector()) || !ty.is_int()) return true;
return ty.bits() != 32;
}))) {
return false;