This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5550fb33cd [REFACTOR][CUDA] Phase out cuda_common.h (#19770)
5550fb33cd is described below

commit 5550fb33cd4e80290e79f937b58532fd9597538a
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Jun 15 09:58:53 2026 -0400

    [REFACTOR][CUDA] Phase out cuda_common.h (#19770)
    
    ## Summary
    
    The CUDA runtime can rely on the shared tvm-ffi CUDA error helper
    instead of carrying a TVM-local common header. This keeps CUDA error
    handling aligned with the FFI CUDA support and removes the remaining
    reason for src/backend/cuda/runtime/cuda_common.h.
    
    Main changes:
    
    - Move CUDA workspace thread-local state into cuda_device_api.cc.
    - Replace CUDA_CALL use sites with TVM_FFI_CHECK_CUDA_ERROR from
    tvm/ffi/extra/cuda/base.h.
    - Inline the remaining CUDA driver checks at their call sites and delete
    cuda_common.h.
---
 src/backend/cuda/runtime/cuda_common.h             |  69 ----------
 src/backend/cuda/runtime/cuda_device_api.cc        | 140 ++++++++++++---------
 src/backend/cuda/runtime/cuda_module.cc            |  16 ++-
 .../extra/contrib/cublas/cublas_json_runtime.cc    |   4 +-
 src/runtime/extra/contrib/cublas/cublas_utils.cc   |   5 +-
 src/runtime/extra/contrib/cudnn/conv_backward.cc   |   4 +-
 src/runtime/extra/contrib/cudnn/conv_forward.cc    |   2 +-
 .../contrib/cudnn/cudnn_frontend/attention.cc      |   4 +-
 .../extra/contrib/cudnn/cudnn_json_runtime.cc      |   4 +-
 src/runtime/extra/contrib/cudnn/cudnn_utils.cc     |   2 +-
 src/runtime/extra/contrib/cudnn/cudnn_utils.h      |   3 +-
 src/runtime/extra/contrib/cudnn/softmax.cc         |   2 +-
 src/runtime/extra/contrib/curand/curand.cc         |   2 +-
 .../cutlass/fp16_group_gemm_runner_sm100.cuh       |   3 +-
 .../cutlass/fp16_group_gemm_runner_sm90.cuh        |   3 +-
 .../fp8_groupwise_scaled_gemm_runner_sm100.cuh     |   4 +-
 .../fp8_groupwise_scaled_gemm_runner_sm90.cuh      |   4 +-
 ...p8_groupwise_scaled_group_gemm_runner_sm100.cuh |   4 +-
 src/runtime/extra/contrib/cutlass/gemm_runner.cuh  |   3 +-
 src/runtime/extra/contrib/nvshmem/dist_gemm.cu     |   7 +-
 src/runtime/extra/contrib/nvshmem/init.cc          |   7 +-
 .../extra/contrib/nvshmem/memory_allocator.cc      |   2 +-
 .../extra/contrib/tensorrt/tensorrt_calibrator.h   |  19 +--
 src/runtime/extra/contrib/thrust/thrust.cu         |   5 +-
 .../extra/disco/cuda_ipc/cuda_ipc_memory.cc        |  39 +++---
 src/runtime/extra/disco/nccl/nccl_context.h        |  17 ++-
 src/runtime/vm/cuda/cuda_graph_builtin.cc          |  18 +--
 tests/python/relax/test_vm_builtin.py              |   2 +-
 28 files changed, 182 insertions(+), 212 deletions(-)

diff --git a/src/backend/cuda/runtime/cuda_common.h 
b/src/backend/cuda/runtime/cuda_common.h
deleted file mode 100644
index 183a2e8702..0000000000
--- a/src/backend/cuda/runtime/cuda_common.h
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file cuda_common.h
- * \brief Common utilities for CUDA
- */
-#ifndef TVM_RUNTIME_CUDA_CUDA_COMMON_H_
-#define TVM_RUNTIME_CUDA_CUDA_COMMON_H_
-
-#include <cuda_runtime.h>
-#include <tvm/ffi/function.h>
-
-#include <string>
-
-#include "../../../runtime/workspace_pool.h"
-
-namespace tvm {
-namespace runtime {
-
-#define CUDA_DRIVER_CALL(x)                                             \
-  {                                                                     \
-    CUresult result = x;                                                \
-    if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { \
-      const char* msg;                                                  \
-      cuGetErrorName(result, &msg);                                     \
-      TVM_FFI_THROW(CUDAError) << "" #x " failed with error: " << msg;  \
-    }                                                                   \
-  }
-
-#ifndef CUDA_CALL
-#define CUDA_CALL(func)                                               \
-  {                                                                   \
-    cudaError_t e = (func);                                           \
-    TVM_FFI_ICHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
-        << "CUDA: " << cudaGetErrorString(e);                         \
-  }
-#endif
-
-/*! \brief Thread local workspace */
-class CUDAThreadEntry {
- public:
-  /*! \brief thread local pool*/
-  WorkspacePool pool;
-  /*! \brief constructor */
-  CUDAThreadEntry();
-  // get the threadlocal workspace
-  static CUDAThreadEntry* ThreadLocal();
-};
-
-}  // namespace runtime
-}  // namespace tvm
-#endif  // TVM_RUNTIME_CUDA_CUDA_COMMON_H_
diff --git a/src/backend/cuda/runtime/cuda_device_api.cc 
b/src/backend/cuda/runtime/cuda_device_api.cc
index 969f40a081..68ae39de56 100644
--- a/src/backend/cuda/runtime/cuda_device_api.cc
+++ b/src/backend/cuda/runtime/cuda_device_api.cc
@@ -24,6 +24,7 @@
 #include <cuda.h>
 #include <cuda_runtime.h>
 #include <tvm/ffi/extra/c_env_api.h>
+#include <tvm/ffi/extra/cuda/base.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/runtime/device_api.h>
@@ -32,14 +33,26 @@
 
 #include <cstring>
 
-#include "cuda_common.h"
+#include "../../../runtime/workspace_pool.h"
 
 namespace tvm {
 namespace runtime {
 
+#ifndef CUDA_DRIVER_CALL
+#define CUDA_DRIVER_CALL(x)                                             \
+  {                                                                     \
+    CUresult result = x;                                                \
+    if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { \
+      const char* msg;                                                  \
+      cuGetErrorName(result, &msg);                                     \
+      TVM_FFI_THROW(CUDAError) << "" #x " failed with error: " << msg;  \
+    }                                                                   \
+  }
+#endif
+
 class CUDADeviceAPI final : public DeviceAPI {
  public:
-  void SetDevice(Device dev) final { CUDA_CALL(cudaSetDevice(dev.device_id)); }
+  void SetDevice(Device dev) final { 
TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(dev.device_id)); }
   void GetAttr(Device dev, DeviceAttrKind kind, ffi::Any* rv) final {
     int value = 0;
     switch (kind) {
@@ -50,23 +63,27 @@ class CUDADeviceAPI final : public DeviceAPI {
         break;
       }
       case kMaxThreadsPerBlock: {
-        CUDA_CALL(cudaDeviceGetAttribute(&value, 
cudaDevAttrMaxThreadsPerBlock, dev.device_id));
+        TVM_FFI_CHECK_CUDA_ERROR(
+            cudaDeviceGetAttribute(&value, cudaDevAttrMaxThreadsPerBlock, 
dev.device_id));
         break;
       }
       case kWarpSize: {
-        CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrWarpSize, 
dev.device_id));
+        TVM_FFI_CHECK_CUDA_ERROR(
+            cudaDeviceGetAttribute(&value, cudaDevAttrWarpSize, 
dev.device_id));
         break;
       }
       case kMaxSharedMemoryPerBlock: {
-        CUDA_CALL(
+        TVM_FFI_CHECK_CUDA_ERROR(
             cudaDeviceGetAttribute(&value, cudaDevAttrMaxSharedMemoryPerBlock, 
dev.device_id));
         break;
       }
       case kComputeVersion: {
         std::ostringstream os;
-        CUDA_CALL(cudaDeviceGetAttribute(&value, 
cudaDevAttrComputeCapabilityMajor, dev.device_id));
+        TVM_FFI_CHECK_CUDA_ERROR(
+            cudaDeviceGetAttribute(&value, cudaDevAttrComputeCapabilityMajor, 
dev.device_id));
         os << value << ".";
-        CUDA_CALL(cudaDeviceGetAttribute(&value, 
cudaDevAttrComputeCapabilityMinor, dev.device_id));
+        TVM_FFI_CHECK_CUDA_ERROR(
+            cudaDeviceGetAttribute(&value, cudaDevAttrComputeCapabilityMinor, 
dev.device_id));
         os << value;
         *rv = os.str();
         return;
@@ -79,18 +96,23 @@ class CUDADeviceAPI final : public DeviceAPI {
         return;
       }
       case kMaxClockRate: {
-        CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrClockRate, 
dev.device_id));
+        TVM_FFI_CHECK_CUDA_ERROR(
+            cudaDeviceGetAttribute(&value, cudaDevAttrClockRate, 
dev.device_id));
         break;
       }
       case kMultiProcessorCount: {
-        CUDA_CALL(cudaDeviceGetAttribute(&value, 
cudaDevAttrMultiProcessorCount, dev.device_id));
+        TVM_FFI_CHECK_CUDA_ERROR(
+            cudaDeviceGetAttribute(&value, cudaDevAttrMultiProcessorCount, 
dev.device_id));
         break;
       }
       case kMaxThreadDimensions: {
         int dims[3];
-        CUDA_CALL(cudaDeviceGetAttribute(&dims[0], cudaDevAttrMaxBlockDimX, 
dev.device_id));
-        CUDA_CALL(cudaDeviceGetAttribute(&dims[1], cudaDevAttrMaxBlockDimY, 
dev.device_id));
-        CUDA_CALL(cudaDeviceGetAttribute(&dims[2], cudaDevAttrMaxBlockDimZ, 
dev.device_id));
+        TVM_FFI_CHECK_CUDA_ERROR(
+            cudaDeviceGetAttribute(&dims[0], cudaDevAttrMaxBlockDimX, 
dev.device_id));
+        TVM_FFI_CHECK_CUDA_ERROR(
+            cudaDeviceGetAttribute(&dims[1], cudaDevAttrMaxBlockDimY, 
dev.device_id));
+        TVM_FFI_CHECK_CUDA_ERROR(
+            cudaDeviceGetAttribute(&dims[2], cudaDevAttrMaxBlockDimZ, 
dev.device_id));
 
         std::stringstream ss;  // use json string to return multiple int 
values;
         ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]";
@@ -98,7 +120,8 @@ class CUDADeviceAPI final : public DeviceAPI {
         return;
       }
       case kMaxRegistersPerBlock: {
-        CUDA_CALL(cudaDeviceGetAttribute(&value, 
cudaDevAttrMaxRegistersPerBlock, dev.device_id));
+        TVM_FFI_CHECK_CUDA_ERROR(
+            cudaDeviceGetAttribute(&value, cudaDevAttrMaxRegistersPerBlock, 
dev.device_id));
         break;
       }
       case kGcnArch:
@@ -112,20 +135,21 @@ class CUDADeviceAPI final : public DeviceAPI {
       case kL2CacheSizeBytes: {
         // Get size of device l2 cache size in bytes.
         int l2_size = 0;
-        CUDA_CALL(cudaDeviceGetAttribute(&l2_size, cudaDevAttrL2CacheSize, 
dev.device_id));
+        TVM_FFI_CHECK_CUDA_ERROR(
+            cudaDeviceGetAttribute(&l2_size, cudaDevAttrL2CacheSize, 
dev.device_id));
         *rv = l2_size;
         return;
       }
       case kTotalGlobalMemory: {
         cudaDeviceProp prop;
-        CUDA_CALL(cudaGetDeviceProperties(&prop, dev.device_id));
+        TVM_FFI_CHECK_CUDA_ERROR(cudaGetDeviceProperties(&prop, 
dev.device_id));
         int64_t total_global_memory = prop.totalGlobalMem;
         *rv = total_global_memory;
         return;
       }
       case kAvailableGlobalMemory: {
         size_t free_mem, total_mem;
-        CUDA_CALL(cudaMemGetInfo(&free_mem, &total_mem));
+        TVM_FFI_CHECK_CUDA_ERROR(cudaMemGetInfo(&free_mem, &total_mem));
         *rv = static_cast<int64_t>(free_mem);
         return;
       }
@@ -139,14 +163,14 @@ class CUDADeviceAPI final : public DeviceAPI {
     void* ret;
     if (dev.device_type == kDLCUDAHost) {
       VLOG(1) << "allocating " << nbytes << "bytes on host";
-      CUDA_CALL(cudaMallocHost(&ret, nbytes));
+      TVM_FFI_CHECK_CUDA_ERROR(cudaMallocHost(&ret, nbytes));
     } else {
-      CUDA_CALL(cudaSetDevice(dev.device_id));
+      TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(dev.device_id));
       size_t free_mem, total_mem;
-      CUDA_CALL(cudaMemGetInfo(&free_mem, &total_mem));
+      TVM_FFI_CHECK_CUDA_ERROR(cudaMemGetInfo(&free_mem, &total_mem));
       VLOG(1) << "allocating " << nbytes << " bytes on device, with " << 
free_mem
               << " bytes currently free out of " << total_mem << " bytes 
available";
-      CUDA_CALL(cudaMalloc(&ret, nbytes));
+      TVM_FFI_CHECK_CUDA_ERROR(cudaMalloc(&ret, nbytes));
     }
     return ret;
   }
@@ -172,11 +196,11 @@ class CUDADeviceAPI final : public DeviceAPI {
 
     if (dev.device_type == kDLCUDAHost) {
       VLOG(1) << "freeing host memory";
-      CUDA_CALL(cudaFreeHost(ptr));
+      TVM_FFI_CHECK_CUDA_ERROR(cudaFreeHost(ptr));
     } else {
-      CUDA_CALL(cudaSetDevice(dev.device_id));
+      TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(dev.device_id));
       VLOG(1) << "freeing device memory";
-      CUDA_CALL(cudaFree(ptr));
+      TVM_FFI_CHECK_CUDA_ERROR(cudaFree(ptr));
     }
   }
 
@@ -203,17 +227,17 @@ class CUDADeviceAPI final : public DeviceAPI {
     }
 
     if (dev_from.device_type == kDLCUDA && dev_to.device_type == kDLCUDA) {
-      CUDA_CALL(cudaSetDevice(dev_from.device_id));
+      TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(dev_from.device_id));
       if (dev_from.device_id == dev_to.device_id) {
         GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream);
       } else {
         cudaMemcpyPeerAsync(to, dev_to.device_id, from, dev_from.device_id, 
size, cu_stream);
       }
     } else if (dev_from.device_type == kDLCUDA && dev_to.device_type == 
kDLCPU) {
-      CUDA_CALL(cudaSetDevice(dev_from.device_id));
+      TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(dev_from.device_id));
       GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream);
     } else if (dev_from.device_type == kDLCPU && dev_to.device_type == 
kDLCUDA) {
-      CUDA_CALL(cudaSetDevice(dev_to.device_id));
+      TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(dev_to.device_id));
       GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream);
     } else {
       TVM_FFI_THROW(InternalError) << "expect copy from/to GPU or between GPU";
@@ -222,40 +246,40 @@ class CUDADeviceAPI final : public DeviceAPI {
 
  public:
   TVMStreamHandle CreateStream(Device dev) {
-    CUDA_CALL(cudaSetDevice(dev.device_id));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(dev.device_id));
     cudaStream_t retval;
-    CUDA_CALL(cudaStreamCreateWithFlags(&retval, cudaStreamNonBlocking));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&retval, 
cudaStreamNonBlocking));
     return static_cast<TVMStreamHandle>(retval);
   }
 
   void FreeStream(Device dev, TVMStreamHandle stream) {
-    CUDA_CALL(cudaSetDevice(dev.device_id));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(dev.device_id));
     cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
