This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3e8fe75380 [ROCm] Minor fixes for latest refactor (#18225)
3e8fe75380 is described below
commit 3e8fe753800d55e843c441f302a6a90b26dd22b9
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Aug 23 19:25:31 2025 -0400
[ROCm] Minor fixes for latest refactor (#18225)
This PR fixes a few ROCm and hipBLAS build issues after recent
refactors.
---
src/runtime/contrib/hipblas/hipblas_json_runtime.cc | 4 +++-
src/runtime/contrib/hipblas/hipblas_utils.cc | 5 +++--
src/runtime/contrib/hipblas/hipblas_utils.h | 4 ++--
3 files changed, 8 insertions(+), 5 deletions(-)
diff --git a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc
b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc
index ab8545561b..08866fc108 100644
--- a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc
+++ b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc
@@ -22,6 +22,7 @@
* \brief A simple JSON runtime for HIPBLAS.
*/
+#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/ndarray.h>
@@ -30,6 +31,7 @@
#include <string>
#include <vector>
+#include "../../rocm/rocm_common.h"
#include "../json/json_node.h"
#include "../json/json_runtime.h"
#include "hipblas_utils.h"
@@ -86,7 +88,7 @@ class HipblasJSONRuntime : public JSONRuntimeBase {
if (device_id == -1) {
ROCM_CALL(hipGetDevice(&device_id));
}
- auto* entry_ptr =
tvm::contrib::HipBlasLtThreadEntry::ThreadLocal(Device(kDLROCM, device_id));
+ auto* entry_ptr =
tvm::contrib::HipBlasLtThreadEntry::ThreadLocal(DLDevice{kDLROCM, device_id});
hipStream_t stream =
static_cast<hipStream_t>(TVMFFIEnvGetCurrentStream(kDLROCM, device_id));
auto get_input = [this, &dl_tensors](const JSONGraphNode& node, int idx) {
diff --git a/src/runtime/contrib/hipblas/hipblas_utils.cc
b/src/runtime/contrib/hipblas/hipblas_utils.cc
index 454ab7a370..1b61cbd382 100644
--- a/src/runtime/contrib/hipblas/hipblas_utils.cc
+++ b/src/runtime/contrib/hipblas/hipblas_utils.cc
@@ -23,6 +23,7 @@
#include "hipblas_utils.h"
#include <dmlc/thread_local.h>
+#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/function.h>
#include "../../rocm/rocm_common.h"
@@ -41,7 +42,7 @@ HipBlasThreadEntry::~HipBlasThreadEntry() {
typedef dmlc::ThreadLocalStore<HipBlasThreadEntry> HipBlasThreadStore;
-HipBlasThreadEntry* HipBlasThreadEntry::ThreadLocal(Device curr_device) {
+HipBlasThreadEntry* HipBlasThreadEntry::ThreadLocal(DLDevice curr_device) {
HipBlasThreadEntry* retval = HipBlasThreadStore::Get();
TVMFFIStreamHandle stream =
TVMFFIEnvGetCurrentStream(curr_device.device_type,
curr_device.device_id);
@@ -72,7 +73,7 @@ HipBlasLtThreadEntry::~HipBlasLtThreadEntry() {
typedef dmlc::ThreadLocalStore<HipBlasLtThreadEntry> HipBlasLtThreadStore;
-HipBlasLtThreadEntry* HipBlasLtThreadEntry::ThreadLocal(Device curr_device) {
+HipBlasLtThreadEntry* HipBlasLtThreadEntry::ThreadLocal(DLDevice curr_device) {
return HipBlasLtThreadStore::Get();
}
diff --git a/src/runtime/contrib/hipblas/hipblas_utils.h
b/src/runtime/contrib/hipblas/hipblas_utils.h
index 66d7afafbd..d07e825c21 100644
--- a/src/runtime/contrib/hipblas/hipblas_utils.h
+++ b/src/runtime/contrib/hipblas/hipblas_utils.h
@@ -68,7 +68,7 @@ struct HipBlasThreadEntry {
HipBlasThreadEntry();
~HipBlasThreadEntry();
hipblasHandle_t handle{nullptr};
- static HipBlasThreadEntry* ThreadLocal();
+ static HipBlasThreadEntry* ThreadLocal(DLDevice curr_device);
}; // HipBlasThreadEntry
struct HipBlasLtThreadEntry {
@@ -82,7 +82,7 @@ struct HipBlasLtThreadEntry {
// https://docs.nvidia.com/cuda/cublas/index.html#cublassetworkspace.
static constexpr const size_t workspace_size = 33554432;
- static HipBlasLtThreadEntry* ThreadLocal();
+ static HipBlasLtThreadEntry* ThreadLocal(DLDevice curr_device);
}; // HipBlasLtThreadEntry
inline hipDataType GetHipDataType(DLDataType type) {