This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3e8fe75380 [ROCm] Minor fixes for latest refactor (#18225)
3e8fe75380 is described below

commit 3e8fe753800d55e843c441f302a6a90b26dd22b9
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Aug 23 19:25:31 2025 -0400

    [ROCm] Minor fixes for latest refactor (#18225)
    
    This PR fixes a few ROCm and hipBLAS build issues after recent
    refactors.
---
 src/runtime/contrib/hipblas/hipblas_json_runtime.cc | 4 +++-
 src/runtime/contrib/hipblas/hipblas_utils.cc        | 5 +++--
 src/runtime/contrib/hipblas/hipblas_utils.h         | 4 ++--
 3 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc 
b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc
index ab8545561b..08866fc108 100644
--- a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc
+++ b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc
@@ -22,6 +22,7 @@
  * \brief A simple JSON runtime for HIPBLAS.
  */
 
+#include <tvm/ffi/extra/c_env_api.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/runtime/ndarray.h>
@@ -30,6 +31,7 @@
 #include <string>
 #include <vector>
 
+#include "../../rocm/rocm_common.h"
 #include "../json/json_node.h"
 #include "../json/json_runtime.h"
 #include "hipblas_utils.h"
@@ -86,7 +88,7 @@ class HipblasJSONRuntime : public JSONRuntimeBase {
     if (device_id == -1) {
       ROCM_CALL(hipGetDevice(&device_id));
     }
-    auto* entry_ptr = 
tvm::contrib::HipBlasLtThreadEntry::ThreadLocal(Device(kDLROCM, device_id));
+    auto* entry_ptr = 
tvm::contrib::HipBlasLtThreadEntry::ThreadLocal(DLDevice{kDLROCM, device_id});
     hipStream_t stream = 
static_cast<hipStream_t>(TVMFFIEnvGetCurrentStream(kDLROCM, device_id));
 
     auto get_input = [this, &dl_tensors](const JSONGraphNode& node, int idx) {
diff --git a/src/runtime/contrib/hipblas/hipblas_utils.cc 
b/src/runtime/contrib/hipblas/hipblas_utils.cc
index 454ab7a370..1b61cbd382 100644
--- a/src/runtime/contrib/hipblas/hipblas_utils.cc
+++ b/src/runtime/contrib/hipblas/hipblas_utils.cc
@@ -23,6 +23,7 @@
 #include "hipblas_utils.h"
 
 #include <dmlc/thread_local.h>
+#include <tvm/ffi/extra/c_env_api.h>
 #include <tvm/ffi/function.h>
 
 #include "../../rocm/rocm_common.h"
@@ -41,7 +42,7 @@ HipBlasThreadEntry::~HipBlasThreadEntry() {
 
 typedef dmlc::ThreadLocalStore<HipBlasThreadEntry> HipBlasThreadStore;
 
-HipBlasThreadEntry* HipBlasThreadEntry::ThreadLocal(Device curr_device) {
+HipBlasThreadEntry* HipBlasThreadEntry::ThreadLocal(DLDevice curr_device) {
   HipBlasThreadEntry* retval = HipBlasThreadStore::Get();
   TVMFFIStreamHandle stream =
       TVMFFIEnvGetCurrentStream(curr_device.device_type, 
curr_device.device_id);
@@ -72,7 +73,7 @@ HipBlasLtThreadEntry::~HipBlasLtThreadEntry() {
 
 typedef dmlc::ThreadLocalStore<HipBlasLtThreadEntry> HipBlasLtThreadStore;
 
-HipBlasLtThreadEntry* HipBlasLtThreadEntry::ThreadLocal(Device curr_device) {
+HipBlasLtThreadEntry* HipBlasLtThreadEntry::ThreadLocal(DLDevice curr_device) {
   return HipBlasLtThreadStore::Get();
 }
 
diff --git a/src/runtime/contrib/hipblas/hipblas_utils.h 
b/src/runtime/contrib/hipblas/hipblas_utils.h
index 66d7afafbd..d07e825c21 100644
--- a/src/runtime/contrib/hipblas/hipblas_utils.h
+++ b/src/runtime/contrib/hipblas/hipblas_utils.h
@@ -68,7 +68,7 @@ struct HipBlasThreadEntry {
   HipBlasThreadEntry();
   ~HipBlasThreadEntry();
   hipblasHandle_t handle{nullptr};
-  static HipBlasThreadEntry* ThreadLocal();
+  static HipBlasThreadEntry* ThreadLocal(DLDevice curr_device);
 };  // HipBlasThreadEntry
 
 struct HipBlasLtThreadEntry {
@@ -82,7 +82,7 @@ struct HipBlasLtThreadEntry {
   // https://docs.nvidia.com/cuda/cublas/index.html#cublassetworkspace.
   static constexpr const size_t workspace_size = 33554432;
 
-  static HipBlasLtThreadEntry* ThreadLocal();
+  static HipBlasLtThreadEntry* ThreadLocal(DLDevice curr_device);
 };  // HipBlasLtThreadEntry
 
 inline hipDataType GetHipDataType(DLDataType type) {

Reply via email to