This is an automated email from the ASF dual-hosted git repository.

yaxingcai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1163aa0bcd [FFI][REFACTOR] Establish Stream Context in ffi (#18216)
1163aa0bcd is described below

commit 1163aa0bcd6c19704a679922bcfb34235f8814f9
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Aug 19 18:58:19 2025 -0400

    [FFI][REFACTOR] Establish Stream Context in ffi (#18216)
    
    This PR sets up the stream context in ffi and migrate
    the existing per device API stream context management
    to ffi env API. The new API will help us to streamline
    stream related integration formost libraries.
---
 ffi/CMakeLists.txt                                 |  1 +
 ffi/include/tvm/ffi/extra/c_env_api.h              | 33 +++++++++
 ffi/src/ffi/extra/stream_context.cc                | 81 ++++++++++++++++++++++
 include/tvm/runtime/device_api.h                   |  2 +-
 python/tvm/contrib/cutlass/attention_operation.py  | 12 ++--
 python/tvm/contrib/cutlass/conv2d_operation.py     |  3 +-
 python/tvm/contrib/cutlass/gemm_operation.py       | 13 ++--
 python/tvm/contrib/cutlass/gen_tensor_op.py        |  2 +-
 python/tvm/contrib/cutlass/layer_norm_operation.py |  3 +-
 python/tvm/contrib/cutlass/rms_norm_operation.py   |  3 +-
 src/contrib/msc/plugin/tvm_codegen.cc              |  4 +-
 src/runtime/contrib/cublas/cublas.cc               | 11 +--
 src/runtime/contrib/cublas/cublas_json_runtime.cc  | 16 +++--
 src/runtime/contrib/cublas/cublas_utils.cc         | 12 ++--
 src/runtime/contrib/cublas/cublas_utils.h          |  4 +-
 src/runtime/contrib/cudnn/conv_backward.cc         | 12 ++--
 src/runtime/contrib/cudnn/conv_forward.cc          |  8 ++-
 .../contrib/cudnn/cudnn_frontend/attention.cc      |  8 +--
 src/runtime/contrib/cudnn/cudnn_json_runtime.cc    | 10 ++-
 src/runtime/contrib/cudnn/cudnn_utils.cc           | 16 +++--
 src/runtime/contrib/cudnn/cudnn_utils.h            |  2 +-
 src/runtime/contrib/cudnn/softmax.cc               |  5 +-
 src/runtime/contrib/cutlass/fp16_group_gemm.cuh    |  2 +-
 src/runtime/contrib/cutlass/fp8_gemm.cu            |  7 +-
 src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu |  5 +-
 .../contrib/cutlass/fp8_groupwise_scaled_gemm.cuh  |  9 +--
 .../fp8_groupwise_scaled_group_gemm_sm100.cu       |  4 +-
 src/runtime/contrib/hipblas/hipblas.cc             |  4 +-
 .../contrib/hipblas/hipblas_json_runtime.cc        | 11 +--
 src/runtime/contrib/hipblas/hipblas_utils.cc       |  9 ++-
 src/runtime/contrib/miopen/conv_forward.cc         |  8 ++-
 src/runtime/contrib/miopen/miopen_utils.cc         | 11 ++-
 src/runtime/contrib/miopen/softmax.cc              |  2 +-
 src/runtime/contrib/msc/tensorrt_runtime.cc        |  4 +-
 src/runtime/cuda/cuda_common.h                     |  4 --
 src/runtime/cuda/cuda_device_api.cc                | 30 ++++----
 src/runtime/cuda/cuda_module.cc                    |  3 +-
 src/runtime/cuda/l2_cache_flush.cc                 |  5 +-
 src/runtime/device_api.cc                          |  8 ++-
 src/runtime/metal/metal_common.h                   |  2 -
 src/runtime/metal/metal_device_api.mm              | 11 ---
 src/runtime/rocm/rocm_device_api.cc                | 28 ++++----
 src/runtime/rocm/rocm_module.cc                    |  3 +-
 src/runtime/vm/cuda/cuda_graph_builtin.cc          | 20 ++++--
 src/runtime/vulkan/vulkan_device_api.cc            |  6 --
 src/runtime/vulkan/vulkan_device_api.h             |  2 -
 web/emcc/webgpu_runtime.cc                         |  4 --
 47 files changed, 306 insertions(+), 157 deletions(-)

diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt
index ce4f4d4e20..b5823b76a7 100644
--- a/ffi/CMakeLists.txt
+++ b/ffi/CMakeLists.txt
@@ -73,6 +73,7 @@ if (TVM_FFI_USE_EXTRA_CXX_API)
     "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module.cc"
     "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_system_lib.cc"
     "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_dynamic_lib.cc"
+    "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/stream_context.cc"
   )
 endif()
 
diff --git a/ffi/include/tvm/ffi/extra/c_env_api.h 
b/ffi/include/tvm/ffi/extra/c_env_api.h
index 5d5d908f78..1211ab0eeb 100644
--- a/ffi/include/tvm/ffi/extra/c_env_api.h
+++ b/ffi/include/tvm/ffi/extra/c_env_api.h
@@ -29,6 +29,39 @@
 extern "C" {
 #endif
 
+// ----------------------------------------------------------------------------
+// Stream context
+// Focusing on minimalistic thread-local context recording stream being used.
+// We explicitly not handle allocation/de-allocation of stream here.
+// ----------------------------------------------------------------------------
+typedef void* TVMFFIStreamHandle;
+
+/*!
+ * \brief FFI function to set the current stream for a device
+ *
+ * \param device_type The type of the device.
+ * \param device_id The id of the device.
+ * \param stream The stream to set.
+ * \param opt_out_original_stream Output original stream if the address is not 
nullptr.
+ * \note The stream is a weak reference that is cached/owned by the module.
+ * \return 0 when success, nonzero when failure happens
+ */
+TVM_FFI_DLL int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id,
+                                   TVMFFIStreamHandle stream,
+                                   TVMFFIStreamHandle* 
opt_out_original_stream);
+
+/*!
+ * \brief FFI function to get the current stream for a device
+ *
+ * \param device_type The type of the device.
+ * \param device_id The id of the device.
+ * \return The current stream of the device.
+ */
+TVM_FFI_DLL TVMFFIStreamHandle TVMFFIEnvGetCurrentStream(int32_t device_type, 
int32_t device_id);
+
+// ----------------------------------------------------------------------------
+// Module symbol management
+// ----------------------------------------------------------------------------
 /*!
  * \brief FFI function to lookup a function from a module's imports.
  *
diff --git a/ffi/src/ffi/extra/stream_context.cc 
b/ffi/src/ffi/extra/stream_context.cc
new file mode 100644
index 0000000000..d063efdef5
--- /dev/null
+++ b/ffi/src/ffi/extra/stream_context.cc
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+ * \file src/ffi/extra/stream_context.cc
+ *
+ * \brief A minimalistic stream context based on ffi values.
+ */
+
+#include <tvm/ffi/extra/c_env_api.h>
+#include <tvm/ffi/function.h>
+
+#include <vector>
+
+namespace tvm {
+namespace ffi {
+
+class StreamContext {
+ public:
+  void SetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle 
stream,
+                 TVMFFIStreamHandle* out_original_stream) {
+    if (static_cast<size_t>(device_type) >= stream_table_.size()) {
+      stream_table_.resize(device_type + 1);
+    }
+    if (static_cast<size_t>(device_id) >= stream_table_[device_type].size()) {
+      stream_table_[device_type].resize(device_id + 1, nullptr);
+    }
+    if (out_original_stream != nullptr) {
+      *out_original_stream = stream_table_[device_type][device_id];
+    }
+    stream_table_[device_type][device_id] = stream;
+  }
+
+  TVMFFIStreamHandle GetStream(int32_t device_type, int32_t device_id) {
+    if (static_cast<size_t>(device_type) < stream_table_.size() &&
+        static_cast<size_t>(device_id) < stream_table_[device_type].size()) {
+      return stream_table_[device_type][device_id];
+    }
+    return nullptr;
+  }
+
+  static StreamContext* ThreadLocal() {
+    static thread_local StreamContext inst;
+    return &inst;
+  }
+
+ private:
+  std::vector<std::vector<TVMFFIStreamHandle>> stream_table_;
+};
+
+}  // namespace ffi
+}  // namespace tvm
+
+int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, 
TVMFFIStreamHandle stream,
+                       TVMFFIStreamHandle* out_original_stream) {
+  TVM_FFI_SAFE_CALL_BEGIN();
+  tvm::ffi::StreamContext::ThreadLocal()->SetStream(device_type, device_id, 
stream,
+                                                    out_original_stream);
+  TVM_FFI_SAFE_CALL_END();
+}
+
+TVMFFIStreamHandle TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t 
device_id) {
+  TVM_FFI_LOG_EXCEPTION_CALL_BEGIN();
+  return tvm::ffi::StreamContext::ThreadLocal()->GetStream(device_type, 
device_id);
+  TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIEnvGetCurrentStream);
+}
diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h
index 7366b9895d..f14b22c576 100644
--- a/include/tvm/runtime/device_api.h
+++ b/include/tvm/runtime/device_api.h
@@ -225,7 +225,7 @@ class TVM_DLL DeviceAPI {
    * \param dev The device to set stream.
    * \param stream The stream to be set.
    */
-  virtual void SetStream(Device dev, TVMStreamHandle stream) {}
+  virtual void SetStream(Device dev, TVMStreamHandle stream);
   /*!
    * \brief Get the current stream
    * \param dev The device to get stream.
diff --git a/python/tvm/contrib/cutlass/attention_operation.py 
b/python/tvm/contrib/cutlass/attention_operation.py
index 4c876142d3..fe29cd5945 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -147,8 +147,7 @@ def instantiate_attention_template(attrs):
   }
 
   CHECK(Attention::check_supported(p));
-  auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
-  cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
+  cudaStream_t stream = 
static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, 
${query}->device.device_id));
 
   kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);
 
@@ -186,8 +185,7 @@ def instantiate_flash_attention_template(attrs):
     int v_batch_stride = v_row_stride * ${num_keys};
     int o_batch_stride = o_row_stride * ${num_queries};
 
-    auto func = 
tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
-    cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
+    cudaStream_t stream = 
static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, 
${query}->device.device_id));
 
     flash_attn::flash_attention_forward(
                             static_cast<const 
cutlass::half_t*>(${query}->data),
@@ -237,8 +235,7 @@ def instantiate_flash_attention_template(attrs):
     int v_batch_stride = v_row_stride * ${num_keys};
     int o_batch_stride = o_row_stride * ${num_queries};
 
-    auto func = 
tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
-    cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
+    cudaStream_t stream = 
static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, 
${query}->device.device_id));
 
     flash_attn::flash_attention_forward(
                             static_cast<const cutlass::half_t*>(${qkv}->data),
@@ -294,8 +291,7 @@ def instantiate_flash_attention_var_len_template(attrs):
     int v_row_stride = v_head_stride * ${num_kv_heads};
     int o_row_stride = o_head_stride * ${num_q_heads};
 
-    auto func = 
tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
-    cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
+    cudaStream_t stream = 
static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, 
${query}->device.device_id));
 
     flash_attn::flash_attention_var_len_forward(
                             static_cast<const 
cutlass::half_t*>(${query}->data),
diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py 
b/python/tvm/contrib/cutlass/conv2d_operation.py
index 361bcb54e5..b0afdcdd6e 100644
--- a/python/tvm/contrib/cutlass/conv2d_operation.py
+++ b/python/tvm/contrib/cutlass/conv2d_operation.py
@@ -424,8 +424,7 @@ def instantiate_conv2d_template(attrs):
   TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
   ${split_k_update}
 
-  auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
-  cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
+  cudaStream_t stream = 
static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, 
${data_arg}->device.device_id));
 
   status = conv2d_op(stream);
   TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
diff --git a/python/tvm/contrib/cutlass/gemm_operation.py 
b/python/tvm/contrib/cutlass/gemm_operation.py
index 65dc5da772..453839cc81 100644
--- a/python/tvm/contrib/cutlass/gemm_operation.py
+++ b/python/tvm/contrib/cutlass/gemm_operation.py
@@ -345,8 +345,7 @@ def instantiate_gemm_template(attrs):
   status = gemm_op.initialize(arguments, workspace.get());
   TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
 
-  auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
-  cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
+  cudaStream_t stream = 
static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, 
${A_arg}->device.device_id));
 
   status = gemm_op(stream);
   TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
@@ -428,8 +427,8 @@ def emit_fp16A_intB_matmul(attrs):
   int n = ${B_arg}->shape[1] * ${float_per_int};
   int k = ${B_arg}->shape[0];
 
-  auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
-  cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
+  cudaStream_t stream = static_cast<cudaStream_t>(
+    TVMFFIEnvGetCurrentStream(kDLCUDA, ${A_arg}->device.device_id));
     """,
         attrs,
     )
@@ -447,12 +446,14 @@ def emit_fp16A_intB_matmul(attrs):
 
     template_residual = """
   ${template_common}
-  gemm_fp16_int_bias_act_residual<${weight_dtype}, 
QuantOp>(static_cast<cutlass::half_t*>(${A_arg}->data),
+  gemm_fp16_int_bias_act_residual<${weight_dtype}, QuantOp>(
+                static_cast<cutlass::half_t*>(${A_arg}->data),
                 static_cast<${weight_dtype}*>(${B_arg}->data),
                 static_cast<cutlass::half_t*>(${scales_arg}->data),
                 ${bias},
                 static_cast<cutlass::half_t*>(${residual_arg}->data),
-                static_cast<cutlass::half_t*>(out0->data), "${activation}", 
"${binary_op}", "${unary_op}",
+                static_cast<cutlass::half_t*>(out0->data),
+                "${activation}", "${binary_op}", "${unary_op}",
                 m, n, k, ${group_size}, nullptr, 0, stream);
 """
 
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 6fa349b28e..c594b3897a 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -487,7 +487,7 @@ def instantiate_template(func_name, annotations, func_args):
         if k in annotations:
             attrs[k] = annotations[k]
 
-    headers = ["tvm/ffi/function.h"]
+    headers = ["tvm/ffi/function.h", "tvm/ffi/extra/c_env_api.h"]
 
     if "relu" in func_name:
         
headers.append("cutlass/epilogue/thread/linear_combination_bias_relu.h")
diff --git a/python/tvm/contrib/cutlass/layer_norm_operation.py 
b/python/tvm/contrib/cutlass/layer_norm_operation.py
index 74f397b39a..d2a0310244 100644
--- a/python/tvm/contrib/cutlass/layer_norm_operation.py
+++ b/python/tvm/contrib/cutlass/layer_norm_operation.py
@@ -39,8 +39,7 @@ def instantiate_layer_norm_template(attrs):
     cutlass::TensorRef<data_type, RowMajor> _beta((data_type*)${beta}->data, 
layout_channels);
     cutlass::TensorRef<data_type, RowMajor> _output((data_type*)out0->data, 
layout_2D);
 
-    auto func = 
tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
-    cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
+    cudaStream_t stream = 
static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, 
${input}->device.device_id));
 
     cutlass::layernorm(size, _output, _input, _gamma, _beta, stream);
     """
diff --git a/python/tvm/contrib/cutlass/rms_norm_operation.py 
b/python/tvm/contrib/cutlass/rms_norm_operation.py
index 27e98fb251..51c18d4ae4 100644
--- a/python/tvm/contrib/cutlass/rms_norm_operation.py
+++ b/python/tvm/contrib/cutlass/rms_norm_operation.py
@@ -38,8 +38,7 @@ def instantiate_rms_norm_template(attrs):
     cutlass::TensorRef<data_type, RowMajor> 
_weight((data_type*)${weight}->data, layout_channels);
     cutlass::TensorRef<data_type, RowMajor> _output((data_type*)out0->data, 
layout_2D);
 
-    auto func = 
tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
-    cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
+    cudaStream_t stream = 
static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, 
${input}->device.device_id));
 
     cutlass::rmsnorm(size, _output, _input, _weight, stream, ${rms_eps});
     """
diff --git a/src/contrib/msc/plugin/tvm_codegen.cc 
b/src/contrib/msc/plugin/tvm_codegen.cc
index a3861aabe7..7410867aaf 100644
--- a/src/contrib/msc/plugin/tvm_codegen.cc
+++ b/src/contrib/msc/plugin/tvm_codegen.cc
@@ -230,6 +230,7 @@ void TVMPluginCodeGen::CodeGenOpRuntime(const Plugin& 
plugin) {
   const auto& attr_name = MetaAttrCls(plugin);
   const auto& func_name = ComputeName(plugin);
   String device_cond = "";
+  String device_index = "";
   for (size_t i = 0; i < plugin->inputs.size(); i++) {
     String device_type = "";
     if (plugin->inputs[i]->device == "cuda" || plugin->inputs[i]->device == 
"default") {
@@ -381,7 +382,8 @@ void TVMPluginCodeGen::CodeGenCompute(const Plugin& plugin, 
const String& device
       ICHECK(plugin->buffers.size() == 0) << "Plugin with buffers is not 
supported in tvm";
       compute_args.push_back("meta_attr");
       if (device == "cuda") {
-        stack_.assign("stream", 
"runtime::CUDAThreadEntry::ThreadLocal()->stream", "auto");
+        // TODO(tvm-team): update to support get stream from device id
+        stack_.assign("stream", "TVMFFIEnvGetCurrentStream(kDLCUDA, 0)", 
"auto");
         compute_args.push_back("stream");
       }
       CodeGenSafeCall(plugin->externs[device + "_compute"], compute_args);
diff --git a/src/runtime/contrib/cublas/cublas.cc 
b/src/runtime/contrib/cublas/cublas.cc
index d55e0535c2..13f958744e 100644
--- a/src/runtime/contrib/cublas/cublas.cc
+++ b/src/runtime/contrib/cublas/cublas.cc
@@ -20,6 +20,7 @@
 /*!
  * \file Use external cblas library call.
  */
+#include <tvm/ffi/extra/c_env_api.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/runtime/data_type.h>
@@ -522,7 +523,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
         auto A = args[0].cast<DLTensor*>();
         auto C = args[2].cast<DLTensor*>();
 
-        CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();
+        CuBlasThreadEntry* entry_ptr = 
CuBlasThreadEntry::ThreadLocal(A->device);
 
         CUBLASTryEnableTensorCore(entry_ptr->handle);
 
@@ -549,15 +550,15 @@ TVM_FFI_STATIC_INIT_BLOCK({
       "tvm.contrib.cublaslt.matmul", [](ffi::PackedArgs args, ffi::Any* ret) {
         auto A = args[0].cast<DLTensor*>();
 
-        CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();
+        CuBlasThreadEntry* entry_ptr = 
CuBlasThreadEntry::ThreadLocal(A->device);
 
         CUBLASTryEnableTensorCore(entry_ptr->handle);
 
         ICHECK(TypeMatch(A->dtype, kDLInt, 8)) << "Expects dtype to be int8\n";
         cublasLtHandle_t ltHandle;
         CHECK_CUBLAS_ERROR(cublasLtCreate(&ltHandle));
-        auto func = 
tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
-        cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
+        cudaStream_t stream =
+            static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, 
A->device.device_id));
         CallLtIgemm(args, ret, ltHandle, stream);
         CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle));
       });