-    CUDA_CALL(cudaStreamDestroy(cu_stream));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaStreamDestroy(cu_stream));
   }
 
   void SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle 
event_dst) {
-    CUDA_CALL(cudaSetDevice(dev.device_id));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(dev.device_id));
     cudaStream_t src_stream = static_cast<cudaStream_t>(event_src);
     cudaStream_t dst_stream = static_cast<cudaStream_t>(event_dst);
     cudaEvent_t evt;
-    CUDA_CALL(cudaEventCreate(&evt));
-    CUDA_CALL(cudaEventRecord(evt, src_stream));
-    CUDA_CALL(cudaStreamWaitEvent(dst_stream, evt, 0));
-    CUDA_CALL(cudaEventDestroy(evt));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaEventCreate(&evt));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaEventRecord(evt, src_stream));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaStreamWaitEvent(dst_stream, evt, 0));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaEventDestroy(evt));
   }
 
   void StreamSync(Device dev, TVMStreamHandle stream) final {
-    CUDA_CALL(cudaSetDevice(dev.device_id));
-    CUDA_CALL(cudaStreamSynchronize(static_cast<cudaStream_t>(stream)));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(dev.device_id));
+    
TVM_FFI_CHECK_CUDA_ERROR(cudaStreamSynchronize(static_cast<cudaStream_t>(stream)));
   }
 
   void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final {
-    return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
+    return ThreadLocalWorkspacePool()->AllocWorkspace(dev, size);
   }
 
   void FreeWorkspace(Device dev, void* data) final {
-    CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(dev, data);
+    ThreadLocalWorkspacePool()->FreeWorkspace(dev, data);
   }
 
   bool SupportsDevicePointerArithmeticsOnHost() final { return true; }
