This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new cdfd041 Add cuda DeviceGuard (#287)
cdfd041 is described below
commit cdfd04109f74a2115c909302f4adf90536bce865
Author: Yaxing Cai <[email protected]>
AuthorDate: Tue Nov 25 19:19:53 2025 -0800
Add cuda DeviceGuard (#287)
This PR adds the `DeviceGuard`, as `c10::cuda::CUDAGuard` in pytorch
---
include/tvm/ffi/extra/cuda/base.h | 54 +++++++++++++++++++++
include/tvm/ffi/extra/cuda/cubin_launcher.h | 26 ++--------
include/tvm/ffi/extra/cuda/device_guard.h | 74 +++++++++++++++++++++++++++++
3 files changed, 132 insertions(+), 22 deletions(-)
diff --git a/include/tvm/ffi/extra/cuda/base.h
b/include/tvm/ffi/extra/cuda/base.h
new file mode 100644
index 0000000..810fa06
--- /dev/null
+++ b/include/tvm/ffi/extra/cuda/base.h
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/extra/cuda/base.h
+ * \brief CUDA base utilities.
+ */
+#ifndef TVM_FFI_EXTRA_CUDA_BASE_H_
+#define TVM_FFI_EXTRA_CUDA_BASE_H_
+
+#include <cuda_runtime.h>
+#include <tvm/ffi/error.h>
+
+namespace tvm {
+namespace ffi {
+
+/*!
+ * \brief Macro for checking CUDA runtime API errors.
+ *
+ * This macro checks the return value of CUDA runtime API calls and throws
+ * a RuntimeError with detailed error information if the call fails.
+ *
+ * \param stmt The CUDA runtime API call to check.
+ */
+#define TVM_FFI_CHECK_CUDA_ERROR(stmt)
\
+ do {
\
+ cudaError_t __err = (stmt);
\
+ if (__err != cudaSuccess) {
\
+ const char* __err_name = cudaGetErrorName(__err);
\
+ const char* __err_str = cudaGetErrorString(__err);
\
+ TVM_FFI_THROW(RuntimeError) << "CUDA Runtime Error: " << __err_name << "
(" \
+ << static_cast<int>(__err) << "): " <<
__err_str; \
+ }
\
+ } while (0)
+
+} // namespace ffi
+} // namespace tvm
+
+#endif // TVM_FFI_EXTRA_CUDA_BASE_H_
diff --git a/include/tvm/ffi/extra/cuda/cubin_launcher.h
b/include/tvm/ffi/extra/cuda/cubin_launcher.h
index f0cc924..72eadd2 100644
--- a/include/tvm/ffi/extra/cuda/cubin_launcher.h
+++ b/include/tvm/ffi/extra/cuda/cubin_launcher.h
@@ -26,12 +26,13 @@
* - Multi-GPU execution using CUDA primary contexts
* - Kernel parameter management and launch configuration
*/
-#ifndef TVM_FFI_EXTRA_CUBIN_LAUNCHER_H_
-#define TVM_FFI_EXTRA_CUBIN_LAUNCHER_H_
+#ifndef TVM_FFI_EXTRA_CUDA_CUBIN_LAUNCHER_H_
+#define TVM_FFI_EXTRA_CUDA_CUBIN_LAUNCHER_H_
#include <cuda_runtime.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/extra/c_env_api.h>
+#include <tvm/ffi/extra/cuda/base.h>
#include <tvm/ffi/string.h>
#include <cstdint>
@@ -40,25 +41,6 @@
namespace tvm {
namespace ffi {
-/*!
- * \brief Macro for checking CUDA runtime API errors.
- *
- * This macro checks the return value of CUDA runtime API calls and throws
- * a RuntimeError with detailed error information if the call fails.
- *
- * \param stmt The CUDA runtime API call to check.
- */
-#define TVM_FFI_CHECK_CUDA_ERROR(stmt)
\
- do {
\
- cudaError_t __err = (stmt);
\
- if (__err != cudaSuccess) {
\
- const char* __err_name = cudaGetErrorName(__err);
\
- const char* __err_str = cudaGetErrorString(__err);
\
- TVM_FFI_THROW(RuntimeError) << "CUDA Runtime Error: " << __err_name << "
(" \
- << static_cast<int>(__err) << "): " <<
__err_str; \
- }
\
- } while (0)
-
/*!
* \brief A simple 3D dimension type for CUDA kernel launch configuration.
*
@@ -619,4 +601,4 @@ inline CubinKernel CubinModule::operator[](const char*
name) { return GetKernel(
} // namespace ffi
} // namespace tvm
-#endif // TVM_FFI_EXTRA_CUBIN_LAUNCHER_H_
+#endif // TVM_FFI_EXTRA_CUDA_CUBIN_LAUNCHER_H_
diff --git a/include/tvm/ffi/extra/cuda/device_guard.h
b/include/tvm/ffi/extra/cuda/device_guard.h
new file mode 100644
index 0000000..083580f
--- /dev/null
+++ b/include/tvm/ffi/extra/cuda/device_guard.h
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/extra/cuda/device_guard.h
+ * \brief Device guard structs.
+ */
+#ifndef TVM_FFI_EXTRA_CUDA_DEVICE_GUARD_H_
+#define TVM_FFI_EXTRA_CUDA_DEVICE_GUARD_H_
+
+#include <tvm/ffi/extra/cuda/base.h>
+
+namespace tvm {
+namespace ffi {
+
+/*!
+ * \brief CUDA Device Guard. On construction, it calls `cudaGetDevice` to set
the CUDA device to
+ * target index, and stores the original current device index. And on
destruction, it sets the
+ * current CUDA device back to original device index.
+ *
+ * Example usage:
+ * \code
+ * void kernel(ffi::TensorView x) {
+ * ffi::CUDADeviceGuard guard(x.device().device_id);
+ * ...
+ * }
+ * \endcode
+ */
+struct CUDADeviceGuard {
+ CUDADeviceGuard() = delete;
+ /*!
+ * \brief Constructor from a device index, and store the original device
index.
+ * \param device_index The device index to guard.
+ */
+ explicit CUDADeviceGuard(int device_index) {
+ target_device_index_ = device_index;
+ TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&original_device_index_));
+ if (target_device_index_ != original_device_index_) {
+ TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(device_index));
+ }
+ }
+
+ /*!
+ * \brief Destructor to set the current device index back to original one if
different.
+ */
+ ~CUDADeviceGuard() noexcept(false) {
+ if (original_device_index_ != target_device_index_) {
+ TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(original_device_index_));
+ }
+ }
+
+ private:
+ int original_device_index_;
+ int target_device_index_;
+};
+
+} // namespace ffi
+} // namespace tvm
+#endif // TVM_FFI_EXTRA_CUDA_DEVICE_GUARD_H_