@@ -571,7 +572,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
         auto A = args[0].cast<DLTensor*>();
         auto C = args[2].cast<DLTensor*>();
 
-        CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();
+        CuBlasThreadEntry* entry_ptr = 
CuBlasThreadEntry::ThreadLocal(A->device);
 
         CUBLASTryEnableTensorCore(entry_ptr->handle);
         if (TypeEqual(A->dtype, C->dtype)) {
diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc 
b/src/runtime/contrib/cublas/cublas_json_runtime.cc
index 11fa3b0c4d..0416391303 100644
--- a/src/runtime/contrib/cublas/cublas_json_runtime.cc
+++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc
@@ -22,6 +22,7 @@
  * \brief A simple JSON runtime for CUBLAS.
  */
 
+#include <tvm/ffi/extra/c_env_api.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/runtime/ndarray.h>
@@ -30,6 +31,7 @@
 #include <string>
 #include <vector>
 
+#include "../../cuda/cuda_common.h"
 #include "../json/json_node.h"
 #include "../json/json_runtime.h"
 #include "cublas_utils.h"
@@ -67,13 +69,8 @@ class CublasJSONRuntime : public JSONRuntimeBase {
   const char* kind() const override { return "cublas_json"; }  // May be 
overridden
 
   void Run(ffi::PackedArgs args) {
-    auto* entry_ptr = tvm::contrib::CuBlasLtThreadEntry::ThreadLocal();
-
-    auto func = 
tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
-    cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
-
     std::vector<const DLTensor*> dl_tensors(NumEntries());
-
+    int device_id = -1;
     for (size_t i = 0; i < static_cast<size_t>(args.size()); i++) {
       auto eid = i < input_var_eid_.size() ? input_var_eid_[i]
                                            : EntryID(outputs_[i - 
input_var_eid_.size()]);
@@ -87,7 +84,14 @@ class CublasJSONRuntime : public JSONRuntimeBase {
       }
 
       dl_tensors[eid] = arg;
+      device_id = arg->device.device_id;
+    }
+
+    if (device_id == -1) {
+      CUDA_CALL(cudaGetDevice(&device_id));
     }
+    auto* entry_ptr = 
tvm::contrib::CuBlasLtThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id});
+    cudaStream_t stream = 
static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id));
 
     auto get_input = [this, &dl_tensors](const JSONGraphNode& node, int idx) {
       ICHECK_LT(idx, node.GetInputs().size());
diff --git a/src/runtime/contrib/cublas/cublas_utils.cc 
b/src/runtime/contrib/cublas/cublas_utils.cc
index 53e00fe141..0ba654c9eb 100644
--- a/src/runtime/contrib/cublas/cublas_utils.cc
+++ b/src/runtime/contrib/cublas/cublas_utils.cc
@@ -23,6 +23,7 @@
 #include "cublas_utils.h"
 
 #include <dmlc/thread_local.h>
+#include <tvm/ffi/extra/c_env_api.h>
 #include <tvm/ffi/function.h>
 
 #include "../../cuda/cuda_common.h"
@@ -41,10 +42,11 @@ CuBlasThreadEntry::~CuBlasThreadEntry() {
 
 typedef dmlc::ThreadLocalStore<CuBlasThreadEntry> CuBlasThreadStore;
 
-CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal() {
-  auto stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
+CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal(DLDevice curr_device) {
   CuBlasThreadEntry* retval = CuBlasThreadStore::Get();
-  CHECK_CUBLAS_ERROR(cublasSetStream(retval->handle, 
static_cast<cudaStream_t>(stream)));
+  cudaStream_t stream = static_cast<cudaStream_t>(
+      TVMFFIEnvGetCurrentStream(curr_device.device_type, 
curr_device.device_id));
+  CHECK_CUBLAS_ERROR(cublasSetStream(retval->handle, stream));
   return retval;
 }
 
@@ -71,7 +73,9 @@ CuBlasLtThreadEntry::~CuBlasLtThreadEntry() {
 
 typedef dmlc::ThreadLocalStore<CuBlasLtThreadEntry> CuBlasLtThreadStore;
 
-CuBlasLtThreadEntry* CuBlasLtThreadEntry::ThreadLocal() { return 
CuBlasLtThreadStore::Get(); }
+CuBlasLtThreadEntry* CuBlasLtThreadEntry::ThreadLocal(DLDevice curr_device) {
+  return CuBlasLtThreadStore::Get();
+}
 
 }  // namespace contrib
 }  // namespace tvm
diff --git a/src/runtime/contrib/cublas/cublas_utils.h 
b/src/runtime/contrib/cublas/cublas_utils.h
index 3e9ded08de..12260a78ef 100644
--- a/src/runtime/contrib/cublas/cublas_utils.h
+++ b/src/runtime/contrib/cublas/cublas_utils.h
@@ -75,7 +75,7 @@ struct CuBlasThreadEntry {
   CuBlasThreadEntry();
   ~CuBlasThreadEntry();
   cublasHandle_t handle{nullptr};
-  static CuBlasThreadEntry* ThreadLocal();
+  static CuBlasThreadEntry* ThreadLocal(DLDevice curr_device);
 };  // CuBlasThreadEntry
 
 struct CuBlasLtThreadEntry {
@@ -89,7 +89,7 @@ struct CuBlasLtThreadEntry {
   // https://docs.nvidia.com/cuda/cublas/index.html#cublassetworkspace.
   static constexpr const size_t workspace_size = 33554432;
 
-  static CuBlasLtThreadEntry* ThreadLocal();
+  static CuBlasLtThreadEntry* ThreadLocal(DLDevice curr_device);
 };  // CuBlasLtThreadEntry
 
 inline cudaDataType_t GetCudaDataType(DLDataType type) {
diff --git a/src/runtime/contrib/cudnn/conv_backward.cc 
b/src/runtime/contrib/cudnn/conv_backward.cc
index 915f21bc7c..515263ef36 100644
--- a/src/runtime/contrib/cudnn/conv_backward.cc
+++ b/src/runtime/contrib/cudnn/conv_backward.cc
@@ -35,7 +35,7 @@ using namespace runtime;
 void ConvolutionBackwardData(int mode, int format, int algo, int dims, int 
groups, const int pad[],
                              const int stride[], const int dilation[], 
DLTensor* dy, DLTensor* w,
                              DLTensor* dx, const std::string& conv_dtype) {
-  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
+  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(dy->device);
   // Set Mode
   entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
   SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, 
dx->shape, w->shape,
@@ -65,7 +65,9 @@ void BackwardDataFindAlgo(int format, int dims, int groups, 
const int pad[], con
                           const int dilation[], const int dy_dim[], const int 
w_dim[],
                           const int dx_dim[], const std::string& data_dtype,
                           const std::string& conv_dtype, bool verbose, 
ffi::Any* ret) {
-  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
+  int device_id;
+  CUDA_CALL(cudaGetDevice(&device_id));
+  CuDNNThreadEntry* entry_ptr = 
CuDNNThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id});
   const int full_dims = dims + 2;
   std::vector<int64_t> dy_dim_int64(full_dims);
   std::vector<int64_t> w_dim_int64(full_dims);
@@ -112,7 +114,7 @@ void ConvolutionBackwardFilter(int mode, int format, int 
algo, int dims, int gro
                                const int pad[], const int stride[], const int 
dilation[],
                                DLTensor* dy, DLTensor* x, DLTensor* dw,
                                const std::string& conv_dtype) {
-  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
+  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(x->device);
   // Set Mode
   entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
   SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, 
x->shape, dw->shape,
@@ -142,7 +144,9 @@ void BackwardFilterFindAlgo(int format, int dims, int 
groups, const int pad[], c
                             const int dilation[], const int dy_dim[], const 
int x_dim[],
                             const int dw_dim[], const std::string& data_dtype,
                             const std::string& conv_dtype, bool verbose, 
ffi::Any* ret) {
-  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
+  int device_id;
+  CUDA_CALL(cudaGetDevice(&device_id));
+  CuDNNThreadEntry* entry_ptr = 
CuDNNThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id});
   const int full_dims = dims + 2;
   std::vector<int64_t> x_dim_int64(full_dims);
   std::vector<int64_t> dy_dim_int64(full_dims);
diff --git a/src/runtime/contrib/cudnn/conv_forward.cc 
b/src/runtime/contrib/cudnn/conv_forward.cc
index a0a9edef97..7a93e194ce 100644
--- a/src/runtime/contrib/cudnn/conv_forward.cc
+++ b/src/runtime/contrib/cudnn/conv_forward.cc
@@ -35,7 +35,7 @@ using namespace runtime;
 void ConvolutionForward(int mode, int format, int algo, int dims, int groups, 
const int pad[],
                         const int stride[], const int dilation[], const 
DLTensor* x,
                         const DLTensor* w, const DLTensor* y, const 
std::string& conv_dtype) {
-  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
+  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(x->device);
   // Set Mode
   entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
   SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, 
x->shape, w->shape,
@@ -69,7 +69,7 @@ void ConvolutionBiasActivationForward(int mode, int format, 
int algo, int dims,
                                       const int dilation[], const DLTensor* x, 
const DLTensor* w,
                                       const DLTensor* y, const DLTensor* bias,
                                       const std::string& conv_dtype) {
-  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
+  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(x->device);
   // Set Mode
   entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
   
CUDNN_CALL(cudnnSetActivationDescriptor(entry_ptr->conv_entry.activation_desc,
@@ -110,7 +110,9 @@ void FindAlgo(int format, int dims, int groups, const int 
pad[], const int strid
               const int dilation[], const int x_dim[], const int w_dim[], 
const int y_dim[],
               const std::string& data_dtype, const std::string& conv_dtype, 
bool verbose,
               ffi::Any* ret) {
-  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
+  int device_id;
+  CUDA_CALL(cudaGetDevice(&device_id));
+  CuDNNThreadEntry* entry_ptr = 
CuDNNThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id});
   const int full_dims = dims + 2;
   std::vector<int64_t> x_dim_int64(full_dims);
   std::vector<int64_t> w_dim_int64(full_dims);
diff --git a/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc 
b/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc
index dffce67389..fbde314bc6 100644
--- a/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc
+++ b/src/runtime/contrib/cudnn/cudnn_frontend/attention.cc
@@ -98,13 +98,13 @@ void CuDNNSDPARunnerNode::Init(int64_t batch, int64_t 
seq_len, int64_t num_heads
   auto [o, stats] = graph_->sdpa(q, k, v, sdpa_options);
   CHECK(stats == nullptr);
   o->set_output(true).set_dim({batch, num_heads, seq_len, 
head_size_v}).set_stride(o_stride);
-  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
+  int device_id;
+  CUDA_CALL(cudaGetDevice(&device_id));
+  CuDNNThreadEntry* entry_ptr = 
CuDNNThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id});
   CUDNN_FRONTEND_CALL(graph_->build(entry_ptr->handle, 
{cudnn_frontend::HeurMode_t::A}));
 }
 
 void CuDNNSDPARunnerNode::Run(const DLTensor* qkv, DLTensor* workspace, 
DLTensor* out) {
-  CUDNN_CALL(
-      cudnnSetStream(CuDNNThreadEntry::ThreadLocal()->handle, 
tvm::runtime::GetCUDAStream()));
   auto* qkv_base = reinterpret_cast<uint8_t*>(qkv->data) + qkv->byte_offset;
   auto* q_ptr = reinterpret_cast<uint16_t*>(qkv_base) + offset_q_;
   auto* k_ptr = reinterpret_cast<uint16_t*>(qkv_base) + offset_k_;
@@ -116,7 +116,7 @@ void CuDNNSDPARunnerNode::Run(const DLTensor* qkv, 
DLTensor* workspace, DLTensor
   std::unordered_map<cudnn_frontend::graph::Tensor_attributes::uid_t, void*> 
inputs = {
       {kTensorIDQ, q_ptr}, {kTensorIDK, k_ptr}, {kTensorIDV, v_ptr}, 
{kTensorIDOut, out_ptr}};
 
-  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
+  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(qkv->device);
   CUDNN_FRONTEND_CALL(graph_->execute(entry_ptr->handle, inputs, 
workspace->data));
 }
 
diff --git a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc 
b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc
index fd4fa68c78..3888bca3df 100644
--- a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc
+++ b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc
@@ -22,6 +22,7 @@
  * \brief A simple JSON runtime for CUDNN.
  */
 
+#include <tvm/ffi/extra/c_env_api.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/runtime/ndarray.h>
@@ -100,7 +101,9 @@ class cuDNNJSONRuntime : public JSONRuntimeBase {
   }
 
   std::function<void()> GetConv2DExec(const JSONGraphNode& node) {
-    auto* entry_ptr = tvm::contrib::CuDNNThreadEntry::ThreadLocal();
+    int device_id;
+    CUDA_CALL(cudaGetDevice(&device_id));
+    auto* entry_ptr = 
tvm::contrib::CuDNNThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id});
     auto op_name = node.GetOpName();
 
     std::vector<int> input_dims, kernel_dims, output_dims;
@@ -159,7 +162,10 @@ class cuDNNJSONRuntime : public JSONRuntimeBase {
 
     int algo = best_algo.cast<int>();
     std::function<void()> op_exec = [=]() {
-      auto stream = static_cast<cudaStream_t>(GetCUDAStream());
+      int device_id;
+      CUDA_CALL(cudaGetDevice(&device_id));
+      cudaStream_t stream =
+          static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, 
device_id));
       CUDNN_CALL(cudnnSetStream(entry_ptr->handle, stream));
 
       auto get_inputs = [this](const JSONGraphNode& node, bool has_bias) {
diff --git a/src/runtime/contrib/cudnn/cudnn_utils.cc 
b/src/runtime/contrib/cudnn/cudnn_utils.cc
index f5bb56e089..acedf7a9e2 100644
--- a/src/runtime/contrib/cudnn/cudnn_utils.cc
+++ b/src/runtime/contrib/cudnn/cudnn_utils.cc
@@ -24,6 +24,7 @@
 #include "cudnn_utils.h"
 
 #include <dmlc/thread_local.h>
+#include <tvm/ffi/extra/c_env_api.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/runtime/data_type.h>
@@ -101,7 +102,6 @@ const void* CuDNNDataType::GetConst<1>(cudnnDataType_t 
type) {
 // CuDNNThreadEntry
 
 CuDNNThreadEntry::CuDNNThreadEntry() {
-  auto stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
   auto func = tvm::ffi::Function::GetGlobalRequired("device_api.cuda");
   void* ret = func().cast<void*>();
   cuda_api = static_cast<runtime::DeviceAPI*>(ret);
@@ -116,8 +116,6 @@ CuDNNThreadEntry::CuDNNThreadEntry() {
     }
     CUDNN_CALL(create_res);
   }
-
-  CUDNN_CALL(cudnnSetStream(handle, stream));
   conv_entry.cuda_api = cuda_api;
 }
 
@@ -125,12 +123,15 @@ CuDNNThreadEntry::~CuDNNThreadEntry() {}
 
 typedef dmlc::ThreadLocalStore<CuDNNThreadEntry> CuDNNThreadStore;
 
-CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal(bool check_exists) {
+CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal(Device curr_device, bool 
check_exists) {
   auto* res = CuDNNThreadStore::Get();
   if (check_exists) {
     ICHECK(res->exists()) << "CUDNN_STATUS_NOT_INITIALIZED";
   }
 
+  cudaStream_t stream = static_cast<cudaStream_t>(
+      TVMFFIEnvGetCurrentStream(curr_device.device_type, 
curr_device.device_id));
+  CUDNN_CALL(cudnnSetStream(res->handle, stream));
   return res;
 }
 
@@ -268,8 +269,11 @@ SoftmaxEntry::~SoftmaxEntry() { 
CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_de
 
 TVM_FFI_STATIC_INIT_BLOCK({
   namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("tvm.contrib.cudnn.exists",
-                        []() -> bool { return 
CuDNNThreadEntry::ThreadLocal(false)->exists(); });
+  refl::GlobalDef().def("tvm.contrib.cudnn.exists", []() -> bool {
+    int device_id;
+    CUDA_CALL(cudaGetDevice(&device_id));
+    return CuDNNThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id}, 
false)->exists();
+  });
 });
 
 }  // namespace contrib
diff --git a/src/runtime/contrib/cudnn/cudnn_utils.h 
b/src/runtime/contrib/cudnn/cudnn_utils.h
index 902b615323..499cc5d6c9 100644
--- a/src/runtime/contrib/cudnn/cudnn_utils.h
+++ b/src/runtime/contrib/cudnn/cudnn_utils.h
@@ -106,7 +106,7 @@ struct CuDNNThreadEntry {
   ConvEntry conv_entry;
   SoftmaxEntry softmax_entry;
   runtime::DeviceAPI* cuda_api{nullptr};
-  static CuDNNThreadEntry* ThreadLocal(bool check_exists = true);
+  static CuDNNThreadEntry* ThreadLocal(Device curr_device, bool check_exists = 
true);
 };  // CuDNNThreadEntry
 
 void SetConvDescriptors(CuDNNThreadEntry* entry_ptr, int format, int dims, int 
groups,
diff --git a/src/runtime/contrib/cudnn/softmax.cc 
b/src/runtime/contrib/cudnn/softmax.cc
index f0fda4fd59..eb2fceb3d2 100644
--- a/src/runtime/contrib/cudnn/softmax.cc
+++ b/src/runtime/contrib/cudnn/softmax.cc
@@ -40,8 +40,9 @@ void softmax_impl(cudnnSoftmaxAlgorithm_t alg, 
ffi::PackedArgs args, ffi::Any* r
   int64_t* shape = x->shape;
   if (axis < 0) axis += ndim;
   ICHECK(axis >= 0 && axis < ndim);
-
-  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
+  int device_id;
+  CUDA_CALL(cudaGetDevice(&device_id));
+  CuDNNThreadEntry* entry_ptr = 
CuDNNThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id});
   entry_ptr->softmax_entry.data_type = 
CuDNNDataType::DLTypeToCuDNNType(x->dtype);
 
   // Set mode and shape descriptor
diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm.cuh 
b/src/runtime/contrib/cutlass/fp16_group_gemm.cuh
index ebb8f58a6b..a09051a86e 100644
--- a/src/runtime/contrib/cutlass/fp16_group_gemm.cuh
+++ b/src/runtime/contrib/cutlass/fp16_group_gemm.cuh
@@ -36,7 +36,7 @@ void tvm_cutlass_group_gemm_impl(NDArray x, NDArray weight, 
NDArray indptr, NDAr
                                  NDArray out) {
   // Workspace is used for storing device-side group gemm arguments and 
cutlass internal workspace.
   // Recommened size is 4MB.
-  static auto func = 
tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
+  cudaStream_t stream = 
static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, 
x->device.device_id));
   CHECK_EQ(x->ndim, 2);
   CHECK_EQ(weight->ndim, 3);
   CHECK_EQ(indptr->ndim, 1);
diff --git a/src/runtime/contrib/cutlass/fp8_gemm.cu 
b/src/runtime/contrib/cutlass/fp8_gemm.cu
index befef1db93..5cabd0ca7a 100644
--- a/src/runtime/contrib/cutlass/fp8_gemm.cu
+++ b/src/runtime/contrib/cutlass/fp8_gemm.cu
@@ -19,6 +19,7 @@
 
 #include <cuda_fp16.h>
 #include <float.h>
+#include <tvm/ffi/extra/c_env_api.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/runtime/ndarray.h>
@@ -42,8 +43,8 @@ void tvm_cutlass_fp8_gemm(NDArray x, NDArray weight, NDArray 
workspace, NDArray
                           NDArray out) {
   // Workspace is used for storing device-side gemm arguments and cutlass 
internal workspace.
   // Recommened size is 4MB.
-  static auto func = 
tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
-  cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
+  cudaStream_t stream =
+      static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, 
x->device.device_id));
 
   CHECK_GE(x->ndim, 2);
   CHECK_EQ(weight->ndim, 2);
@@ -68,7 +69,7 @@ void tvm_cutlass_fp8_gemm(NDArray x, NDArray weight, NDArray 
workspace, NDArray
         static_cast<float*>(alpha->data), beta, 
static_cast<ElementC*>(out->data), stream);
   } else {
     tvm::contrib::CuBlasLtThreadEntry* cublas_entry =
-        tvm::contrib::CuBlasLtThreadEntry::ThreadLocal();
+        tvm::contrib::CuBlasLtThreadEntry::ThreadLocal(x->device);
     tvm::contrib::CallCublasLt(cublas_entry->handle, stream, 
cublas_entry->matmul_pref_desc,
                                x.operator->(), weight.operator->(), nullptr, 
alpha.operator->(),
                                nullptr, out.operator->(), /*transa=*/false, 
/*transb=*/true,
diff --git a/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu 
b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu
index f9f03fc4ed..150485b868 100644
--- a/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu
+++ b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu
@@ -19,6 +19,7 @@
 
 #include <cuda_fp16.h>
 #include <float.h>
+#include <tvm/ffi/extra/c_env_api.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/runtime/ndarray.h>
@@ -45,8 +46,8 @@ void tvm_cutlass_fp8_group_gemm(NDArray x, NDArray weight, 
NDArray indptr, NDArr
                                 NDArray alpha, NDArray out) {
   // Workspace is used for storing device-side group gemm arguments and 
cutlass internal workspace.
   // Recommened size is 4MB.
-  static auto func = 
tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
-  cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
+  cudaStream_t stream =
+      static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, 
x->device.device_id));
   CHECK_EQ(x->ndim, 2);
   CHECK_EQ(weight->ndim, 3);
   CHECK_EQ(indptr->ndim, 1);
diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh 
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh
index 4ecca5f1d8..0f688616d5 100644
--- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh
+++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh
@@ -20,6 +20,7 @@
 #include <cuda_fp16.h>
 #include <float.h>
 #include <tvm/ffi/function.h>
+#include <tvm/ffi/extra/c_env_api.h>
 #include <tvm/runtime/ndarray.h>
 
 #include "cutlass/bfloat16.h"
@@ -39,9 +40,7 @@ void tvm_cutlass_fp8_groupwise_scaled_gemm_impl(NDArray a, 
NDArray b, NDArray sc
                                                 NDArray out) {
   // Workspace is used for storing device-side gemm arguments and cutlass 
internal workspace.
   // Recommened size is 4MB.
-  static tvm::ffi::Function get_stream_func =
-      tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
-  cudaStream_t stream = 
static_cast<cudaStream_t>(get_stream_func().cast<void*>());
+  cudaStream_t stream = 
static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, 
a->device.device_id));
 
   CHECK_GE(a->ndim, 2);
   CHECK_EQ(scales_a->ndim, a->ndim);
@@ -107,9 +106,7 @@ void tvm_cutlass_fp8_groupwise_scaled_bmm_impl(NDArray a, 
NDArray b, NDArray sca
                                                NDArray out) {
   // Workspace is used for storing device-side gemm arguments and cutlass 
internal workspace.
   // Recommened size is 4MB.
-  static tvm::ffi::Function get_stream_func =
-      tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
-  cudaStream_t stream = 
static_cast<cudaStream_t>(get_stream_func().cast<void*>());
+  cudaStream_t stream = 
static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, 
a->device.device_id));
 
   CHECK_EQ(a->ndim, 3);
   CHECK_EQ(scales_a->ndim, 3);
diff --git 
a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu 
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu
index 5b467c9bd5..2745c0b1fc 100644
--- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu
+++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu
@@ -37,8 +37,8 @@ void tvm_fp8_groupwise_scaled_group_gemm_sm100(NDArray a, 
NDArray b, NDArray sca
                                                NDArray out) {
   // Workspace is used for storing device-side group gemm arguments and 
cutlass internal workspace.
   // Recommended size is 4MB.
-  static auto func = 
tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
-  cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
+  cudaStream_t stream =
+      static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, 
a->device.device_id));
   CHECK_EQ(a->ndim, 2);
   CHECK_EQ(b->ndim, 3);
   CHECK_EQ(indptr->ndim, 1);
