This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cdfdd0e4ec [Contrib] Enable fp16 for thrust sort (#16887)
cdfdd0e4ec is described below

commit cdfdd0e4ec7452bedf4e79ba0ff474d2de70bbbf
Author: Siyuan Feng <[email protected]>
AuthorDate: Tue Apr 16 20:13:21 2024 +0800

    [Contrib] Enable fp16 for thrust sort (#16887)
    
    [Contrib] Enable fp16 for thrust
    
    Enable fp16 for thrust to support LLM cases
---
 src/runtime/contrib/thrust/thrust.cu | 14 +++++++++++++-
 1 file changed, 13 insertions(+), 1 deletion(-)

diff --git a/src/runtime/contrib/thrust/thrust.cu 
b/src/runtime/contrib/thrust/thrust.cu
index 28edba64aa..048df518e3 100644
--- a/src/runtime/contrib/thrust/thrust.cu
+++ b/src/runtime/contrib/thrust/thrust.cu
@@ -167,7 +167,19 @@ void thrust_sort(DLTensor* input, DLTensor* out_values, 
DLTensor* out_indices, b
 void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* 
indices_out,
                         bool is_ascend, int sort_len, std::string data_dtype, 
std::string out_dtype,
                         DLTensor* workspace) {
-  if (data_dtype == "float32") {
+  if (data_dtype == "float16") {
+    if (out_dtype == "int32") {
+      thrust_sort<half, int32_t>(input, values_out, indices_out, is_ascend, 
sort_len, workspace);
+    } else if (out_dtype == "int64") {
+      thrust_sort<half, int64_t>(input, values_out, indices_out, is_ascend, 
sort_len, workspace);
+    } else if (out_dtype == "float32") {
+      thrust_sort<half, float>(input, values_out, indices_out, is_ascend, 
sort_len, workspace);
+    } else if (out_dtype == "float64") {
+      thrust_sort<half, double>(input, values_out, indices_out, is_ascend, 
sort_len, workspace);
+    } else {
+      LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
+    }
+  } else if (data_dtype == "float32") {
     if (out_dtype == "int32") {
       thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend, 
sort_len, workspace);
     } else if (out_dtype == "int64") {

Reply via email to