This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3ad4eb3a94 [CUDA] Fix thrust with latest FFI refactor (#18024)
3ad4eb3a94 is described below

commit 3ad4eb3a94a259bcbc3ee3a50785bf752f39ed2f
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Jun 1 07:20:53 2025 -0400

    [CUDA] Fix thrust with latest FFI refactor (#18024)
---
 cmake/modules/CUDA.cmake             |   1 +
 src/runtime/contrib/thrust/thrust.cu | 189 ++++++++++++++++++-----------------
 2 files changed, 96 insertions(+), 94 deletions(-)

diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake
index f9dd4a8903..84261c6ea0 100644
--- a/cmake/modules/CUDA.cmake
+++ b/cmake/modules/CUDA.cmake
@@ -109,6 +109,7 @@ if(USE_CUDA)
     message(STATUS "Build with Thrust support")
     tvm_file_glob(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/thrust/*.cu)
     add_library(tvm_thrust_objs OBJECT ${CONTRIB_THRUST_SRC})
+    target_link_libraries(tvm_thrust_objs PRIVATE tvm_ffi_header)
     target_compile_options(tvm_thrust_objs PRIVATE 
$<$<COMPILE_LANGUAGE:CUDA>:--expt-extended-lambda>)
     target_compile_definitions(tvm_thrust_objs PUBLIC 
DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
     if (NOT USE_THRUST MATCHES ${IS_TRUE_PATTERN})
diff --git a/src/runtime/contrib/thrust/thrust.cu 
b/src/runtime/contrib/thrust/thrust.cu
index 19f82b1855..6b6b9df834 100644
--- a/src/runtime/contrib/thrust/thrust.cu
+++ b/src/runtime/contrib/thrust/thrust.cu
@@ -31,6 +31,7 @@
 #include <thrust/scan.h>
 #include <thrust/sequence.h>
 #include <thrust/sort.h>
+#include <tvm/ffi/dtype.h>
 #include <tvm/ffi/function.h>
 
 #include <algorithm>
@@ -233,24 +234,24 @@ void thrust_sort_common(DLTensor* input, DLTensor* 
values_out, DLTensor* indices
 }
 
 TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
-.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
-  ICHECK_GE(args.num_args, 4);
-  auto input = args[0].cast<DLTensor*>();
-  auto values_out = args[1].cast<DLTensor*>();
-  auto indices_out = args[2].cast<DLTensor*>();
-  bool is_ascend = args[3].cast<bool>();
-  DLTensor* workspace = nullptr;
-  if (args.num_args == 5) {
-    workspace = args[4];
-  }
+    .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
+      ICHECK_GE(args.size(), 4);
+      auto input = args[0].cast<DLTensor*>();
+      auto values_out = args[1].cast<DLTensor*>();
+      auto indices_out = args[2].cast<DLTensor*>();
+      bool is_ascend = args[3].cast<bool>();
+      DLTensor* workspace = nullptr;
+      if (args.size() == 5) {
+        workspace = args[4].cast<DLTensor*>();
+      }
 
-  auto data_dtype = DLDataTypeToString(input->dtype);
-  auto out_dtype = DLDataTypeToString(indices_out->dtype);
+      auto data_dtype = ffi::DLDataTypeToString(input->dtype);
+      auto out_dtype = ffi::DLDataTypeToString(indices_out->dtype);
 
-  int n_values = input->shape[input->ndim - 1];
-  thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, 
data_dtype, out_dtype,
-                     workspace);
-});
+      int n_values = input->shape[input->ndim - 1];
+      thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, 
data_dtype, out_dtype,
+                         workspace);
+    });
 
 template <typename KeyType, typename ValueType>
 void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in, 
DLTensor* keys_out,
@@ -281,19 +282,19 @@ void thrust_stable_sort_by_key(DLTensor* keys_in, 
DLTensor* values_in, DLTensor*
 
 TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
     .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
-      ICHECK_GE(args.num_args, 5);
+      ICHECK_GE(args.size(), 5);
       auto keys_in = args[0].cast<DLTensor*>();
       auto values_in = args[1].cast<DLTensor*>();
       auto keys_out = args[2].cast<DLTensor*>();
       auto values_out = args[3].cast<DLTensor*>();
       bool for_scatter = args[4].cast<bool>();
       DLTensor* workspace = nullptr;
-      if (args.num_args == 6) {
-        workspace = args[5];
+      if (args.size() == 6) {
+        workspace = args[5].cast<DLTensor*>();
       }
 
-      auto key_dtype = DLDataTypeToString(keys_in->dtype);
-      auto value_dtype = DLDataTypeToString(values_in->dtype);
+      auto key_dtype = ffi::DLDataTypeToString(keys_in->dtype);
+      auto value_dtype = ffi::DLDataTypeToString(values_in->dtype);
 
       if (key_dtype == "int32") {
         if (value_dtype == "int32") {
@@ -395,82 +396,82 @@ void thrust_scan(DLTensor* data, DLTensor* output, bool 
exclusive, DLTensor* wor
 }
 
 TVM_FFI_REGISTER_GLOBAL("tvm.contrib.thrust.sum_scan")
-.set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
-  ICHECK(args.num_args == 2 || args.num_args == 3 || args.num_args == 4);
-  auto data = args[0].cast<DLTensor*>();
-  auto output = args[1].cast<DLTensor*>();
-  bool exclusive = false;
-  DLTensor* workspace = nullptr;
-
-  if (args.num_args >= 3) {
-    exclusive = args[2];
-  }
+    .set_body_packed([](ffi::PackedArgs args, ffi::Any* ret) {
+      ICHECK(args.size() == 2 || args.size() == 3 || args.size() == 4);
+      auto data = args[0].cast<DLTensor*>();
+      auto output = args[1].cast<DLTensor*>();
+      bool exclusive = false;
+      DLTensor* workspace = nullptr;
 
-  if (args.num_args == 4) {
-    workspace = args[3];
-  }
+      if (args.size() >= 3) {
+        exclusive = args[2].cast<bool>();
+      }
 
-  auto in_dtype = DLDataTypeToString(data->dtype);
-  auto out_dtype = DLDataTypeToString(output->dtype);
+      if (args.size() == 4) {
+        workspace = args[3].cast<DLTensor*>();
+      }
 
-  if (in_dtype == "bool") {
-    if (out_dtype == "int32") {
-      thrust_scan<bool, int>(data, output, exclusive, workspace);
-    } else if (out_dtype == "int64") {
-      thrust_scan<bool, int64_t>(data, output, exclusive, workspace);
-    } else if (out_dtype == "float32") {
-      thrust_scan<bool, float>(data, output, exclusive, workspace);
-    } else if (out_dtype == "float64") {
-      thrust_scan<bool, double>(data, output, exclusive, workspace);
-    } else {
-      LOG(FATAL) << "Unsupported output dtype: " << out_dtype
-                 << ". Supported output dtypes are int32, int64, float32, and 
float64";
-    }
-  } else if (in_dtype == "int32") {
-    if (out_dtype == "int32") {
-      thrust_scan<int, int>(data, output, exclusive, workspace);
-    } else if (out_dtype == "int64") {
-      thrust_scan<int, int64_t>(data, output, exclusive, workspace);
-    } else if (out_dtype == "float32") {
-      thrust_scan<int, float>(data, output, exclusive, workspace);
-    } else if (out_dtype == "float64") {
-      thrust_scan<int, double>(data, output, exclusive, workspace);
-    } else {
-      LOG(FATAL) << "Unsupported output dtype: " << out_dtype
-                 << ". Supported output dtypes are int32, int64, float32, and 
float64";
-    }
-  } else if (in_dtype == "int64") {
-    if (out_dtype == "int64") {
-      thrust_scan<int64_t, int64_t>(data, output, exclusive, workspace);
-    } else if (out_dtype == "float32") {
-      thrust_scan<int64_t, float>(data, output, exclusive, workspace);
-    } else if (out_dtype == "float64") {
-      thrust_scan<int64_t, double>(data, output, exclusive, workspace);
-    } else {
-      LOG(FATAL) << "Unsupported output dtype: " << out_dtype
-                 << ". Supported output dtypes are int64, float32, and 
float64";
-    }
-  } else if (in_dtype == "float32") {
-    if (out_dtype == "float32") {
-      thrust_scan<float, float>(data, output, exclusive, workspace);
-    } else if (out_dtype == "float64") {
-      thrust_scan<float, double>(data, output, exclusive, workspace);
-    } else {
-      LOG(FATAL) << "Unsupported output dtype: " << out_dtype
-                 << ". Supported output dtypes are float32, and float64";
-    }
-  } else if (in_dtype == "float64") {
-    if (out_dtype == "float64") {
-      thrust_scan<double, double>(data, output, exclusive, workspace);
-    } else {
-      LOG(FATAL) << "Unsupported output dtype: " << out_dtype
-                 << ". Supported output dtype is float64";
-    }
-  } else {
-    LOG(FATAL) << "Unsupported input dtype: " << in_dtype
-               << ". Supported input dtypes are bool, int32, int64, float32, 
and float64";
-  }
-});
+      auto in_dtype = ffi::DLDataTypeToString(data->dtype);
+      auto out_dtype = ffi::DLDataTypeToString(output->dtype);
+
+      if (in_dtype == "bool") {
+        if (out_dtype == "int32") {
+          thrust_scan<bool, int>(data, output, exclusive, workspace);
+        } else if (out_dtype == "int64") {
+          thrust_scan<bool, int64_t>(data, output, exclusive, workspace);
+        } else if (out_dtype == "float32") {
+          thrust_scan<bool, float>(data, output, exclusive, workspace);
+        } else if (out_dtype == "float64") {
+          thrust_scan<bool, double>(data, output, exclusive, workspace);
+        } else {
+          LOG(FATAL) << "Unsupported output dtype: " << out_dtype
+                     << ". Supported output dtypes are int32, int64, float32, 
and float64";
+        }
+      } else if (in_dtype == "int32") {
+        if (out_dtype == "int32") {
+          thrust_scan<int, int>(data, output, exclusive, workspace);
+        } else if (out_dtype == "int64") {
+          thrust_scan<int, int64_t>(data, output, exclusive, workspace);
+        } else if (out_dtype == "float32") {
+          thrust_scan<int, float>(data, output, exclusive, workspace);
+        } else if (out_dtype == "float64") {
+          thrust_scan<int, double>(data, output, exclusive, workspace);
+        } else {
+          LOG(FATAL) << "Unsupported output dtype: " << out_dtype
+                     << ". Supported output dtypes are int32, int64, float32, 
and float64";
+        }
+      } else if (in_dtype == "int64") {
+        if (out_dtype == "int64") {
+          thrust_scan<int64_t, int64_t>(data, output, exclusive, workspace);
+        } else if (out_dtype == "float32") {
+          thrust_scan<int64_t, float>(data, output, exclusive, workspace);
+        } else if (out_dtype == "float64") {
+          thrust_scan<int64_t, double>(data, output, exclusive, workspace);
+        } else {
+          LOG(FATAL) << "Unsupported output dtype: " << out_dtype
+                     << ". Supported output dtypes are int64, float32, and 
float64";
+        }
+      } else if (in_dtype == "float32") {
+        if (out_dtype == "float32") {
+          thrust_scan<float, float>(data, output, exclusive, workspace);
+        } else if (out_dtype == "float64") {
+          thrust_scan<float, double>(data, output, exclusive, workspace);
+        } else {
+          LOG(FATAL) << "Unsupported output dtype: " << out_dtype
+                     << ". Supported output dtypes are float32, and float64";
+        }
+      } else if (in_dtype == "float64") {
+        if (out_dtype == "float64") {
+          thrust_scan<double, double>(data, output, exclusive, workspace);
+        } else {
+          LOG(FATAL) << "Unsupported output dtype: " << out_dtype
+                     << ". Supported output dtype is float64";
+        }
+      } else {
+        LOG(FATAL) << "Unsupported input dtype: " << in_dtype
+                   << ". Supported input dtypes are bool, int32, int64, 
float32, and float64";
+      }
+    });
 
 }  // namespace contrib
 }  // namespace tvm

Reply via email to