diff --git a/src/runtime/contrib/hipblas/hipblas.cc 
b/src/runtime/contrib/hipblas/hipblas.cc
index 4e7a5c5d10..628ffb5bdf 100644
--- a/src/runtime/contrib/hipblas/hipblas.cc
+++ b/src/runtime/contrib/hipblas/hipblas.cc
@@ -416,7 +416,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
                     auto A = args[0].cast<DLTensor*>();
                     auto C = args[2].cast<DLTensor*>();
 
-                    HipBlasThreadEntry* entry_ptr = 
HipBlasThreadEntry::ThreadLocal();
+                    HipBlasThreadEntry* entry_ptr = 
HipBlasThreadEntry::ThreadLocal(A->device);
 
                     if (TypeEqual(A->dtype, C->dtype)) {
                       ICHECK(TypeMatch(A->dtype, kDLFloat, 16) ||
@@ -438,7 +438,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
         auto A = args[0].cast<DLTensor*>();
         auto C = args[2].cast<DLTensor*>();
 
-        HipBlasThreadEntry* entry_ptr = HipBlasThreadEntry::ThreadLocal();
+        HipBlasThreadEntry* entry_ptr = 
HipBlasThreadEntry::ThreadLocal(A->device);
 
         if (TypeEqual(A->dtype, C->dtype)) {
           ICHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, 
kDLFloat, 32) ||
diff --git a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc 
b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc
index 5750b91ab4..ab8545561b 100644
--- a/src/runtime/contrib/hipblas/hipblas_json_runtime.cc
+++ b/src/runtime/contrib/hipblas/hipblas_json_runtime.cc
@@ -65,10 +65,7 @@ class HipblasJSONRuntime : public JSONRuntimeBase {
   const char* kind() const override { return "hipblas_json"; }  // May be 
overridden
 
   void Run(ffi::PackedArgs args) {
-    auto* entry_ptr = tvm::contrib::HipBlasLtThreadEntry::ThreadLocal();
-    static auto func = 
tvm::ffi::Function::GetGlobalRequired("runtime.get_rocm_stream");
-    hipStream_t stream = static_cast<hipStream_t>(func().cast<void*>());
-
+    int device_id = -1;
     std::vector<const DLTensor*> dl_tensors(NumEntries());
 
     for (size_t i = 0; i < static_cast<size_t>(args.size()); i++) {
@@ -84,7 +81,13 @@ class HipblasJSONRuntime : public JSONRuntimeBase {
       }
 
       dl_tensors[eid] = arg;
+      device_id = arg->device.device_id;
+    }
+    if (device_id == -1) {
+      ROCM_CALL(hipGetDevice(&device_id));
     }
+    auto* entry_ptr = 
tvm::contrib::HipBlasLtThreadEntry::ThreadLocal(Device(kDLROCM, device_id));
+    hipStream_t stream = 
static_cast<hipStream_t>(TVMFFIEnvGetCurrentStream(kDLROCM, device_id));
 
     auto get_input = [this, &dl_tensors](const JSONGraphNode& node, int idx) {
       ICHECK_LT(idx, node.GetInputs().size());
diff --git a/src/runtime/contrib/hipblas/hipblas_utils.cc 
b/src/runtime/contrib/hipblas/hipblas_utils.cc
index 6facbb232b..454ab7a370 100644
--- a/src/runtime/contrib/hipblas/hipblas_utils.cc
+++ b/src/runtime/contrib/hipblas/hipblas_utils.cc
@@ -41,9 +41,10 @@ HipBlasThreadEntry::~HipBlasThreadEntry() {
 
 typedef dmlc::ThreadLocalStore<HipBlasThreadEntry> HipBlasThreadStore;
 
-HipBlasThreadEntry* HipBlasThreadEntry::ThreadLocal() {
-  auto stream = runtime::ROCMThreadEntry::ThreadLocal()->stream;
+HipBlasThreadEntry* HipBlasThreadEntry::ThreadLocal(Device curr_device) {
   HipBlasThreadEntry* retval = HipBlasThreadStore::Get();
+  TVMFFIStreamHandle stream =
+      TVMFFIEnvGetCurrentStream(curr_device.device_type, 
curr_device.device_id);
   CHECK_HIPBLAS_ERROR(hipblasSetStream(retval->handle, 
static_cast<hipStream_t>(stream)));
   return retval;
 }
@@ -71,7 +72,9 @@ HipBlasLtThreadEntry::~HipBlasLtThreadEntry() {
 
 typedef dmlc::ThreadLocalStore<HipBlasLtThreadEntry> HipBlasLtThreadStore;
 
-HipBlasLtThreadEntry* HipBlasLtThreadEntry::ThreadLocal() { return 
HipBlasLtThreadStore::Get(); }
+HipBlasLtThreadEntry* HipBlasLtThreadEntry::ThreadLocal(Device curr_device) {
+  return HipBlasLtThreadStore::Get();
+}
 
 }  // namespace contrib
 
diff --git a/src/runtime/contrib/miopen/conv_forward.cc 
b/src/runtime/contrib/miopen/conv_forward.cc
index 53eba8e9c4..2c8a70aa6b 100644
--- a/src/runtime/contrib/miopen/conv_forward.cc
+++ b/src/runtime/contrib/miopen/conv_forward.cc
@@ -59,8 +59,10 @@ TVM_FFI_STATIC_INIT_BLOCK({
             const int w_dim3 = args[15].cast<int>();
             const int n_group = args[16].cast<int>();
             void* out_shape = args[17].cast<void*>();
-
-            MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal();
+            int device_id = -1;
+            ROCM_CALL(hipGetDevice(&device_id));
+            MIOpenThreadEntry* entry_ptr =
+                MIOpenThreadEntry::ThreadLocal(Device{kDLROCM, device_id});
             assert(n_group > 0 && "Group Size > 0 is expected");
             if (n_group > 1)
               assert(mode > 1 && "Group /Depthwise Conv mode when num of 
groups > 1");
@@ -168,7 +170,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
         const auto w = args[10].cast<DLTensor*>();
         const auto y = args[11].cast<DLTensor*>();
 
-        MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal();
+        MIOpenThreadEntry* entry_ptr = 
MIOpenThreadEntry::ThreadLocal(x->device);
         entry_ptr->conv_entry.fwd_algo = 
static_cast<miopenConvFwdAlgorithm_t>(algo);
         // Set Mode
         entry_ptr->conv_entry.mode = 
static_cast<miopenConvolutionMode_t>(mode);
diff --git a/src/runtime/contrib/miopen/miopen_utils.cc 
b/src/runtime/contrib/miopen/miopen_utils.cc
index bb091fdf7a..e860ba8ea7 100644
--- a/src/runtime/contrib/miopen/miopen_utils.cc
+++ b/src/runtime/contrib/miopen/miopen_utils.cc
@@ -42,12 +42,10 @@ std::string miopenGetErrorString(int error_code) {
 
 // MiopenThreadEntry
 MIOpenThreadEntry::MIOpenThreadEntry() {
-  auto stream = runtime::ROCMThreadEntry::ThreadLocal()->stream;
   const auto get_rocm_api = 
tvm::ffi::Function::GetGlobalRequired("device_api.rocm");
   void* ret = get_rocm_api();
   rocm_api = static_cast<runtime::DeviceAPI*>(ret);
   MIOPEN_CALL(miopenCreate(&handle));
-  MIOPEN_CALL(miopenSetStream(handle, stream));
   conv_entry.rocm_api = rocm_api;
 }
 
@@ -55,7 +53,14 @@ MIOpenThreadEntry::~MIOpenThreadEntry() { 
MIOPEN_CALL(miopenDestroy(handle)); }
 
 typedef dmlc::ThreadLocalStore<MIOpenThreadEntry> MIOpenThreadStore;
 
-MIOpenThreadEntry* MIOpenThreadEntry::ThreadLocal() { return 
MIOpenThreadStore::Get(); }
+MIOpenThreadEntry* MIOpenThreadEntry::ThreadLocal(Device curr_device) {
+  // Need to update stream per fetch to avoid stream switching
+  MIOpenThreadEntry* res = MIOpenThreadStore::Get();
+  TVMFFIStreamHandle stream =
+      TVMFFIEnvGetCurrentStream(curr_device.device_type, 
curr_device.device_id);
+  MIOPEN_CALL(miopenSetStream(res->handle, stream));
+  return res;
+}
 
 // ConvEntry
 
diff --git a/src/runtime/contrib/miopen/softmax.cc 
b/src/runtime/contrib/miopen/softmax.cc
index dfcde9e879..5853cb2a7b 100644
--- a/src/runtime/contrib/miopen/softmax.cc
+++ b/src/runtime/contrib/miopen/softmax.cc
@@ -45,7 +45,7 @@ void softmax_impl(ffi::PackedArgs args, ffi::Any* ret, 
miopenSoftmaxAlgorithm_t
   ICHECK(TypeMatch(x->dtype, kDLFloat, 32));
   ICHECK(TypeMatch(y->dtype, kDLFloat, 32));
 
-  MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal();
+  MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(x->device);
 
   miopenSoftmaxMode_t mode;
   if (axis == ndim - 1) {
diff --git a/src/runtime/contrib/msc/tensorrt_runtime.cc 
b/src/runtime/contrib/msc/tensorrt_runtime.cc
index e19c03d4fd..37ae9f2548 100644
--- a/src/runtime/contrib/msc/tensorrt_runtime.cc
+++ b/src/runtime/contrib/msc/tensorrt_runtime.cc
@@ -123,15 +123,17 @@ class MSCTensorRTRuntime : public JSONRuntimeBase {
       const auto pf = tvm::ffi::Function::GetGlobal("msc_tool.callback_step");
       ICHECK(pf.has_value()) << "Cannot find msc_tool.callback_step func.";
       Map<String, runtime::NDArray> input_datas;
+      int device_id = 0;
       for (const auto& pair : input_bindings_) {
         const auto& tensor_name = engine_->getBindingName(pair.first);
         input_datas.Set(tensor_name, device_buffers_[pair.first]);
+        device_id = data_entry_[pair.first]->device.device_id;
       }
       Map<String, Map<String, runtime::NDArray>> context;
       context.Set("datas", input_datas);
       (*pf)(context, "before_forward", graph_name_, tool_tag_);
     }
-    auto tvm_stream = CUDAThreadEntry::ThreadLocal()->stream;
+    auto tvm_stream = TVMFFIEnvGetCurrentStream(kDLCUDA, device_id);
 #if TRT_VERSION_GE(6, 0, 1)
     ICHECK(context_->enqueueV2(bindings_.data(), tvm_stream, nullptr))
         << "Running TensorRT failed.";
diff --git a/src/runtime/cuda/cuda_common.h b/src/runtime/cuda/cuda_common.h
index a378e53c54..fd032fc75b 100644
--- a/src/runtime/cuda/cuda_common.h
+++ b/src/runtime/cuda/cuda_common.h
@@ -54,8 +54,6 @@ namespace runtime {
 /*! \brief Thread local workspace */
 class CUDAThreadEntry {
  public:
-  /*! \brief The cuda stream */
-  cudaStream_t stream{nullptr};
   /*! \brief thread local pool*/
   WorkspacePool pool;
   /*! \brief constructor */
@@ -64,8 +62,6 @@ class CUDAThreadEntry {
   static CUDAThreadEntry* ThreadLocal();
 };
 
-inline cudaStream_t GetCUDAStream() { return 
CUDAThreadEntry::ThreadLocal()->stream; }
-
 }  // namespace runtime
 }  // namespace tvm
 #endif  // TVM_RUNTIME_CUDA_CUDA_COMMON_H_
diff --git a/src/runtime/cuda/cuda_device_api.cc 
b/src/runtime/cuda/cuda_device_api.cc
index 8a0da35c20..451348afbf 100644
--- a/src/runtime/cuda/cuda_device_api.cc
+++ b/src/runtime/cuda/cuda_device_api.cc
@@ -24,6 +24,7 @@
 #include <cuda.h>
 #include <cuda_runtime.h>
 #include <dmlc/thread_local.h>
+#include <tvm/ffi/extra/c_env_api.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/runtime/device_api.h>
@@ -249,14 +250,6 @@ class CUDADeviceAPI final : public DeviceAPI {
     CUDA_CALL(cudaStreamSynchronize(static_cast<cudaStream_t>(stream)));
   }
 
-  void SetStream(Device dev, TVMStreamHandle stream) final {
-    CUDAThreadEntry::ThreadLocal()->stream = static_cast<cudaStream_t>(stream);
-  }
-
-  TVMStreamHandle GetCurrentStream(Device dev) final {
-    return 
static_cast<TVMStreamHandle>(CUDAThreadEntry::ThreadLocal()->stream);
-  }
-
   void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final {
     return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
   }
@@ -306,9 +299,16 @@ class CUDATimerNode : public TimerNode {
   virtual void Start() {
     // This initial cudaEventRecord is sometimes pretty slow (~100us). Does
     // cudaEventRecord do some stream synchronization?
-    CUDA_CALL(cudaEventRecord(start_, CUDAThreadEntry::ThreadLocal()->stream));
+    int device_id;
+    CUDA_CALL(cudaGetDevice(&device_id));
+    stream_ = TVMFFIEnvGetCurrentStream(kDLCUDA, device_id);
+    CUDA_CALL(cudaEventRecord(start_, static_cast<cudaStream_t>(stream_)));
+  }
+  virtual void Stop() {
+    int device_id;
+    CUDA_CALL(cudaGetDevice(&device_id));
+    CUDA_CALL(cudaEventRecord(stop_, static_cast<cudaStream_t>(stream_)));
   }
-  virtual void Stop() { CUDA_CALL(cudaEventRecord(stop_, 
CUDAThreadEntry::ThreadLocal()->stream)); }
   virtual int64_t SyncAndGetElapsedNanos() {
     CUDA_CALL(cudaEventSynchronize(stop_));
     float milliseconds = 0;
@@ -330,6 +330,7 @@ class CUDATimerNode : public TimerNode {
  private:
   cudaEvent_t start_;
   cudaEvent_t stop_;
+  TVMStreamHandle stream_;
 };
 
 TVM_FFI_STATIC_INIT_BLOCK({
@@ -351,8 +352,13 @@ TVM_FFI_STATIC_INIT_BLOCK({
   namespace refl = tvm::ffi::reflection;
   refl::GlobalDef()
       .def("runtime.GetCudaFreeMemory", GetCudaFreeMemory)
-      .def("runtime.get_cuda_stream",
-           []() { return 
static_cast<void*>(CUDAThreadEntry::ThreadLocal()->stream); });
+      .def("runtime.get_cuda_stream", []() {
+        // TODO(tvm-team): remove once confirms all dep such as flashinfer
+        // migrated to TVMFFIEnvGetCurrentStream
+        int device_id;
+        CUDA_CALL(cudaGetDevice(&device_id));
+        return static_cast<void*>(TVMFFIEnvGetCurrentStream(kDLCUDA, 
device_id));
+      });
 });
 
 TVM_DLL int GetCudaDeviceCount() {
diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc
index 5a4e682da8..eb3bee4757 100644
--- a/src/runtime/cuda/cuda_module.cc
+++ b/src/runtime/cuda/cuda_module.cc
@@ -25,6 +25,7 @@
 #include <cuda.h>
 #include <cuda_runtime.h>
 #include <dmlc/memory_io.h>
+#include <tvm/ffi/extra/c_env_api.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 
@@ -198,7 +199,7 @@ class CUDAWrappedFunc {
         }
       }
     }
-    CUstream strm = 
static_cast<CUstream>(CUDAThreadEntry::ThreadLocal()->stream);
+    CUstream strm = static_cast<CUstream>(TVMFFIEnvGetCurrentStream(kDLCUDA, 
device_id));
     CUresult result = cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), 
wl.grid_dim(1),
                                      wl.grid_dim(2), wl.block_dim(0), 
wl.block_dim(1),
                                      wl.block_dim(2), wl.dyn_shmem_size, strm, 
void_args, nullptr);
diff --git a/src/runtime/cuda/l2_cache_flush.cc 
b/src/runtime/cuda/l2_cache_flush.cc
index 9427a6a3ee..0c7f939181 100644
--- a/src/runtime/cuda/l2_cache_flush.cc
+++ b/src/runtime/cuda/l2_cache_flush.cc
@@ -19,6 +19,7 @@
 #include "../../../3rdparty/nvbench/l2_cache_flush.h"
 
 #include <dmlc/thread_local.h>
+#include <tvm/ffi/extra/c_env_api.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/runtime/device_api.h>
@@ -37,7 +38,9 @@ TVM_FFI_STATIC_INIT_BLOCK({
   namespace refl = tvm::ffi::reflection;
   refl::GlobalDef().def_packed("l2_cache_flush_cuda", [](ffi::PackedArgs args, 
ffi::Any* rv) {
     ICHECK(L2Flush::ThreadLocal() != nullptr) << "L2Flush::ThreadLocal do not 
exist.";
-    cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream;
+    int device_id;
+    CUDA_CALL(cudaGetDevice(&device_id));
+    cudaStream_t stream = 
static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id));
     L2Flush::ThreadLocal()->Flush(stream);
   });
 });
diff --git a/src/runtime/device_api.cc b/src/runtime/device_api.cc
index ae85f9ce53..31006069a2 100644
--- a/src/runtime/device_api.cc
+++ b/src/runtime/device_api.cc
@@ -164,7 +164,13 @@ TVMStreamHandle DeviceAPI::CreateStream(Device dev) { 
return nullptr; }
 
 void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {}
 
-TVMStreamHandle DeviceAPI::GetCurrentStream(Device dev) { return nullptr; }
+void DeviceAPI::SetStream(Device dev, TVMStreamHandle stream) {
+  TVM_FFI_CHECK_SAFE_CALL(TVMFFIEnvSetStream(dev.device_type, dev.device_id, 
stream, nullptr));
+}
+
+TVMStreamHandle DeviceAPI::GetCurrentStream(Device dev) {
+  return TVMFFIEnvGetCurrentStream(dev.device_type, dev.device_id);
+}
 
 void DeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, 
TVMStreamHandle event_dst) {
 }
diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h
index 138d312dd4..f10489826a 100644
--- a/src/runtime/metal/metal_common.h
+++ b/src/runtime/metal/metal_common.h
@@ -168,8 +168,6 @@ class MetalWorkspace final : public DeviceAPI {
   TVMStreamHandle CreateStream(Device dev) final;
   void FreeStream(Device dev, TVMStreamHandle stream) final;
   void StreamSync(Device dev, TVMStreamHandle stream) final;
-  void SetStream(Device dev, TVMStreamHandle stream) final;
-  TVMStreamHandle GetCurrentStream(Device dev) final;
   void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final;
   void FreeWorkspace(Device dev, void* data) final;
   void ReinitializeDefaultStreams();
diff --git a/src/runtime/metal/metal_device_api.mm 
b/src/runtime/metal/metal_device_api.mm
index ba2f69b8e7..2a8544f6f1 100644
--- a/src/runtime/metal/metal_device_api.mm
+++ b/src/runtime/metal/metal_device_api.mm
@@ -312,17 +312,6 @@ void MetalWorkspace::StreamSync(Device dev, 
TVMStreamHandle stream) {
   };
 }
 
-void MetalWorkspace::SetStream(Device dev, TVMStreamHandle stream) {
-  ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << 
dev.device_id;
-  ICHECK(stream != nullptr);
-  MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = stream;
-}
-
-TVMStreamHandle MetalWorkspace::GetCurrentStream(Device dev) {
-  ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << 
dev.device_id;
-  return MetalThreadEntry::ThreadLocal()->stream[dev.device_id];
-}
-
 void* MetalWorkspace::AllocWorkspace(Device dev, size_t size, DLDataType 
type_hint) {
   return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
 }
diff --git a/src/runtime/rocm/rocm_device_api.cc 
b/src/runtime/rocm/rocm_device_api.cc
index c8842f7f53..9692b811a4 100644
--- a/src/runtime/rocm/rocm_device_api.cc
+++ b/src/runtime/rocm/rocm_device_api.cc
@@ -24,6 +24,7 @@
 #include <dmlc/thread_local.h>
 #include <hip/hip_runtime_api.h>
 #include <hsa/hsa.h>
+#include <tvm/ffi/extra/c_env_api.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/runtime/device_api.h>
@@ -214,14 +215,6 @@ class ROCMDeviceAPI final : public DeviceAPI {
     ROCM_CALL(hipStreamSynchronize(static_cast<hipStream_t>(stream)));
   }
 
-  void SetStream(Device dev, TVMStreamHandle stream) final {
-    ROCMThreadEntry::ThreadLocal()->stream = static_cast<hipStream_t>(stream);
-  }
-
-  TVMStreamHandle GetCurrentStream(Device dev) final {
-    return 
static_cast<TVMStreamHandle>(ROCMThreadEntry::ThreadLocal()->stream);
-  }
-
   void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final {
     return ROCMThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
   }
@@ -269,9 +262,16 @@ TVM_FFI_STATIC_INIT_BLOCK({
 class ROCMTimerNode : public TimerNode {
  public:
   virtual void Start() {
-    ROCM_CALL(hipEventRecord(start_, ROCMThreadEntry::ThreadLocal()->stream));
+    int device_id;
+    ROCM_CALL(hipGetDevice(&device_id));
+    stream_ = TVMFFIEnvGetCurrentStream(kDLROCM, device_id);
+    ROCM_CALL(hipEventRecord(start_, static_cast<hipStream_t>(stream_)));
+  }
+  virtual void Stop() {
+    int device_id;
+    ROCM_CALL(hipGetDevice(&device_id));
+    ROCM_CALL(hipEventRecord(stop_, static_cast<hipStream_t>(stream_)));
   }
-  virtual void Stop() { ROCM_CALL(hipEventRecord(stop_, 
ROCMThreadEntry::ThreadLocal()->stream)); }
   virtual int64_t SyncAndGetElapsedNanos() {
     ROCM_CALL(hipEventSynchronize(stop_));
     float milliseconds = 0;
@@ -293,14 +293,18 @@ class ROCMTimerNode : public TimerNode {
  private:
   hipEvent_t start_;
   hipEvent_t stop_;
+  TVMStreamHandle stream_;
 };
 
 TVM_FFI_STATIC_INIT_BLOCK({
   namespace refl = tvm::ffi::reflection;
   refl::GlobalDef()
       .def("profiling.timer.rocm", [](Device dev) { return 
Timer(make_object<ROCMTimerNode>()); })
-      .def("runtime.get_rocm_stream",
-           []() { return 
static_cast<void*>(ROCMThreadEntry::ThreadLocal()->stream); });
+      .def("runtime.get_rocm_stream", []() {
+        int device_id;
+        ROCM_CALL(hipGetDevice(&device_id));
+        return static_cast<void*>(TVMFFIEnvGetCurrentStream(kDLROCM, 
device_id));
+      });
 });
 
 }  // namespace runtime
diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc
index 13b14e13e0..f6beaca210 100644
--- a/src/runtime/rocm/rocm_module.cc
+++ b/src/runtime/rocm/rocm_module.cc
@@ -24,6 +24,7 @@
 
 #include <dmlc/memory_io.h>
 #include <hip/hip_runtime_api.h>
+#include <tvm/ffi/extra/c_env_api.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 
@@ -171,7 +172,7 @@ class ROCMWrappedFunc {
       fcache_[device_id] = m_->GetFunc(device_id, func_name_);
     }
 
-    hipStream_t strm = 
static_cast<hipStream_t>(ROCMThreadEntry::ThreadLocal()->stream);
+    hipStream_t strm = 
static_cast<hipStream_t>(TVMFFIEnvGetCurrentStream(kDLROCM, device_id));
 
     ThreadWorkLoad wl = launch_param_config_.Extract(args);
     void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args, 
HIP_LAUNCH_PARAM_BUFFER_SIZE,
diff --git a/src/runtime/vm/cuda/cuda_graph_builtin.cc 
b/src/runtime/vm/cuda/cuda_graph_builtin.cc
index 691246c3bf..d7ccff66a0 100644
--- a/src/runtime/vm/cuda/cuda_graph_builtin.cc
+++ b/src/runtime/vm/cuda/cuda_graph_builtin.cc
@@ -23,6 +23,7 @@
  */
 
 #include <tvm/ffi/container/array.h>
+#include <tvm/ffi/extra/c_env_api.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/runtime/vm/vm.h>
@@ -114,18 +115,20 @@ class ScopedCUDAStream {
 
 class CUDACaptureStream {
  public:
-  explicit CUDACaptureStream(cudaGraph_t* graph)
-      : prev_default_stream_(CUDAThreadEntry::ThreadLocal()->stream), 
output_graph_(graph) {
-    CUDAThreadEntry::ThreadLocal()->stream = capture_stream_;
-
+  explicit CUDACaptureStream(cudaGraph_t* graph) : output_graph_(graph) {
+    CUDA_CALL(cudaGetDevice(&device_id_));
+    TVM_FFI_CHECK_SAFE_CALL(
+        TVMFFIEnvSetStream(kDLCUDA, device_id_, capture_stream_,
+                           
reinterpret_cast<TVMFFIStreamHandle*>(&prev_default_stream_)));
     CUDA_CALL(cudaStreamBeginCapture(capture_stream_, 
cudaStreamCaptureModeGlobal));
   }
-  ~CUDACaptureStream() {
+  ~CUDACaptureStream() noexcept(false) {
     cudaStreamEndCapture(capture_stream_, output_graph_);
-    CUDAThreadEntry::ThreadLocal()->stream = prev_default_stream_;
+    TVM_FFI_CHECK_SAFE_CALL(TVMFFIEnvSetStream(kDLCUDA, device_id_, 
prev_default_stream_, nullptr));
   }
 
  private:
+  int device_id_;
   cudaStream_t prev_default_stream_;
   ScopedCUDAStream capture_stream_;
 
@@ -155,7 +158,10 @@ class CUDAGraphExtensionNode : public VMExtensionNode {
     if (auto it = capture_cache_.find(entry_key); it != capture_cache_.end()) {
       // Launch CUDA graph
       const auto& [states, exec] = it->second;
-      CUDA_CALL(cudaGraphLaunch(exec, CUDAThreadEntry::ThreadLocal()->stream));
+      int device_id;
+      CUDA_CALL(cudaGetDevice(&device_id));
+      CUDA_CALL(cudaGraphLaunch(
+          exec, static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, 
device_id))));
       return states;
     }
 
diff --git a/src/runtime/vulkan/vulkan_device_api.cc 
b/src/runtime/vulkan/vulkan_device_api.cc
index 09c3d522b0..023d34e68b 100644
--- a/src/runtime/vulkan/vulkan_device_api.cc
+++ b/src/runtime/vulkan/vulkan_device_api.cc
@@ -332,12 +332,6 @@ void VulkanDeviceAPI::StreamSync(Device dev, 
TVMStreamHandle stream) {
   device(dev.device_id).ThreadLocalStream().Synchronize();
 }
 
-void VulkanDeviceAPI::SetStream(Device dev, TVMStreamHandle stream) {
-  ICHECK_EQ(stream, static_cast<void*>(nullptr));
-}
-
-TVMStreamHandle VulkanDeviceAPI::GetCurrentStream(Device dev) { return 
nullptr; }
-
 void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, 
void* to,
                                      size_t to_offset, size_t size, Device 
dev_from, Device dev_to,
                                      DLDataType type_hint, TVMStreamHandle 
stream) {
diff --git a/src/runtime/vulkan/vulkan_device_api.h 
b/src/runtime/vulkan/vulkan_device_api.h
index 64ca0db701..5e9bfeb8c0 100644
--- a/src/runtime/vulkan/vulkan_device_api.h
+++ b/src/runtime/vulkan/vulkan_device_api.h
@@ -61,8 +61,6 @@ class VulkanDeviceAPI final : public DeviceAPI {
   void FreeStream(Device dev, TVMStreamHandle stream) final;
   void SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle 
event_dst) final;
   void StreamSync(Device dev, TVMStreamHandle stream) final;
-  void SetStream(Device dev, TVMStreamHandle stream) final;
-  TVMStreamHandle GetCurrentStream(Device dev) final;
 
  protected:
   void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t 
to_offset, size_t size,
diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc
index cd50bc0679..eb14a7b7d7 100644
--- a/web/emcc/webgpu_runtime.cc
+++ b/web/emcc/webgpu_runtime.cc
@@ -118,10 +118,6 @@ class WebGPUDeviceAPI : public DeviceAPI {
     (*func)();
   }
 
-  void SetStream(Device dev, TVMStreamHandle stream) final { LOG(FATAL) << 
"Not implemented"; }
-
-  TVMStreamHandle GetCurrentStream(Device dev) final { LOG(FATAL) << "Not 
implemented"; }
-
   void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final {
     return WebGPUThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
   }

Reply via email to