This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 524ec5f1b0 [Runtime] Use cudaGetDeviceCount to check if device exists
(#16377)
524ec5f1b0 is described below
commit 524ec5f1b03b54d796160ca1eca5804edbe38b3e
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Jan 10 09:47:02 2024 -0800
[Runtime] Use cudaGetDeviceCount to check if device exists (#16377)
Using `cudaDeviceGetAttribute` will set the global error code when the
device doesn't exist and will impact subsequent CUDA API calls.
---
src/runtime/cuda/cuda_device_api.cc | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/src/runtime/cuda/cuda_device_api.cc
b/src/runtime/cuda/cuda_device_api.cc
index 21416f619f..769f01063f 100644
--- a/src/runtime/cuda/cuda_device_api.cc
+++ b/src/runtime/cuda/cuda_device_api.cc
@@ -42,8 +42,9 @@ class CUDADeviceAPI final : public DeviceAPI {
int value = 0;
switch (kind) {
case kExist:
- value = (cudaDeviceGetAttribute(&value, cudaDevAttrMaxThreadsPerBlock,
dev.device_id) ==
- cudaSuccess);
+ int count;
+ CUDA_CALL(cudaGetDeviceCount(&count));
+ value = static_cast<int>(dev.device_id < count);
break;
case kMaxThreadsPerBlock: {
CUDA_CALL(cudaDeviceGetAttribute(&value,
cudaDevAttrMaxThreadsPerBlock, dev.device_id));