@@ -268,17 +292,17 @@ class CUDADeviceAPI final : public DeviceAPI {
   }
 
  private:
+  static WorkspacePool* ThreadLocalWorkspacePool();
+
   static void GPUCopy(const void* from, void* to, size_t size, cudaMemcpyKind 
kind,
                       cudaStream_t stream) {
-    CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaMemcpyAsync(to, from, size, kind, stream));
   }
 };
 
-CUDAThreadEntry::CUDAThreadEntry() : pool(kDLCUDA, CUDADeviceAPI::Global()) {}
-
-CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() {
-  static thread_local CUDAThreadEntry inst;
-  return &inst;
+WorkspacePool* CUDADeviceAPI::ThreadLocalWorkspacePool() {
+  static thread_local WorkspacePool pool(kDLCUDA, CUDADeviceAPI::Global());
+  return &pool;
 }
 
 TVM_FFI_STATIC_INIT_BLOCK() {
@@ -301,28 +325,28 @@ class CUDATimerNode : public TimerNode {
     // This initial cudaEventRecord is sometimes pretty slow (~100us). Does
     // cudaEventRecord do some stream synchronization?
     int device_id;
-    CUDA_CALL(cudaGetDevice(&device_id));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&device_id));
     stream_ = TVMFFIEnvGetStream(kDLCUDA, device_id);
-    CUDA_CALL(cudaEventRecord(start_, static_cast<cudaStream_t>(stream_)));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaEventRecord(start_, 
static_cast<cudaStream_t>(stream_)));
   }
   virtual void Stop() {
     int device_id;
-    CUDA_CALL(cudaGetDevice(&device_id));
-    CUDA_CALL(cudaEventRecord(stop_, static_cast<cudaStream_t>(stream_)));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&device_id));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaEventRecord(stop_, 
static_cast<cudaStream_t>(stream_)));
   }
   virtual int64_t SyncAndGetElapsedNanos() {
-    CUDA_CALL(cudaEventSynchronize(stop_));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaEventSynchronize(stop_));
     float milliseconds = 0;
-    CUDA_CALL(cudaEventElapsedTime(&milliseconds, start_, stop_));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaEventElapsedTime(&milliseconds, start_, 
stop_));
     return milliseconds * 1e6;
   }
   virtual ~CUDATimerNode() {
-    CUDA_CALL(cudaEventDestroy(start_));
-    CUDA_CALL(cudaEventDestroy(stop_));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaEventDestroy(start_));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaEventDestroy(stop_));
   }
   CUDATimerNode() {
-    CUDA_CALL(cudaEventCreate(&start_));
-    CUDA_CALL(cudaEventCreate(&stop_));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaEventCreate(&start_));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaEventCreate(&stop_));
   }
   TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.cuda.CUDATimerNode", 
CUDATimerNode, TimerNode);
 
@@ -340,7 +364,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
 
 TVM_RUNTIME_DLL ffi::String GetCudaFreeMemory() {
   size_t free_mem, total_mem;
-  CUDA_CALL(cudaMemGetInfo(&free_mem, &total_mem));
+  TVM_FFI_CHECK_CUDA_ERROR(cudaMemGetInfo(&free_mem, &total_mem));
   std::stringstream ss;
   ss << "Current CUDA memory is " << free_mem << " bytes free out of " << 
total_mem
      << " bytes on device";
@@ -355,14 +379,14 @@ TVM_FFI_STATIC_INIT_BLOCK() {
         // TODO(tvm-team): remove once confirms all dep such as flashinfer
         // migrated to TVMFFIEnvGetStream
         int device_id;
-        CUDA_CALL(cudaGetDevice(&device_id));
+        TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&device_id));
         return static_cast<void*>(TVMFFIEnvGetStream(kDLCUDA, device_id));
       });
 }
 
 TVM_RUNTIME_DLL int GetCudaDeviceCount() {
   int count;
-  CUDA_CALL(cudaGetDeviceCount(&count));
+  TVM_FFI_CHECK_CUDA_ERROR(cudaGetDeviceCount(&count));
   return count;
 }
 
diff --git a/src/backend/cuda/runtime/cuda_module.cc 
b/src/backend/cuda/runtime/cuda_module.cc
index 8ea734da69..5a598c87cc 100644
--- a/src/backend/cuda/runtime/cuda_module.cc
+++ b/src/backend/cuda/runtime/cuda_module.cc
@@ -28,6 +28,7 @@
 #include <cuda_runtime.h>
 #include <tvm/ffi/cast.h>
 #include <tvm/ffi/extra/c_env_api.h>
+#include <tvm/ffi/extra/cuda/base.h>
 #include <tvm/ffi/extra/module.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
@@ -41,11 +42,22 @@
 #include "../../../runtime/pack_args.h"
 #include "../../../runtime/thread_storage_scope.h"
 #include "../../../support/bytes_io.h"
