This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new cdfdd0e4ec [Contrib] Enable fp16 for thrust sort (#16887)
cdfdd0e4ec is described below
commit cdfdd0e4ec7452bedf4e79ba0ff474d2de70bbbf
Author: Siyuan Feng <[email protected]>
AuthorDate: Tue Apr 16 20:13:21 2024 +0800
[Contrib] Enable fp16 for thrust sort (#16887)
[Contrib] Enable fp16 for thrust
Enable fp16 for thrust to support LLM cases
---
src/runtime/contrib/thrust/thrust.cu | 14 +++++++++++++-
1 file changed, 13 insertions(+), 1 deletion(-)
diff --git a/src/runtime/contrib/thrust/thrust.cu
b/src/runtime/contrib/thrust/thrust.cu
index 28edba64aa..048df518e3 100644
--- a/src/runtime/contrib/thrust/thrust.cu
+++ b/src/runtime/contrib/thrust/thrust.cu
@@ -167,7 +167,19 @@ void thrust_sort(DLTensor* input, DLTensor* out_values,
DLTensor* out_indices, b
void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor*
indices_out,
bool is_ascend, int sort_len, std::string data_dtype,
std::string out_dtype,
DLTensor* workspace) {
- if (data_dtype == "float32") {
+ if (data_dtype == "float16") {
+ if (out_dtype == "int32") {
+ thrust_sort<half, int32_t>(input, values_out, indices_out, is_ascend,
sort_len, workspace);
+ } else if (out_dtype == "int64") {
+ thrust_sort<half, int64_t>(input, values_out, indices_out, is_ascend,
sort_len, workspace);
+ } else if (out_dtype == "float32") {
+ thrust_sort<half, float>(input, values_out, indices_out, is_ascend,
sort_len, workspace);
+ } else if (out_dtype == "float64") {
+ thrust_sort<half, double>(input, values_out, indices_out, is_ascend,
sort_len, workspace);
+ } else {
+ LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
+ }
+ } else if (data_dtype == "float32") {
if (out_dtype == "int32") {
thrust_sort<float, int32_t>(input, values_out, indices_out, is_ascend,
sort_len, workspace);
} else if (out_dtype == "int64") {