This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8ad523811c [Thrust] Fix getting CUDA stream (#18220)
8ad523811c is described below

commit 8ad523811cc30401955089a4ec0fd73bbc18ab29
Author: Ruihang Lai <[email protected]>
AuthorDate: Thu Aug 21 17:02:15 2025 -0400

    [Thrust] Fix getting CUDA stream (#18220)
    
    This PR updates the `GetCUDAStream` in CUDA thrust integration
    to the latest `TVMFFIEnvGetCurrentStream` interface.
---
 src/runtime/contrib/thrust/thrust.cu | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/src/runtime/contrib/thrust/thrust.cu 
b/src/runtime/contrib/thrust/thrust.cu
index 45c1bcc7cf..1adf95f693 100644
--- a/src/runtime/contrib/thrust/thrust.cu
+++ b/src/runtime/contrib/thrust/thrust.cu
@@ -32,6 +32,7 @@
 #include <thrust/sequence.h>
 #include <thrust/sort.h>
 #include <tvm/ffi/dtype.h>
+#include <tvm/ffi/extra/c_env_api.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 
@@ -91,7 +92,10 @@ class WorkspaceMemoryResource : public 
thrust::mr::memory_resource<void*> {
 };
 
 auto get_thrust_exec_policy(WorkspaceMemoryResource* memory_resouce) {
-  return thrust::cuda::par_nosync(memory_resouce).on(GetCUDAStream());
+  int device_id;
+  CUDA_CALL(cudaGetDevice(&device_id));
+  cudaStream_t stream = 
static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id));
+  return thrust::cuda::par_nosync(memory_resouce).on(stream);
 }
 
 // Performs sorting along axis -1 and returns both sorted values and indices.

Reply via email to