-#include "cuda_common.h"
 
 namespace tvm {
 namespace runtime {
 
+#ifndef CUDA_DRIVER_CALL
+#define CUDA_DRIVER_CALL(x)                                             \
+  {                                                                     \
+    CUresult result = x;                                                \
+    if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { \
+      const char* msg;                                                  \
+      cuGetErrorName(result, &msg);                                     \
+      TVM_FFI_THROW(CUDAError) << "" #x " failed with error: " << msg;  \
+    }                                                                   \
+  }
+#endif
+
 // Maximum number of GPU supported in CUDAModule (file-local).
 static constexpr const int kMaxNumGPUs = 32;
 
@@ -204,7 +216,7 @@ class CUDAWrappedFunc {
   // invoke the function with void arguments
   void operator()(ffi::PackedArgs args, ffi::Any* rv, void** void_args) const {
     int device_id;
-    CUDA_CALL(cudaGetDevice(&device_id));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&device_id));
     ThreadWorkLoad wl = launch_param_config_.Extract(args);
 
     if (fcache_[device_id] == nullptr) {
diff --git a/src/runtime/extra/contrib/cublas/cublas_json_runtime.cc 
b/src/runtime/extra/contrib/cublas/cublas_json_runtime.cc
index 6520753e11..34ba853505 100644
--- a/src/runtime/extra/contrib/cublas/cublas_json_runtime.cc
+++ b/src/runtime/extra/contrib/cublas/cublas_json_runtime.cc
@@ -24,6 +24,7 @@
 
 #include <tvm/ffi/cast.h>
 #include <tvm/ffi/extra/c_env_api.h>
+#include <tvm/ffi/extra/cuda/base.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/runtime/tensor.h>
@@ -32,7 +33,6 @@
 #include <string>
 #include <vector>
 
-#include "../../../../backend/cuda/runtime/cuda_common.h"
 #include "../json/json_node.h"
 #include "../json/json_runtime.h"
 #include "cublas_utils.h"
@@ -89,7 +89,7 @@ class CublasJSONRuntime : public JSONRuntimeBase {
     }
 
     if (device_id == -1) {
-      CUDA_CALL(cudaGetDevice(&device_id));
+      TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&device_id));
     }
     auto* entry_ptr = 
tvm::contrib::CuBlasLtThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id});
     cudaStream_t stream = 
static_cast<cudaStream_t>(TVMFFIEnvGetStream(kDLCUDA, device_id));
diff --git a/src/runtime/extra/contrib/cublas/cublas_utils.cc 
b/src/runtime/extra/contrib/cublas/cublas_utils.cc
index b8c239f13a..dcd115026c 100644
--- a/src/runtime/extra/contrib/cublas/cublas_utils.cc
+++ b/src/runtime/extra/contrib/cublas/cublas_utils.cc
@@ -23,10 +23,9 @@
 #include "cublas_utils.h"
 
 #include <tvm/ffi/extra/c_env_api.h>
+#include <tvm/ffi/extra/cuda/base.h>
 #include <tvm/ffi/function.h>
 
-#include "../../../../backend/cuda/runtime/cuda_common.h"
-
 namespace tvm {
 namespace contrib {
 
@@ -51,7 +50,7 @@ CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal(DLDevice 
curr_device) {
 CuBlasLtThreadEntry::CuBlasLtThreadEntry() {
   CHECK_CUBLAS_ERROR(cublasLtCreate(&handle));
   CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&matmul_pref_desc));
-  CUDA_CALL(cudaMalloc(&workspace_ptr, workspace_size));
+  TVM_FFI_CHECK_CUDA_ERROR(cudaMalloc(&workspace_ptr, workspace_size));
 }
 
 CuBlasLtThreadEntry::~CuBlasLtThreadEntry() {
diff --git a/src/runtime/extra/contrib/cudnn/conv_backward.cc 
b/src/runtime/extra/contrib/cudnn/conv_backward.cc
index bfc65baaff..ca711f4e78 100644
--- a/src/runtime/extra/contrib/cudnn/conv_backward.cc
+++ b/src/runtime/extra/contrib/cudnn/conv_backward.cc
@@ -67,7 +67,7 @@ void BackwardDataFindAlgo(int format, int dims, int groups, 
const int pad[], con
                           const int dx_dim[], const std::string& data_dtype,
                           const std::string& conv_dtype, bool verbose, 
ffi::Any* ret) {
   int device_id;
-  CUDA_CALL(cudaGetDevice(&device_id));
+  TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&device_id));
   CuDNNThreadEntry* entry_ptr = 
CuDNNThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id});
   const int full_dims = dims + 2;
   std::vector<int64_t> dy_dim_int64(full_dims);
@@ -146,7 +146,7 @@ void BackwardFilterFindAlgo(int format, int dims, int 
groups, const int pad[], c
                             const int dw_dim[], const std::string& data_dtype,
                             const std::string& conv_dtype, bool verbose, 
ffi::Any* ret) {
   int device_id;
-  CUDA_CALL(cudaGetDevice(&device_id));
+  TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&device_id));
   CuDNNThreadEntry* entry_ptr = 
CuDNNThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id});
   const int full_dims = dims + 2;
   std::vector<int64_t> x_dim_int64(full_dims);
diff --git a/src/runtime/extra/contrib/cudnn/conv_forward.cc 
b/src/runtime/extra/contrib/cudnn/conv_forward.cc
index 6c6fd7eb40..befbddd788 100644
--- a/src/runtime/extra/contrib/cudnn/conv_forward.cc
+++ b/src/runtime/extra/contrib/cudnn/conv_forward.cc
@@ -112,7 +112,7 @@ void FindAlgo(int format, int dims, int groups, const int 
pad[], const int strid
               const std::string& data_dtype, const std::string& conv_dtype, 
bool verbose,
               ffi::Any* ret) {
   int device_id;
-  CUDA_CALL(cudaGetDevice(&device_id));
+  TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&device_id));
   CuDNNThreadEntry* entry_ptr = 
CuDNNThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id});
   const int full_dims = dims + 2;
   std::vector<int64_t> x_dim_int64(full_dims);
diff --git a/src/runtime/extra/contrib/cudnn/cudnn_frontend/attention.cc 
b/src/runtime/extra/contrib/cudnn/cudnn_frontend/attention.cc
index f31a38bd47..54d16b5a39 100644
--- a/src/runtime/extra/contrib/cudnn/cudnn_frontend/attention.cc
+++ b/src/runtime/extra/contrib/cudnn/cudnn_frontend/attention.cc
@@ -24,10 +24,10 @@
 
 #include "./attention.h"
 
+#include <tvm/ffi/extra/cuda/base.h>
 #include <tvm/ffi/function.h>
 #include <tvm/runtime/device_api.h>
 
-#include "../../../../../backend/cuda/runtime/cuda_common.h"
 #include "../cudnn_utils.h"
 
 namespace tvm {
@@ -99,7 +99,7 @@ void CuDNNSDPARunnerNode::Init(int64_t batch, int64_t 
seq_len, int64_t num_heads
   TVM_FFI_ICHECK(stats == nullptr);
   o->set_output(true).set_dim({batch, num_heads, seq_len, 
head_size_v}).set_stride(o_stride);
   int device_id;
-  CUDA_CALL(cudaGetDevice(&device_id));
+  TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&device_id));
   CuDNNThreadEntry* entry_ptr = 
CuDNNThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id});
   CUDNN_FRONTEND_CALL(graph_->build(entry_ptr->handle, 
{cudnn_frontend::HeurMode_t::A}));
 }
diff --git a/src/runtime/extra/contrib/cudnn/cudnn_json_runtime.cc 
b/src/runtime/extra/contrib/cudnn/cudnn_json_runtime.cc
index a7cf1ec2b3..5b3756fd79 100644
--- a/src/runtime/extra/contrib/cudnn/cudnn_json_runtime.cc
+++ b/src/runtime/extra/contrib/cudnn/cudnn_json_runtime.cc
@@ -103,7 +103,7 @@ class cuDNNJSONRuntime : public JSONRuntimeBase {
 
   std::function<void()> GetConv2DExec(const JSONGraphNode& node) {
     int device_id;
-    CUDA_CALL(cudaGetDevice(&device_id));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&device_id));
     auto* entry_ptr = 
