This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dfc9fc03b1 [Contrib] Fix CUDA contrib build after FFI/header cleanups
(#19539)
dfc9fc03b1 is described below
commit dfc9fc03b1d3e381d0cbf189e21e50f7c7a60041
Author: Ruihang Lai <[email protected]>
AuthorDate: Tue May 12 00:16:34 2026 -0400
[Contrib] Fix CUDA contrib build after FFI/header cleanups (#19539)
Six CUDA sources in src/runtime/contrib used LOG(FATAL) via transitive
includes that #19483 trimmed; add the explicit <tvm/runtime/logging.h>
include to thrust.cu, attention_kernels.cu, and the four cutlass kernel
headers (fp16/fp8 sm90/sm100, gemm_runner, fp8_groupwise_scaled_gemm).
cache_kernels.cu used the bare Array{...} alias that #19483 removed;
switch to ffi::Array<Tensor>{...}.
attention_kernels.cu registered FFI functions whose parameters were raw
DLTensor*; the new reflection registry requires TypeSchema, so wrap both
TVM_FFI_STATIC_INIT_BLOCK registrations to take Tensor and forward to
the unchanged launchers via GetDLTensorPtr() (with const_cast for the
output tensors, matching the mt_random_engine / cudnn pattern).
---
.../cutlass/fp16_group_gemm_runner_sm100.cuh | 2 +
.../cutlass/fp16_group_gemm_runner_sm90.cuh | 2 +
.../contrib/cutlass/fp8_groupwise_scaled_gemm.cuh | 1 +
src/runtime/contrib/cutlass/gemm_runner.cuh | 2 +
src/runtime/contrib/nvshmem/init.cc | 1 +
src/runtime/contrib/thrust/thrust.cu | 1 +
src/runtime/contrib/vllm/attention_kernels.cu | 49 ++++++++++++++++------
src/runtime/contrib/vllm/cache_kernels.cu | 4 +-
8 files changed, 48 insertions(+), 14 deletions(-)
diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh
b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh
index 22a9bea646..17f5c23a75 100644
--- a/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh
+++ b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh
@@ -17,6 +17,8 @@
* under the License.
*/
+#include <tvm/runtime/logging.h>
+
#include <fstream>
#include <iostream>
#include <sstream>
diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh
b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh
index 4fc513e3db..2ee0026766 100644
--- a/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh
+++ b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh
@@ -17,6 +17,8 @@
* under the License.
*/
+#include <tvm/runtime/logging.h>
+
#include <fstream>
#include <iostream>
#include <sstream>
diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh
index 338a96c8b7..26dbcad6c5 100644
--- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh
+++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh
@@ -21,6 +21,7 @@
#include <float.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/function.h>
+#include <tvm/runtime/logging.h>
#include <tvm/runtime/tensor.h>
#include "cutlass/bfloat16.h"
diff --git a/src/runtime/contrib/cutlass/gemm_runner.cuh
b/src/runtime/contrib/cutlass/gemm_runner.cuh
index b0907bfe29..c6815f60c5 100644
--- a/src/runtime/contrib/cutlass/gemm_runner.cuh
+++ b/src/runtime/contrib/cutlass/gemm_runner.cuh
@@ -17,6 +17,8 @@
* under the License.
*/
+#include <tvm/runtime/logging.h>
+
#include <fstream>
#include <iostream>
#include <sstream>
diff --git a/src/runtime/contrib/nvshmem/init.cc
b/src/runtime/contrib/nvshmem/init.cc
index 1528f03d8e..b82ab0530b 100644
--- a/src/runtime/contrib/nvshmem/init.cc
+++ b/src/runtime/contrib/nvshmem/init.cc
@@ -23,6 +23,7 @@
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/disco/disco_worker.h>
+#include <tvm/runtime/logging.h>
#include "../../cuda/cuda_common.h"
diff --git a/src/runtime/contrib/thrust/thrust.cu
b/src/runtime/contrib/thrust/thrust.cu
index d306750c48..16217432dc 100644
--- a/src/runtime/contrib/thrust/thrust.cu
+++ b/src/runtime/contrib/thrust/thrust.cu
@@ -35,6 +35,7 @@
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/runtime/logging.h>
#include <algorithm>
#include <functional>
diff --git a/src/runtime/contrib/vllm/attention_kernels.cu
b/src/runtime/contrib/vllm/attention_kernels.cu
index f9b812b2a2..ec0caa5f3d 100644
--- a/src/runtime/contrib/vllm/attention_kernels.cu
+++ b/src/runtime/contrib/vllm/attention_kernels.cu
@@ -37,6 +37,7 @@
#include <float.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/runtime/logging.h>
#include <tvm/runtime/tensor.h>
#include <algorithm>
@@ -756,10 +757,10 @@ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def(
"tvm.contrib.vllm.single_query_cached_kv_attention",
- [](const DLTensor* query, const DLTensor* key_cache, const DLTensor*
value_cache,
- const DLTensor* block_tables, const DLTensor* context_lens, int
block_size,
- const DLTensor* max_context_len_tensor, // TODO(masahi): pass integer
- DLTensor* exp_sums, DLTensor* max_logits, DLTensor* tmp_out,
DLTensor* out) {
+ [](Tensor query, Tensor key_cache, Tensor value_cache, Tensor
block_tables,
+ Tensor context_lens, int block_size,
+ Tensor max_context_len_tensor, // TODO(masahi): pass integer
+ Tensor exp_sums, Tensor max_logits, Tensor tmp_out, Tensor out) {
int num_seqs = query->shape[0];
int num_heads = query->shape[1];
int max_context_len =
static_cast<int*>(max_context_len_tensor->data)[0];
@@ -768,13 +769,19 @@ TVM_FFI_STATIC_INIT_BLOCK() {
bool use_v1 =
max_context_len <= 8192 && (max_num_partitions == 1 || num_seqs *
num_heads > 512);
if (use_v1) {
- single_query_cached_kv_attention_v1(query, key_cache, value_cache,
block_tables,
- context_lens, block_size,
max_context_len_tensor,
- out);
+ single_query_cached_kv_attention_v1(
+ query.GetDLTensorPtr(), key_cache.GetDLTensorPtr(),
value_cache.GetDLTensorPtr(),
+ block_tables.GetDLTensorPtr(), context_lens.GetDLTensorPtr(),
block_size,
+ max_context_len_tensor.GetDLTensorPtr(),
const_cast<DLTensor*>(out.GetDLTensorPtr()));
} else {
- single_query_cached_kv_attention_v2(query, key_cache, value_cache,
block_tables,
- context_lens, block_size,
max_context_len_tensor,
- exp_sums, max_logits, tmp_out,
out);
+ single_query_cached_kv_attention_v2(
+ query.GetDLTensorPtr(), key_cache.GetDLTensorPtr(),
value_cache.GetDLTensorPtr(),
+ block_tables.GetDLTensorPtr(), context_lens.GetDLTensorPtr(),
block_size,
+ max_context_len_tensor.GetDLTensorPtr(),
+ const_cast<DLTensor*>(exp_sums.GetDLTensorPtr()),
+ const_cast<DLTensor*>(max_logits.GetDLTensorPtr()),
+ const_cast<DLTensor*>(tmp_out.GetDLTensorPtr()),
+ const_cast<DLTensor*>(out.GetDLTensorPtr()));
}
});
}
@@ -784,9 +791,27 @@ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("tvm.contrib.vllm.single_query_cached_kv_attention_v1",
- single_query_cached_kv_attention_v1)
+ [](Tensor query, Tensor key_cache, Tensor value_cache, Tensor
block_tables,
+ Tensor context_lens, int block_size, Tensor
max_context_len_tensor, Tensor out) {
+ single_query_cached_kv_attention_v1(
+ query.GetDLTensorPtr(), key_cache.GetDLTensorPtr(),
value_cache.GetDLTensorPtr(),
+ block_tables.GetDLTensorPtr(), context_lens.GetDLTensorPtr(),
block_size,
+ max_context_len_tensor.GetDLTensorPtr(),
+ const_cast<DLTensor*>(out.GetDLTensorPtr()));
+ })
.def("tvm.contrib.vllm.single_query_cached_kv_attention_v2",
- single_query_cached_kv_attention_v2);
+ [](Tensor query, Tensor key_cache, Tensor value_cache, Tensor
block_tables,
+ Tensor context_lens, int block_size, Tensor
max_context_len_tensor, Tensor exp_sums,
+ Tensor max_logits, Tensor tmp_out, Tensor out) {
+ single_query_cached_kv_attention_v2(
+ query.GetDLTensorPtr(), key_cache.GetDLTensorPtr(),
value_cache.GetDLTensorPtr(),
+ block_tables.GetDLTensorPtr(), context_lens.GetDLTensorPtr(),
block_size,
+ max_context_len_tensor.GetDLTensorPtr(),
+ const_cast<DLTensor*>(exp_sums.GetDLTensorPtr()),
+ const_cast<DLTensor*>(max_logits.GetDLTensorPtr()),
+ const_cast<DLTensor*>(tmp_out.GetDLTensorPtr()),
+ const_cast<DLTensor*>(out.GetDLTensorPtr()));
+ });
}
} // namespace runtime
diff --git a/src/runtime/contrib/vllm/cache_kernels.cu
b/src/runtime/contrib/vllm/cache_kernels.cu
index 5ddf18e482..5af93a1fd9 100644
--- a/src/runtime/contrib/vllm/cache_kernels.cu
+++ b/src/runtime/contrib/vllm/cache_kernels.cu
@@ -154,7 +154,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
static_cast<const int*>(slot_mapping->data), key_stride,
value_stride, num_heads,
head_size, block_size, vec_size);
- return Array{key_cache, value_cache};
+ return ffi::Array<Tensor>{key_cache, value_cache};
})
.def("tvm.contrib.vllm.reconstruct_from_cache",
[](Tensor key_cache, Tensor value_cache, Tensor slot_mapping) {
@@ -182,7 +182,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
static_cast<scalar_t*>(value->data), key_stride,
value_stride, num_heads,
head_size, block_size, vec_size);
- return Array{key, value};
+ return ffi::Array<Tensor>{key, value};
})
.def("tvm.contrib.vllm.copy_blocks", [](ffi::Array<Tensor>
key_value_caches,
Tensor block_mapping) {