This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 524ec5f1b0 [Runtime] Use cudaGetDeviceCount to check if device exists 
(#16377)
524ec5f1b0 is described below

commit 524ec5f1b03b54d796160ca1eca5804edbe38b3e
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Jan 10 09:47:02 2024 -0800

    [Runtime] Use cudaGetDeviceCount to check if device exists (#16377)
    
    Using `cudaDeviceGetAttribute` will set the global error code when the
    device doesn't exist and will impact subsequent CUDA API calls.
---
 src/runtime/cuda/cuda_device_api.cc | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/src/runtime/cuda/cuda_device_api.cc 
b/src/runtime/cuda/cuda_device_api.cc
index 21416f619f..769f01063f 100644
--- a/src/runtime/cuda/cuda_device_api.cc
+++ b/src/runtime/cuda/cuda_device_api.cc
@@ -42,8 +42,9 @@ class CUDADeviceAPI final : public DeviceAPI {
     int value = 0;
     switch (kind) {
       case kExist:
-        value = (cudaDeviceGetAttribute(&value, cudaDevAttrMaxThreadsPerBlock, 
dev.device_id) ==
-                 cudaSuccess);
+        int count;
+        CUDA_CALL(cudaGetDeviceCount(&count));
+        value = static_cast<int>(dev.device_id < count);
         break;
       case kMaxThreadsPerBlock: {
         CUDA_CALL(cudaDeviceGetAttribute(&value, 
cudaDevAttrMaxThreadsPerBlock, dev.device_id));

Reply via email to