tvm::contrib::CuDNNThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id});
     auto op_name = node.GetOpName();
 
@@ -164,7 +164,7 @@ class cuDNNJSONRuntime : public JSONRuntimeBase {
     int algo = best_algo.cast<int>();
     std::function<void()> op_exec = [=, this]() {
       int device_id;
-      CUDA_CALL(cudaGetDevice(&device_id));
+      TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&device_id));
       cudaStream_t stream = 
static_cast<cudaStream_t>(TVMFFIEnvGetStream(kDLCUDA, device_id));
       CUDNN_CALL(cudnnSetStream(entry_ptr->handle, stream));
 
diff --git a/src/runtime/extra/contrib/cudnn/cudnn_utils.cc 
b/src/runtime/extra/contrib/cudnn/cudnn_utils.cc
index 5793febf01..5c34d4a2b0 100644
--- a/src/runtime/extra/contrib/cudnn/cudnn_utils.cc
+++ b/src/runtime/extra/contrib/cudnn/cudnn_utils.cc
@@ -270,7 +270,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
   refl::GlobalDef().def("tvm.contrib.cudnn.exists", []() -> bool {
     int device_id;
-    CUDA_CALL(cudaGetDevice(&device_id));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&device_id));
     return CuDNNThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id}, 
false)->exists();
   });
 }
diff --git a/src/runtime/extra/contrib/cudnn/cudnn_utils.h 
b/src/runtime/extra/contrib/cudnn/cudnn_utils.h
index 81849c0f0c..cef82134ad 100644
--- a/src/runtime/extra/contrib/cudnn/cudnn_utils.h
+++ b/src/runtime/extra/contrib/cudnn/cudnn_utils.h
@@ -26,12 +26,11 @@
 
 #include <cudnn.h>
 #include <tvm/ffi/error.h>
+#include <tvm/ffi/extra/cuda/base.h>
 #include <tvm/runtime/device_api.h>
 
 #include <string>
 
-#include "../../../../backend/cuda/runtime/cuda_common.h"
-
 namespace tvm {
 namespace contrib {
 
diff --git a/src/runtime/extra/contrib/cudnn/softmax.cc 
b/src/runtime/extra/contrib/cudnn/softmax.cc
index d494fb3349..f0562dcba0 100644
--- a/src/runtime/extra/contrib/cudnn/softmax.cc
+++ b/src/runtime/extra/contrib/cudnn/softmax.cc
@@ -41,7 +41,7 @@ void softmax_impl(cudnnSoftmaxAlgorithm_t alg, 
ffi::PackedArgs args, ffi::Any* r
   if (axis < 0) axis += ndim;
   TVM_FFI_ICHECK(axis >= 0 && axis < ndim);
   int device_id;
-  CUDA_CALL(cudaGetDevice(&device_id));
+  TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&device_id));
   CuDNNThreadEntry* entry_ptr = 
CuDNNThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id});
   entry_ptr->softmax_entry.data_type = 
CuDNNDataType::DLTypeToCuDNNType(x->dtype);
 
diff --git a/src/runtime/extra/contrib/curand/curand.cc 
b/src/runtime/extra/contrib/curand/curand.cc
index 5f8ff07583..b2fc753865 100644
--- a/src/runtime/extra/contrib/curand/curand.cc
+++ b/src/runtime/extra/contrib/curand/curand.cc
@@ -17,11 +17,11 @@
  * under the License.
  */
 #include <curand.h>
+#include <tvm/ffi/extra/cuda/base.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/runtime/base.h>
 
