This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new cdfd041  Add cuda DeviceGuard (#287)
cdfd041 is described below

commit cdfd04109f74a2115c909302f4adf90536bce865
Author: Yaxing Cai <[email protected]>
AuthorDate: Tue Nov 25 19:19:53 2025 -0800

    Add cuda DeviceGuard (#287)
    
    This PR adds the `DeviceGuard`, as `c10::cuda::CUDAGuard` in pytorch
---
 include/tvm/ffi/extra/cuda/base.h           | 54 +++++++++++++++++++++
 include/tvm/ffi/extra/cuda/cubin_launcher.h | 26 ++--------
 include/tvm/ffi/extra/cuda/device_guard.h   | 74 +++++++++++++++++++++++++++++
 3 files changed, 132 insertions(+), 22 deletions(-)

diff --git a/include/tvm/ffi/extra/cuda/base.h 
b/include/tvm/ffi/extra/cuda/base.h
new file mode 100644
index 0000000..810fa06
--- /dev/null
+++ b/include/tvm/ffi/extra/cuda/base.h
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/extra/cuda/base.h
+ * \brief CUDA base utilities.
+ */
+#ifndef TVM_FFI_EXTRA_CUDA_BASE_H_
+#define TVM_FFI_EXTRA_CUDA_BASE_H_
+
+#include <cuda_runtime.h>
+#include <tvm/ffi/error.h>
+
+namespace tvm {
+namespace ffi {
+
+/*!
+ * \brief Macro for checking CUDA runtime API errors.
+ *
+ * This macro checks the return value of CUDA runtime API calls and throws
+ * a RuntimeError with detailed error information if the call fails.
+ *
+ * \param stmt The CUDA runtime API call to check.
+ */
+#define TVM_FFI_CHECK_CUDA_ERROR(stmt)                                         
     \
+  do {                                                                         
     \
+    cudaError_t __err = (stmt);                                                
     \
+    if (__err != cudaSuccess) {                                                
     \
+      const char* __err_name = cudaGetErrorName(__err);                        
     \
+      const char* __err_str = cudaGetErrorString(__err);                       
     \
+      TVM_FFI_THROW(RuntimeError) << "CUDA Runtime Error: " << __err_name << " 
("   \
+                                  << static_cast<int>(__err) << "): " << 
__err_str; \
+    }                                                                          
     \
+  } while (0)
+
+}  // namespace ffi
+}  // namespace tvm
+
+#endif  // TVM_FFI_EXTRA_CUDA_BASE_H_
diff --git a/include/tvm/ffi/extra/cuda/cubin_launcher.h 
b/include/tvm/ffi/extra/cuda/cubin_launcher.h
index f0cc924..72eadd2 100644
--- a/include/tvm/ffi/extra/cuda/cubin_launcher.h
+++ b/include/tvm/ffi/extra/cuda/cubin_launcher.h
@@ -26,12 +26,13 @@
  * - Multi-GPU execution using CUDA primary contexts
  * - Kernel parameter management and launch configuration
  */
-#ifndef TVM_FFI_EXTRA_CUBIN_LAUNCHER_H_
-#define TVM_FFI_EXTRA_CUBIN_LAUNCHER_H_
+#ifndef TVM_FFI_EXTRA_CUDA_CUBIN_LAUNCHER_H_
+#define TVM_FFI_EXTRA_CUDA_CUBIN_LAUNCHER_H_
 
 #include <cuda_runtime.h>
 #include <tvm/ffi/error.h>
 #include <tvm/ffi/extra/c_env_api.h>
+#include <tvm/ffi/extra/cuda/base.h>
 #include <tvm/ffi/string.h>
 
 #include <cstdint>
@@ -40,25 +41,6 @@
 namespace tvm {
 namespace ffi {
 
-/*!
- * \brief Macro for checking CUDA runtime API errors.
- *
- * This macro checks the return value of CUDA runtime API calls and throws
- * a RuntimeError with detailed error information if the call fails.
- *
- * \param stmt The CUDA runtime API call to check.
- */
-#define TVM_FFI_CHECK_CUDA_ERROR(stmt)                                         
     \
-  do {                                                                         
     \
-    cudaError_t __err = (stmt);                                                
     \
-    if (__err != cudaSuccess) {                                                
     \
-      const char* __err_name = cudaGetErrorName(__err);                        
     \
-      const char* __err_str = cudaGetErrorString(__err);                       
     \
-      TVM_FFI_THROW(RuntimeError) << "CUDA Runtime Error: " << __err_name << " 
("   \
-                                  << static_cast<int>(__err) << "): " << 
__err_str; \
-    }                                                                          
     \
-  } while (0)
-
 /*!
  * \brief A simple 3D dimension type for CUDA kernel launch configuration.
  *
@@ -619,4 +601,4 @@ inline CubinKernel CubinModule::operator[](const char* 
name) { return GetKernel(
 }  // namespace ffi
 }  // namespace tvm
 
-#endif  // TVM_FFI_EXTRA_CUBIN_LAUNCHER_H_
+#endif  // TVM_FFI_EXTRA_CUDA_CUBIN_LAUNCHER_H_
diff --git a/include/tvm/ffi/extra/cuda/device_guard.h 
b/include/tvm/ffi/extra/cuda/device_guard.h
new file mode 100644
index 0000000..083580f
--- /dev/null
+++ b/include/tvm/ffi/extra/cuda/device_guard.h
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/extra/cuda/device_guard.h
+ * \brief Device guard structs.
+ */
+#ifndef TVM_FFI_EXTRA_CUDA_DEVICE_GUARD_H_
+#define TVM_FFI_EXTRA_CUDA_DEVICE_GUARD_H_
+
+#include <tvm/ffi/extra/cuda/base.h>
+
+namespace tvm {
+namespace ffi {
+
+/*!
+ * \brief CUDA Device Guard. On construction, it calls `cudaGetDevice` to set 
the CUDA device to
+ * target index, and stores the original current device index. And on 
destruction, it sets the
+ * current CUDA device back to original device index.
+ *
+ * Example usage:
+ * \code
+ * void kernel(ffi::TensorView x) {
+ *   ffi::CUDADeviceGuard guard(x.device().device_id);
+ *   ...
+ * }
+ * \endcode
+ */
+struct CUDADeviceGuard {
+  CUDADeviceGuard() = delete;
+  /*!
+   * \brief Constructor from a device index, and store the original device 
index.
+   * \param device_index The device index to guard.
+   */
+  explicit CUDADeviceGuard(int device_index) {
+    target_device_index_ = device_index;
+    TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&original_device_index_));
+    if (target_device_index_ != original_device_index_) {
+      TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(device_index));
+    }
+  }
+
+  /*!
+   * \brief Destructor to set the current device index back to original one if 
different.
+   */
+  ~CUDADeviceGuard() noexcept(false) {
+    if (original_device_index_ != target_device_index_) {
+      TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(original_device_index_));
+    }
+  }
+
+ private:
+  int original_device_index_;
+  int target_device_index_;
+};
+
+}  // namespace ffi
+}  // namespace tvm
+#endif  // TVM_FFI_EXTRA_CUDA_DEVICE_GUARD_H_

Reply via email to