This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3ad4eb3a94 [CUDA] Fix thrust with latest FFI refactor (#18024)
3ad4eb3a94 is described below
commit 3ad4eb3a94a259bcbc3ee3a50785bf752f39ed2f
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Jun 1 07:20:53 2025 -0400
[CUDA] Fix thrust with latest FFI refactor (#18024)
---
cmake/modules/CUDA.cmake | 1 +
src/runtime/contrib/thrust/thrust.cu | 189 ++++++++++++++++++-----------------
2 files changed, 96 insertions(+), 94 deletions(-)
diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake
index f9dd4a8903..84261c6ea0 100644
--- a/cmake/modules/CUDA.cmake
+++ b/cmake/modules/CUDA.cmake
@@ -109,6 +109,7 @@ if(USE_CUDA)
message(STATUS "Build with Thrust support")
tvm_file_glob(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/thrust/*.cu)
add_library(tvm_thrust_objs OBJECT ${CONTRIB_THRUST_SRC})
+ target_link_libraries(tvm_thrust_objs PRIVATE tvm_ffi_header)
target_compile_options(tvm_thrust_objs PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:--expt-extended-lambda>)
target_compile_definitions(tvm_thrust_objs PUBLIC
DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
if (NOT USE_THRUST MATCHES ${IS_TRUE_PATTERN})
diff --git a/src/runtime/contrib/thrust/thrust.cu
b/src/runtime/contrib/thrust/thrust.cu
index 19f82b1855..6b6b9df834 100644
--- a/src/runtime/contrib/thrust/thrust.cu
+++ b/src/runtime/contrib/thrust/thrust.cu
@@ -31,6 +31,7 @@
#include <thrust/scan.h>
#include <thrust/sequence.h>
#include <thrust/sort.h>
+#include <tvm/ffi/dtype.h>
#include <tvm/ffi/function.h>
#include <algorithm>
@@ -233,24 +234,24 @@ void thrust_sort_common(DLTensor* input, DLTensor*
values_out, DLTensor* indices
}
TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
-.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
- ICHECK_GE(args.num_args, 4);
- auto input = args[0].cast<DLTensor*>();
- auto values_out = args[1].cast<DLTensor*>();
- auto indices_out = args[2].cast<DLTensor*>();
- bool is_ascend = args[3].cast<bool>();
- DLTensor* workspace = nullptr;
- if (args.num_args == 5) {
- workspace = args[4];
- }
+ .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
+ ICHECK_GE(args.size(), 4);
+ auto input = args[0].cast<DLTensor*>();
+ auto values_out = args[1].cast<DLTensor*>();
+ auto indices_out = args[2].cast<DLTensor*>();
+ bool is_ascend = args[3].cast<bool>();
+ DLTensor* workspace = nullptr;
+ if (args.size() == 5) {
+ workspace = args[4].cast<DLTensor*>();
+ }
- auto data_dtype = DLDataTypeToString(input->dtype);
- auto out_dtype = DLDataTypeToString(indices_out->dtype);
+ auto data_dtype = ffi::DLDataTypeToString(input->dtype);
+ auto out_dtype = ffi::DLDataTypeToString(indices_out->dtype);
- int n_values = input->shape[input->ndim - 1];
- thrust_sort_common(input, values_out, indices_out, is_ascend, n_values,
data_dtype, out_dtype,
- workspace);
-});
+ int n_values = input->shape[input->ndim - 1];
+ thrust_sort_common(input, values_out, indices_out, is_ascend, n_values,
data_dtype, out_dtype,
+ workspace);
+ });
template <typename KeyType, typename ValueType>
void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in,
DLTensor* keys_out,
@@ -281,19 +282,19 @@ void thrust_stable_sort_by_key(DLTensor* keys_in,
DLTensor* values_in, DLTensor*
TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
- ICHECK_GE(args.num_args, 5);
+ ICHECK_GE(args.size(), 5);
auto keys_in = args[0].cast<DLTensor*>();
auto values_in = args[1].cast<DLTensor*>();
auto keys_out = args[2].cast<DLTensor*>();
auto values_out = args[3].cast<DLTensor*>();
bool for_scatter = args[4].cast<bool>();
DLTensor* workspace = nullptr;
- if (args.num_args == 6) {
- workspace = args[5];
+ if (args.size() == 6) {
+ workspace = args[5].cast<DLTensor*>();
}
- auto key_dtype = DLDataTypeToString(keys_in->dtype);
- auto value_dtype = DLDataTypeToString(values_in->dtype);
+ auto key_dtype = ffi::DLDataTypeToString(keys_in->dtype);
+ auto value_dtype = ffi::DLDataTypeToString(values_in->dtype);
if (key_dtype == "int32") {
if (value_dtype == "int32") {
@@ -395,82 +396,82 @@ void thrust_scan(DLTensor* data, DLTensor* output, bool
exclusive, DLTensor* wor
}
TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan")
-.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
- ICHECK(args.num_args == 2 || args.num_args == 3 || args.num_args == 4);
- auto data = args[0].cast<DLTensor*>();
- auto output = args[1].cast<DLTensor*>();
- bool exclusive = false;
- DLTensor* workspace = nullptr;
-
- if (args.num_args >= 3) {
- exclusive = args[2];
- }
+ .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
+ ICHECK(args.size() == 2 || args.size() == 3 || args.size() == 4);
+ auto data = args[0].cast<DLTensor*>();
+ auto output = args[1].cast<DLTensor*>();
+ bool exclusive = false;
+ DLTensor* workspace = nullptr;
- if (args.num_args == 4) {
- workspace = args[3];
- }
+ if (args.size() >= 3) {
+ exclusive = args[2].cast<bool>();
+ }
- auto in_dtype = DLDataTypeToString(data->dtype);
- auto out_dtype = DLDataTypeToString(output->dtype);
+ if (args.size() == 4) {
+ workspace = args[3].cast<DLTensor*>();
+ }
- if (in_dtype == "bool") {
- if (out_dtype == "int32") {
- thrust_scan<bool, int>(data, output, exclusive, workspace);
- } else if (out_dtype == "int64") {
- thrust_scan<bool, int64_t>(data, output, exclusive, workspace);
- } else if (out_dtype == "float32") {
- thrust_scan<bool, float>(data, output, exclusive, workspace);
- } else if (out_dtype == "float64") {
- thrust_scan<bool, double>(data, output, exclusive, workspace);
- } else {
- LOG(FATAL) << "Unsupported output dtype: " << out_dtype
- << ". Supported output dtypes are int32, int64, float32, and
float64";
- }
- } else if (in_dtype == "int32") {
- if (out_dtype == "int32") {
- thrust_scan<int, int>(data, output, exclusive, workspace);
- } else if (out_dtype == "int64") {
- thrust_scan<int, int64_t>(data, output, exclusive, workspace);
- } else if (out_dtype == "float32") {
- thrust_scan<int, float>(data, output, exclusive, workspace);
- } else if (out_dtype == "float64") {
- thrust_scan<int, double>(data, output, exclusive, workspace);
- } else {
- LOG(FATAL) << "Unsupported output dtype: " << out_dtype
- << ". Supported output dtypes are int32, int64, float32, and
float64";
- }
- } else if (in_dtype == "int64") {
- if (out_dtype == "int64") {
- thrust_scan<int64_t, int64_t>(data, output, exclusive, workspace);
- } else if (out_dtype == "float32") {
- thrust_scan<int64_t, float>(data, output, exclusive, workspace);
- } else if (out_dtype == "float64") {
- thrust_scan<int64_t, double>(data, output, exclusive, workspace);
- } else {
- LOG(FATAL) << "Unsupported output dtype: " << out_dtype
- << ". Supported output dtypes are int64, float32, and
float64";
- }
- } else if (in_dtype == "float32") {
- if (out_dtype == "float32") {
- thrust_scan<float, float>(data, output, exclusive, workspace);
- } else if (out_dtype == "float64") {
- thrust_scan<float, double>(data, output, exclusive, workspace);
- } else {
- LOG(FATAL) << "Unsupported output dtype: " << out_dtype
- << ". Supported output dtypes are float32, and float64";
- }
- } else if (in_dtype == "float64") {
- if (out_dtype == "float64") {
- thrust_scan<double, double>(data, output, exclusive, workspace);
- } else {
- LOG(FATAL) << "Unsupported output dtype: " << out_dtype
- << ". Supported output dtype is float64";
- }
- } else {
- LOG(FATAL) << "Unsupported input dtype: " << in_dtype
- << ". Supported input dtypes are bool, int32, int64, float32,
and float64";
- }
-});
+ auto in_dtype = ffi::DLDataTypeToString(data->dtype);
+ auto out_dtype = ffi::DLDataTypeToString(output->dtype);
+
+ if (in_dtype == "bool") {
+ if (out_dtype == "int32") {
+ thrust_scan<bool, int>(data, output, exclusive, workspace);
+ } else if (out_dtype == "int64") {
+ thrust_scan<bool, int64_t>(data, output, exclusive, workspace);
+ } else if (out_dtype == "float32") {
+ thrust_scan<bool, float>(data, output, exclusive, workspace);
+ } else if (out_dtype == "float64") {
+ thrust_scan<bool, double>(data, output, exclusive, workspace);
+ } else {
+ LOG(FATAL) << "Unsupported output dtype: " << out_dtype
+ << ". Supported output dtypes are int32, int64, float32,
and float64";
+ }
+ } else if (in_dtype == "int32") {
+ if (out_dtype == "int32") {
+ thrust_scan<int, int>(data, output, exclusive, workspace);
+ } else if (out_dtype == "int64") {
+ thrust_scan<int, int64_t>(data, output, exclusive, workspace);
+ } else if (out_dtype == "float32") {
+ thrust_scan<int, float>(data, output, exclusive, workspace);
+ } else if (out_dtype == "float64") {
+ thrust_scan<int, double>(data, output, exclusive, workspace);
+ } else {
+ LOG(FATAL) << "Unsupported output dtype: " << out_dtype
+ << ". Supported output dtypes are int32, int64, float32,
and float64";
+ }
+ } else if (in_dtype == "int64") {
+ if (out_dtype == "int64") {
+ thrust_scan<int64_t, int64_t>(data, output, exclusive, workspace);
+ } else if (out_dtype == "float32") {
+ thrust_scan<int64_t, float>(data, output, exclusive, workspace);
+ } else if (out_dtype == "float64") {
+ thrust_scan<int64_t, double>(data, output, exclusive, workspace);
+ } else {
+ LOG(FATAL) << "Unsupported output dtype: " << out_dtype
+ << ". Supported output dtypes are int64, float32, and
float64";
+ }
+ } else if (in_dtype == "float32") {
+ if (out_dtype == "float32") {
+ thrust_scan<float, float>(data, output, exclusive, workspace);
+ } else if (out_dtype == "float64") {
+ thrust_scan<float, double>(data, output, exclusive, workspace);
+ } else {
+ LOG(FATAL) << "Unsupported output dtype: " << out_dtype
+ << ". Supported output dtypes are float32, and float64";
+ }
+ } else if (in_dtype == "float64") {
+ if (out_dtype == "float64") {
+ thrust_scan<double, double>(data, output, exclusive, workspace);
+ } else {
+ LOG(FATAL) << "Unsupported output dtype: " << out_dtype
+ << ". Supported output dtype is float64";
+ }
+ } else {
+ LOG(FATAL) << "Unsupported input dtype: " << in_dtype
+ << ". Supported input dtypes are bool, int32, int64,
float32, and float64";
+ }
+ });
} // namespace contrib
} // namespace tvm