-#include "../../../../backend/cuda/runtime/cuda_common.h"
 #include "./helper_cuda_kernels.h"
 
 namespace tvm {
diff --git a/src/runtime/extra/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh 
b/src/runtime/extra/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh
index c30b34d0f4..0a3068c5d9 100644
--- a/src/runtime/extra/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh
+++ b/src/runtime/extra/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh
@@ -17,6 +17,7 @@
  * under the License.
  */
 
+#include <tvm/ffi/extra/cuda/base.h>
 #include <tvm/runtime/logging.h>
 
 #include <fstream>
@@ -25,8 +26,6 @@
 #include <variant>
 #include <vector>
 
-#include "../../../../backend/cuda/runtime/cuda_common.h"
-
 // clang-format off
 #include "cutlass/cutlass.h"
 
diff --git a/src/runtime/extra/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh 
b/src/runtime/extra/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh
index dcb26b9071..552b46b0a0 100644
--- a/src/runtime/extra/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh
+++ b/src/runtime/extra/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh
@@ -17,6 +17,7 @@
  * under the License.
  */
 
+#include <tvm/ffi/extra/cuda/base.h>
 #include <tvm/runtime/logging.h>
 
 #include <fstream>
@@ -25,8 +26,6 @@
 #include <variant>
 #include <vector>
 
-#include "../../../../backend/cuda/runtime/cuda_common.h"
-
 // clang-format off
 #include "cutlass/cutlass.h"
 
diff --git 
a/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh 
b/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh
index 9c1962ad98..26eb35e34e 100644
--- 
a/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh
+++ 
b/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh
@@ -17,6 +17,8 @@
  * under the License.
  */
 
+#include <tvm/ffi/extra/cuda/base.h>
+
 #include <fstream>
 #include <iostream>
 #include <sstream>
@@ -24,8 +26,6 @@
 #include <variant>
 #include <vector>
 
-#include "../../../../backend/cuda/runtime/cuda_common.h"
-
 // clang-format off
 #include "cutlass/cutlass.h"
 
diff --git 
a/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh 
b/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh
index 1955e47759..9307e33d7f 100644
--- 
a/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh
+++ 
b/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh
@@ -17,6 +17,8 @@
  * under the License.
  */
 
+#include <tvm/ffi/extra/cuda/base.h>
+
 #include <fstream>
 #include <iostream>
 #include <sstream>
@@ -24,8 +26,6 @@
 #include <variant>
 #include <vector>
 
-#include "../../../../backend/cuda/runtime/cuda_common.h"
-
 // clang-format off
 #include "cutlass/cutlass.h"
 
diff --git 
a/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh
 
b/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh
index a516f31d3c..01e9ca7b8a 100644
--- 
a/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh
+++ 
b/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh
@@ -17,14 +17,14 @@
  * under the License.
  */
 
+#include <tvm/ffi/extra/cuda/base.h>
+
 #include <fstream>
 #include <iostream>
 #include <sstream>
 #include <variant>
 #include <vector>
 
-#include "../../../../backend/cuda/runtime/cuda_common.h"
-
 // clang-format off
 #include "cutlass/cutlass.h"
 
diff --git a/src/runtime/extra/contrib/cutlass/gemm_runner.cuh 
b/src/runtime/extra/contrib/cutlass/gemm_runner.cuh
index 58e1c9fbd0..174699bcec 100644
--- a/src/runtime/extra/contrib/cutlass/gemm_runner.cuh
+++ b/src/runtime/extra/contrib/cutlass/gemm_runner.cuh
@@ -17,6 +17,7 @@
  * under the License.
  */
 
+#include <tvm/ffi/extra/cuda/base.h>
 #include <tvm/runtime/logging.h>
 
 #include <fstream>
@@ -25,8 +26,6 @@
 #include <variant>
 #include <vector>
 
-#include "../../../../backend/cuda/runtime/cuda_common.h"
-
 // clang-format off
 #include "cutlass/cutlass.h"
 
diff --git a/src/runtime/extra/contrib/nvshmem/dist_gemm.cu 
b/src/runtime/extra/contrib/nvshmem/dist_gemm.cu
index d975e14038..26cc55a442 100644
--- a/src/runtime/extra/contrib/nvshmem/dist_gemm.cu
+++ b/src/runtime/extra/contrib/nvshmem/dist_gemm.cu
@@ -19,12 +19,11 @@
 #include <nvshmem.h>
 #include <nvshmemx.h>
 #include <tvm/ffi/extra/c_env_api.h>
+#include <tvm/ffi/extra/cuda/base.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/runtime/disco/disco_worker.h>
 
-#include "../../../../backend/cuda/runtime/cuda_common.h"
-
 namespace tvm {
 namespace runtime {
 
@@ -62,7 +61,7 @@ TVMStreamHandle stream_create() {
   DiscoWorker* worker = ThreadLocalDiscoWorker::Get()->worker;
   TVM_FFI_ICHECK(worker != nullptr) << "NVSHMEM stream creation failed: worker 
is not initialized";
   cudaStream_t retval;
-  CUDA_CALL(cudaStreamCreateWithFlags(&retval, cudaStreamNonBlocking));
+  TVM_FFI_CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&retval, 
cudaStreamNonBlocking));
   return static_cast<TVMStreamHandle>(retval);
 }
 
@@ -100,7 +99,7 @@ void transfer_to_peers_reduce_scatter(Tensor semaphore, 
Tensor gemm_out, Tensor
                    stream);
     } else {
       int device_id;
-      CUDA_CALL(cudaGetDevice(&device_id));
+      TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&device_id));
       TVMStreamHandle main_stream = TVMFFIEnvGetStream(kDLCUDA, device_id);
       copy_to_peer(get_pointer(staging_buffer, ffi::Shape{my_rank, 0, 0}), 
to_rank,
                    get_pointer(gemm_out, ffi::Shape{to_rank * LOCAL_M, 0}), 
LOCAL_M * N * 2,
diff --git a/src/runtime/extra/contrib/nvshmem/init.cc 
b/src/runtime/extra/contrib/nvshmem/init.cc
index 698f3b6802..aeb71bb6a1 100644
--- a/src/runtime/extra/contrib/nvshmem/init.cc
+++ b/src/runtime/extra/contrib/nvshmem/init.cc
@@ -20,14 +20,13 @@
 #include <nvshmem.h>
 #include <nvshmemx.h>
 #include <tvm/ffi/extra/c_env_api.h>
+#include <tvm/ffi/extra/cuda/base.h>
 #include <tvm/ffi/extra/json.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/runtime/disco/disco_worker.h>
 #include <tvm/runtime/logging.h>
 
-#include "../../../../backend/cuda/runtime/cuda_common.h"
-
 namespace tvm {
 namespace runtime {
 
@@ -66,7 +65,7 @@ void InitNVSHMEM(ffi::Shape uid_64, int num_workers, int 
worker_id_start) {
   nvshmemx_set_attr_uniqueid_args(worker_id, num_workers, &uid, &attr);
   nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
   int mype_node = nvshmem_team_my_pe(NVSHMEMX_TEAM_NODE);
-  CUDA_CALL(cudaSetDevice(mype_node));
+  TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(mype_node));
   if (worker != nullptr) {
     if (worker->default_device.device_type == DLDeviceType::kDLCPU) {
       worker->default_device = Device{DLDeviceType::kDLCUDA, mype_node};
@@ -153,7 +152,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
       .def("runtime.nvshmem.cumodule_init", NVSHMEMXCumoduleInit)
       .def("runtime.disco.nvshmem.barrier_all_on_current_stream", []() {
         int device_id;
-        CUDA_CALL(cudaGetDevice(&device_id));
+        TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&device_id));
         TVMStreamHandle stream = TVMFFIEnvGetStream(kDLCUDA, device_id);
         NVSHMEMBarrierAllOnStream(stream);
       });
diff --git a/src/runtime/extra/contrib/nvshmem/memory_allocator.cc 
b/src/runtime/extra/contrib/nvshmem/memory_allocator.cc
index ee9afd0eca..cb6e3520c8 100644
--- a/src/runtime/extra/contrib/nvshmem/memory_allocator.cc
+++ b/src/runtime/extra/contrib/nvshmem/memory_allocator.cc
@@ -18,13 +18,13 @@
  */
 #include <nvshmem.h>
 #include <nvshmemx.h>
+#include <tvm/ffi/extra/cuda/base.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/runtime/memory/memory_manager.h>
 
 #include <thread>
 
-#include "../../../../backend/cuda/runtime/cuda_common.h"
 #include "../../../memory/pooled_allocator.h"
 #include "../../disco/utils.h"
 
diff --git a/src/runtime/extra/contrib/tensorrt/tensorrt_calibrator.h 
b/src/runtime/extra/contrib/tensorrt/tensorrt_calibrator.h
index 4d92afb234..408d50cc7e 100755
--- a/src/runtime/extra/contrib/tensorrt/tensorrt_calibrator.h
+++ b/src/runtime/extra/contrib/tensorrt/tensorrt_calibrator.h
@@ -23,10 +23,11 @@
 #ifndef TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_CALIBRATOR_H_
 #define TVM_RUNTIME_CONTRIB_TENSORRT_TENSORRT_CALIBRATOR_H_
 
+#include <tvm/ffi/extra/cuda/base.h>
+
 #include <string>
 #include <vector>
 
-#include "../../../../backend/cuda/runtime/cuda_common.h"
 #include "NvInfer.h"
 
 namespace tvm {
@@ -46,7 +47,7 @@ class TensorRTCalibrator : public 
nvinfer1::IInt8EntropyCalibrator2 {
     }
     // Free buffers
     for (size_t i = 0; i < buffers_.size(); ++i) {
-      CUDA_CALL(cudaFree(buffers_[i]));
+      TVM_FFI_CHECK_CUDA_ERROR(cudaFree(buffers_[i]));
     }
   }
 
@@ -55,8 +56,9 @@ class TensorRTCalibrator : public 
nvinfer1::IInt8EntropyCalibrator2 {
     std::vector<float*> data_host(bindings.size(), nullptr);
     for (size_t i = 0; i < bindings.size(); ++i) {
       data_host[i] = new float[batch_size_ * binding_sizes[i]];
-      CUDA_CALL(cudaMemcpy(static_cast<void*>(data_host[i]), bindings[i],
-                           batch_size_ * binding_sizes[i] * sizeof(float), 
cudaMemcpyDeviceToHost));
+      TVM_FFI_CHECK_CUDA_ERROR(cudaMemcpy(static_cast<void*>(data_host[i]), 
bindings[i],
+                                          batch_size_ * binding_sizes[i] * 
sizeof(float),
+                                          cudaMemcpyDeviceToHost));
     }
     data_.push_back(data_host);
     data_sizes_.push_back(binding_sizes);
@@ -73,9 +75,10 @@ class TensorRTCalibrator : public 
nvinfer1::IInt8EntropyCalibrator2 {
     TVM_FFI_ICHECK_EQ(input_names_.size(), nbBindings);
     for (size_t i = 0; i < input_names_.size(); ++i) {
       TVM_FFI_ICHECK_EQ(input_names_[i], names[i]);
-      CUDA_CALL(cudaMemcpy(buffers_[i], data_[num_batches_calibrated_][i],
-                           batch_size_ * 
data_sizes_[num_batches_calibrated_][i] * sizeof(float),
-                           cudaMemcpyHostToDevice));
+      TVM_FFI_CHECK_CUDA_ERROR(
+          cudaMemcpy(buffers_[i], data_[num_batches_calibrated_][i],
+                     batch_size_ * data_sizes_[num_batches_calibrated_][i] * 
sizeof(float),
+                     cudaMemcpyHostToDevice));
       bindings[i] = buffers_[i];
     }
     num_batches_calibrated_++;
@@ -120,7 +123,7 @@ class TensorRTCalibrator : public 
nvinfer1::IInt8EntropyCalibrator2 {
     const int num_inputs = data_sizes_[0].size();
     buffers_.assign(num_inputs, nullptr);
     for (int i = 0; i < num_inputs; ++i) {
-      CUDA_CALL(cudaMalloc(&buffers_[i], data_sizes_[0][i] * sizeof(float)));
+      TVM_FFI_CHECK_CUDA_ERROR(cudaMalloc(&buffers_[i], data_sizes_[0][i] * 
sizeof(float)));
     }
   }
 };
diff --git a/src/runtime/extra/contrib/thrust/thrust.cu 
b/src/runtime/extra/contrib/thrust/thrust.cu
index b361323231..f534e1fc76 100644
--- a/src/runtime/extra/contrib/thrust/thrust.cu
+++ b/src/runtime/extra/contrib/thrust/thrust.cu
@@ -33,6 +33,7 @@
 #include <thrust/sort.h>
 #include <tvm/ffi/dtype.h>
 #include <tvm/ffi/extra/c_env_api.h>
+#include <tvm/ffi/extra/cuda/base.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/runtime/logging.h>
@@ -41,8 +42,6 @@
 #include <functional>
 #include <memory>
 #include <vector>
-
-#include "../../../../backend/cuda/runtime/cuda_common.h"
 namespace tvm {
 namespace contrib {
 
@@ -95,7 +94,7 @@ class WorkspaceMemoryResource : public 
thrust::mr::memory_resource<void*> {
 
 auto get_thrust_exec_policy(WorkspaceMemoryResource* memory_resouce) {
   int device_id;
-  CUDA_CALL(cudaGetDevice(&device_id));
+  TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&device_id));
   cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(kDLCUDA, 
device_id));
   return thrust::cuda::par_nosync(memory_resouce).on(stream);
 }
diff --git a/src/runtime/extra/disco/cuda_ipc/cuda_ipc_memory.cc 
b/src/runtime/extra/disco/cuda_ipc/cuda_ipc_memory.cc
index 109ce6ed37..426557b7b7 100644
--- a/src/runtime/extra/disco/cuda_ipc/cuda_ipc_memory.cc
+++ b/src/runtime/extra/disco/cuda_ipc/cuda_ipc_memory.cc
@@ -18,13 +18,13 @@
  */
 
 #include <cuda_runtime.h>
+#include <tvm/ffi/extra/cuda/base.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/runtime/disco/cuda_ipc_memory.h>
 #include <tvm/runtime/memory/memory_manager.h>
 
 #include "../../../../../3rdparty/tensorrt_llm/custom_allreduce_kernels.h"
-#include "../../../../backend/cuda/runtime/cuda_common.h"
 #include "../../../memory/pooled_allocator.h"
 #include "../nccl/nccl_context.h"
 
@@ -45,20 +45,22 @@ using tvm::runtime::memory::Buffer;
 std::vector<cudaIpcMemHandle_t> 
AllGatherIPCHandles(nccl::CCLThreadLocalContext* ctx,
                                                     cudaIpcMemHandle_t 
local_handle) {
   void *d_src, *d_dst;
-  CUDA_CALL(cudaMalloc(&d_src, CUDA_IPC_HANDLE_SIZE));
-  CUDA_CALL(cudaMalloc(&d_dst, CUDA_IPC_HANDLE_SIZE * 
ctx->worker->num_workers));
-  CUDA_CALL(cudaMemcpy(d_src, &local_handle, CUDA_IPC_HANDLE_SIZE, 
cudaMemcpyHostToDevice));
+  TVM_FFI_CHECK_CUDA_ERROR(cudaMalloc(&d_src, CUDA_IPC_HANDLE_SIZE));
+  TVM_FFI_CHECK_CUDA_ERROR(cudaMalloc(&d_dst, CUDA_IPC_HANDLE_SIZE * 
ctx->worker->num_workers));
+  TVM_FFI_CHECK_CUDA_ERROR(
+      cudaMemcpy(d_src, &local_handle, CUDA_IPC_HANDLE_SIZE, 
cudaMemcpyHostToDevice));
   NCCL_CALL(ncclAllGather(d_src, d_dst, CUDA_IPC_HANDLE_SIZE, ncclChar, 
ctx->global_comm,
                           /*stream=*/nullptr));
   std::vector<char> serial_handles(CUDA_IPC_HANDLE_SIZE * 
ctx->worker->num_workers, 0);
-  CUDA_CALL(cudaMemcpy(serial_handles.data(), d_dst,
-                       CUDA_IPC_HANDLE_SIZE * ctx->worker->num_workers, 
cudaMemcpyDefault));
+  TVM_FFI_CHECK_CUDA_ERROR(cudaMemcpy(serial_handles.data(), d_dst,
+                                      CUDA_IPC_HANDLE_SIZE * 
ctx->worker->num_workers,
+                                      cudaMemcpyDefault));
   std::vector<cudaIpcMemHandle_t> handles(ctx->worker->num_workers);
   for (int i = 0; i < ctx->worker->num_workers; ++i) {
     memcpy(handles[i].reserved, &serial_handles[i * CUDA_IPC_HANDLE_SIZE], 
CUDA_IPC_HANDLE_SIZE);
   }
-  CUDA_CALL(cudaFree(d_src));
-  CUDA_CALL(cudaFree(d_dst));
+  TVM_FFI_CHECK_CUDA_ERROR(cudaFree(d_src));
+  TVM_FFI_CHECK_CUDA_ERROR(cudaFree(d_dst));
   return handles;
 }
 
@@ -115,7 +117,7 @@ class CUDAIPCMemoryAllocator final : public 
memory::PooledAllocator {
 
   void DeviceFreeDataSpace(Device dev, void* ptr) final {
     TVM_FFI_ICHECK(dev.device_type == kDLCUDA);
-    CUDA_CALL(cudaSetDevice(dev.device_id));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(dev.device_id));
     nccl::CCLThreadLocalContext* ctx = nccl::CCLThreadLocalContext::Get();
     auto it = ipc_memory_map_.find(ptr);
     TVM_FFI_ICHECK(it != ipc_memory_map_.end());
@@ -146,20 +148,20 @@ class CUDAIPCMemoryAllocator final : public 
memory::PooledAllocator {
     // Alloc local buffer
     TVM_FFI_ICHECK(dev.device_type == kDLCUDA);
     void* ptr;
-    CUDA_CALL(cudaSetDevice(dev.device_id));
-    CUDA_CALL(cudaMalloc(&ptr, size));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(dev.device_id));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaMalloc(&ptr, size));
     // Reset allocated memory to zero when required.
     // We explicitly synchronize after memset, to make sure memset finishes
     // before using all-gather to exchange IPC handles.
     // This is important to ensure the memory reset get ordered
     // before any other peers read the memory.
     if (reset_memory_to_zero) {
-      CUDA_CALL(cudaMemset(ptr, 0, size));
-      CUDA_CALL(cudaDeviceSynchronize());
+      TVM_FFI_CHECK_CUDA_ERROR(cudaMemset(ptr, 0, size));
+      TVM_FFI_CHECK_CUDA_ERROR(cudaDeviceSynchronize());
     }
     // Create ipc handle
     cudaIpcMemHandle_t local_handle;
-    CUDA_CALL(cudaIpcGetMemHandle(&local_handle, ptr));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaIpcGetMemHandle(&local_handle, ptr));
     // All-gather IPC handles.
     nccl::CCLThreadLocalContext* ctx = nccl::CCLThreadLocalContext::Get();
     std::vector<cudaIpcMemHandle_t> handles = AllGatherIPCHandles(ctx, 
local_handle);
@@ -170,8 +172,9 @@ class CUDAIPCMemoryAllocator final : public 
memory::PooledAllocator {
         comm_ptrs[node_id] = ptr;
       } else {
         uint8_t* foreign_buffer;
-        
CUDA_CALL(cudaIpcOpenMemHandle(reinterpret_cast<void**>(&foreign_buffer), 
handles[node_id],
-                                       cudaIpcMemLazyEnablePeerAccess));
+        
TVM_FFI_CHECK_CUDA_ERROR(cudaIpcOpenMemHandle(reinterpret_cast<void**>(&foreign_buffer),
+                                                      handles[node_id],
+                                                      
cudaIpcMemLazyEnablePeerAccess));
         comm_ptrs[node_id] = foreign_buffer;
       }
     }
@@ -183,10 +186,10 @@ class CUDAIPCMemoryAllocator final : public 
memory::PooledAllocator {
     for (int i = 0; i < static_cast<int>(comm_ptrs.size()); ++i) {
       if (i != worker_id) {
         // Free ipc handle.
-        CUDA_CALL(cudaIpcCloseMemHandle(comm_ptrs[i]));
+        TVM_FFI_CHECK_CUDA_ERROR(cudaIpcCloseMemHandle(comm_ptrs[i]));
       } else {
         // Free local buffer.
-        CUDA_CALL(cudaFree(comm_ptrs[i]));
+        TVM_FFI_CHECK_CUDA_ERROR(cudaFree(comm_ptrs[i]));
       }
     }
   }
diff --git a/src/runtime/extra/disco/nccl/nccl_context.h 
b/src/runtime/extra/disco/nccl/nccl_context.h
index 3747fec434..7a99be0897 100644
--- a/src/runtime/extra/disco/nccl/nccl_context.h
+++ b/src/runtime/extra/disco/nccl/nccl_context.h
@@ -35,8 +35,7 @@
 #endif
 #if TVM_NCCL_RCCL_SWITCH == 0
 #include <nccl.h>
-
-#include "../../../../backend/cuda/runtime/cuda_common.h"
+#include <tvm/ffi/extra/cuda/base.h>
 #else
 #include <rccl/rccl.h>
 
@@ -62,10 +61,16 @@ namespace nccl {
 
 using deviceStream_t = cudaStream_t;
 const constexpr DLDeviceType TVM_DISCO_DEVICE_TYPE = DLDeviceType::kDLCUDA;
-inline void SetDevice(int device_id) { CUDA_CALL(cudaSetDevice(device_id)); }
-inline void StreamSynchronize(deviceStream_t stream) { 
CUDA_CALL(cudaStreamSynchronize(stream)); }
-inline void StreamCreate(deviceStream_t* stream) { 
CUDA_CALL(cudaStreamCreate(stream)); }
-inline void StreamDestroy(deviceStream_t stream) { 
CUDA_CALL(cudaStreamDestroy(stream)); }
+inline void SetDevice(int device_id) { 
TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(device_id)); }
+inline void StreamSynchronize(deviceStream_t stream) {
+  TVM_FFI_CHECK_CUDA_ERROR(cudaStreamSynchronize(stream));
+}
+inline void StreamCreate(deviceStream_t* stream) {
+  TVM_FFI_CHECK_CUDA_ERROR(cudaStreamCreate(stream));
+}
+inline void StreamDestroy(deviceStream_t stream) {
+  TVM_FFI_CHECK_CUDA_ERROR(cudaStreamDestroy(stream));
+}
 
 #else
 
diff --git a/src/runtime/vm/cuda/cuda_graph_builtin.cc 
b/src/runtime/vm/cuda/cuda_graph_builtin.cc
index 58348fb695..b51b287d81 100644
--- a/src/runtime/vm/cuda/cuda_graph_builtin.cc
+++ b/src/runtime/vm/cuda/cuda_graph_builtin.cc
@@ -24,11 +24,11 @@
 
 #include <tvm/ffi/container/array.h>
 #include <tvm/ffi/extra/c_env_api.h>
+#include <tvm/ffi/extra/cuda/base.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/runtime/vm/vm.h>
 
-#include "../../../backend/cuda/runtime/cuda_common.h"
 #include "../../../support/utils.h"
 namespace tvm {
 namespace runtime {
@@ -85,7 +85,7 @@ struct CUDAGraphCapturedState {
 
   ~CUDAGraphCapturedState() {
     if (exec) {
-      CUDA_CALL(cudaGraphExecDestroy(exec));
+      TVM_FFI_CHECK_CUDA_ERROR(cudaGraphExecDestroy(exec));
     }
   }
 
@@ -100,7 +100,7 @@ struct CUDAGraphCapturedState {
 
 class ScopedCUDAStream {
  public:
-  ScopedCUDAStream() { CUDA_CALL(cudaStreamCreate(&stream_)); }
+  ScopedCUDAStream() { TVM_FFI_CHECK_CUDA_ERROR(cudaStreamCreate(&stream_)); }
   ~ScopedCUDAStream() { cudaStreamDestroy(stream_); }
   ScopedCUDAStream(const ScopedCUDAStream&) = delete;
   ScopedCUDAStream(ScopedCUDAStream&&) = delete;
@@ -116,11 +116,11 @@ class ScopedCUDAStream {
 class CUDACaptureStream {
  public:
   explicit CUDACaptureStream(cudaGraph_t* graph) : output_graph_(graph) {
-    CUDA_CALL(cudaGetDevice(&device_id_));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&device_id_));
     TVM_FFI_CHECK_SAFE_CALL(
         TVMFFIEnvSetStream(kDLCUDA, device_id_, capture_stream_,
                            
reinterpret_cast<TVMFFIStreamHandle*>(&prev_default_stream_)));
-    CUDA_CALL(cudaStreamBeginCapture(capture_stream_, 
cudaStreamCaptureModeGlobal));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaStreamBeginCapture(capture_stream_, 
cudaStreamCaptureModeGlobal));
   }
   ~CUDACaptureStream() noexcept(false) {
     cudaStreamEndCapture(capture_stream_, output_graph_);
@@ -158,8 +158,8 @@ class CUDAGraphExtensionNode : public VMExtensionNode {
       // Launch CUDA graph
       const auto& [states, exec] = it->second;
       int device_id;
-      CUDA_CALL(cudaGetDevice(&device_id));
-      CUDA_CALL(
+      TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&device_id));
+      TVM_FFI_CHECK_CUDA_ERROR(
           cudaGraphLaunch(exec, 
static_cast<cudaStream_t>(TVMFFIEnvGetStream(kDLCUDA, device_id))));
       return states;
     }
@@ -190,8 +190,8 @@ class CUDAGraphExtensionNode : public VMExtensionNode {
 
     CUDAGraphCapturedState entry;
     entry.states = capture_func_rv.cast<ffi::ObjectRef>();
-    CUDA_CALL(cudaGraphInstantiate(&entry.exec, graph, NULL, NULL, 0));
-    CUDA_CALL(cudaGraphDestroy(graph));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaGraphInstantiate(&entry.exec, graph, NULL, 
NULL, 0));
+    TVM_FFI_CHECK_CUDA_ERROR(cudaGraphDestroy(graph));
 
     ffi::ObjectRef states = entry.states;
 
diff --git a/tests/python/relax/test_vm_builtin.py 
b/tests/python/relax/test_vm_builtin.py
index f46d618104..677bf0b4a3 100644
--- a/tests/python/relax/test_vm_builtin.py
+++ b/tests/python/relax/test_vm_builtin.py
@@ -77,7 +77,7 @@ def test_alloc_tensor_raises_out_of_memory(target, dev):
     built = tvm.compile(Module, target=target)
     vm = relax.VirtualMachine(built, dev)
 
-    with pytest.raises(Exception, match="CUDA: out of memory"):
+    with pytest.raises(Exception, match="CUDA.*out of memory"):
         vm["main"]()
 
 

Reply via email to