This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8ad523811c [Thrust] Fix getting CUDA stream (#18220)
8ad523811c is described below
commit 8ad523811cc30401955089a4ec0fd73bbc18ab29
Author: Ruihang Lai <[email protected]>
AuthorDate: Thu Aug 21 17:02:15 2025 -0400
[Thrust] Fix getting CUDA stream (#18220)
This PR updates the `GetCUDAStream` in CUDA thrust integration
to the latest `TVMFFIEnvGetCurrentStream` interface.
---
src/runtime/contrib/thrust/thrust.cu | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)
diff --git a/src/runtime/contrib/thrust/thrust.cu
b/src/runtime/contrib/thrust/thrust.cu
index 45c1bcc7cf..1adf95f693 100644
--- a/src/runtime/contrib/thrust/thrust.cu
+++ b/src/runtime/contrib/thrust/thrust.cu
@@ -32,6 +32,7 @@
#include <thrust/sequence.h>
#include <thrust/sort.h>
#include <tvm/ffi/dtype.h>
+#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
@@ -91,7 +92,10 @@ class WorkspaceMemoryResource : public
thrust::mr::memory_resource<void*> {
};
auto get_thrust_exec_policy(WorkspaceMemoryResource* memory_resouce) {
- return thrust::cuda::par_nosync(memory_resouce).on(GetCUDAStream());
+ int device_id;
+ CUDA_CALL(cudaGetDevice(&device_id));
+ cudaStream_t stream =
static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id));
+ return thrust::cuda::par_nosync(memory_resouce).on(stream);
}
// Performs sorting along axis -1 and returns both sorted values and indices.