This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-2.1 by this push:
new cc6ff12097e [opt](function) Optimize the trim function for single-char
inputs (#3… (#37799)
cc6ff12097e is described below
commit cc6ff12097e3c0cc79cf70ae0d25dd7020faa9c5
Author: Mryange <[email protected]>
AuthorDate: Tue Jul 16 17:52:52 2024 +0800
[opt](function) Optimize the trim function for single-char inputs (#3…
(#37799)
https://github.com/apache/doris/pull/36497
before
```
mysql [test]>select count(ltrim(str,"1")) from stringDb2;
+------------------------+
| count(ltrim(str, '1')) |
+------------------------+
| 64000000 |
+------------------------+
1 row in set (7.79 sec)
```
now
```
mysql [test]>select count(ltrim(str,"1")) from stringDb2;
+------------------------+
| count(ltrim(str, '1')) |
+------------------------+
| 64000000 |
+------------------------+
1 row in set (0.73 sec)
```
## Proposed changes
Issue Number: close #xxx
<!--Describe your changes.-->
---
be/src/util/simd/vstring_function.h | 196 ++++++---------------
be/src/vec/functions/function_string.cpp | 54 +++---
.../correctness/test_trim_new_parameters.groovy | 3 +
3 files changed, 92 insertions(+), 161 deletions(-)
diff --git a/be/src/util/simd/vstring_function.h
b/be/src/util/simd/vstring_function.h
index dac964b1b94..4fff59a01df 100644
--- a/be/src/util/simd/vstring_function.h
+++ b/be/src/util/simd/vstring_function.h
@@ -17,6 +17,7 @@
#pragma once
+#include <immintrin.h>
#include <unistd.h>
#include <array>
@@ -100,169 +101,86 @@ public:
/// n equals to 16 chars length
static constexpr auto REGISTER_SIZE = sizeof(__m128i);
#endif
-public:
- static StringRef rtrim(const StringRef& str) {
- if (str.size == 0) {
- return str;
- }
- auto begin = 0;
- int64_t end = str.size - 1;
-#if defined(__SSE2__) || defined(__aarch64__)
- char blank = ' ';
- const auto pattern = _mm_set1_epi8(blank);
- while (end - begin + 1 >= REGISTER_SIZE) {
- const auto v_haystack = _mm_loadu_si128(
- reinterpret_cast<const __m128i*>(str.data + end + 1 -
REGISTER_SIZE));
- const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, pattern);
- const auto mask = _mm_movemask_epi8(v_against_pattern);
- int offset = __builtin_clz(~(mask << REGISTER_SIZE));
- /// means not found
- if (offset == 0) {
- return StringRef(str.data + begin, end - begin + 1);
- } else {
- end -= offset;
- }
- }
-#endif
- while (end >= begin && str.data[end] == ' ') {
- --end;
- }
- if (end < 0) {
- return StringRef("");
- }
- return StringRef(str.data + begin, end - begin + 1);
- }
-
- static StringRef ltrim(const StringRef& str) {
- if (str.size == 0) {
- return str;
- }
- auto begin = 0;
- auto end = str.size - 1;
-#if defined(__SSE2__) || defined(__aarch64__)
- char blank = ' ';
- const auto pattern = _mm_set1_epi8(blank);
- while (end - begin + 1 >= REGISTER_SIZE) {
- const auto v_haystack =
- _mm_loadu_si128(reinterpret_cast<const __m128i*>(str.data
+ begin));
- const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, pattern);
- const auto mask = _mm_movemask_epi8(v_against_pattern) ^ 0xffff;
- /// zero means not found
- if (mask == 0) {
- begin += REGISTER_SIZE;
- } else {
- const auto offset = __builtin_ctz(mask);
- begin += offset;
- return StringRef(str.data + begin, end - begin + 1);
- }
- }
-#endif
- while (begin <= end && str.data[begin] == ' ') {
- ++begin;
- }
- return StringRef(str.data + begin, end - begin + 1);
- }
- static StringRef trim(const StringRef& str) {
- if (str.size == 0) {
- return str;
+ template <bool trim_single>
+ static inline const unsigned char* rtrim(const unsigned char* begin, const
unsigned char* end,
+ const StringRef& remove_str) {
+ if (remove_str.size == 0) {
+ return end;
}
- return rtrim(ltrim(str));
- }
+ const auto* p = end;
- static StringRef rtrim(const StringRef& str, const StringRef& rhs) {
- if (str.size == 0 || rhs.size == 0) {
- return str;
- }
- if (rhs.size == 1) {
- auto begin = 0;
- int64_t end = str.size - 1;
- const char blank = rhs.data[0];
-#if defined(__SSE2__) || defined(__aarch64__)
- const auto pattern = _mm_set1_epi8(blank);
- while (end - begin + 1 >= REGISTER_SIZE) {
- const auto v_haystack = _mm_loadu_si128(
- reinterpret_cast<const __m128i*>(str.data + end + 1 -
REGISTER_SIZE));
- const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack,
pattern);
- const auto mask = _mm_movemask_epi8(v_against_pattern);
- int offset = __builtin_clz(~(mask << REGISTER_SIZE));
- /// means not found
- if (offset == 0) {
- return StringRef(str.data + begin, end - begin + 1);
- } else {
- end -= offset;
+ if constexpr (trim_single) {
+ const auto ch = remove_str.data[0];
+#if defined(__AVX2__) || defined(__aarch64__)
+ constexpr auto AVX2_BYTES = sizeof(__m256i);
+ const auto size = end - begin;
+ const auto* const avx2_begin = end - size / AVX2_BYTES *
AVX2_BYTES;
+ const auto spaces = _mm256_set1_epi8(ch);
+ for (p = end - AVX2_BYTES; p >= avx2_begin; p -= AVX2_BYTES) {
+ uint32_t masks = _mm256_movemask_epi8(
+ _mm256_cmpeq_epi8(_mm256_loadu_si256((__m256i*)p),
spaces));
+ if ((~masks)) {
+ break;
}
}
+ p += AVX2_BYTES;
#endif
- while (end >= begin && str.data[end] == blank) {
- --end;
- }
- if (end < 0) {
- return StringRef("");
+ for (; (p - 1) >= begin && *(p - 1) == ch; p--) {
}
- return StringRef(str.data + begin, end - begin + 1);
+ return p;
}
- auto begin = 0;
- auto end = str.size - 1;
- const auto rhs_size = rhs.size;
- while (end - begin + 1 >= rhs_size) {
- if (memcmp(str.data + end - rhs_size + 1, rhs.data, rhs_size) ==
0) {
- end -= rhs.size;
+
+ const auto remove_size = remove_str.size;
+ const auto* const remove_data = remove_str.data;
+ while (p - begin >= remove_size) {
+ if (memcmp(p - remove_size, remove_data, remove_size) == 0) {
+ p -= remove_str.size;
} else {
break;
}
}
- return StringRef(str.data + begin, end - begin + 1);
+ return p;
}
- static StringRef ltrim(const StringRef& str, const StringRef& rhs) {
- if (str.size == 0 || rhs.size == 0) {
- return str;
+ template <bool trim_single>
+ static inline const unsigned char* ltrim(const unsigned char* begin, const
unsigned char* end,
+ const StringRef& remove_str) {
+ if (remove_str.size == 0) {
+ return begin;
}
- if (str.size == 1) {
- auto begin = 0;
- auto end = str.size - 1;
- const char blank = rhs.data[0];
-#if defined(__SSE2__) || defined(__aarch64__)
- const auto pattern = _mm_set1_epi8(blank);
- while (end - begin + 1 >= REGISTER_SIZE) {
- const auto v_haystack =
- _mm_loadu_si128(reinterpret_cast<const
__m128i*>(str.data + begin));
- const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack,
pattern);
- const auto mask = _mm_movemask_epi8(v_against_pattern) ^
0xffff;
- /// zero means not found
- if (mask == 0) {
- begin += REGISTER_SIZE;
- } else {
- const auto offset = __builtin_ctz(mask);
- begin += offset;
- return StringRef(str.data + begin, end - begin + 1);
+ const auto* p = begin;
+
+ if constexpr (trim_single) {
+ const auto ch = remove_str.data[0];
+#if defined(__AVX2__) || defined(__aarch64__)
+ constexpr auto AVX2_BYTES = sizeof(__m256i);
+ const auto size = end - begin;
+ const auto* const avx2_end = begin + size / AVX2_BYTES *
AVX2_BYTES;
+ const auto spaces = _mm256_set1_epi8(ch);
+ for (; p < avx2_end; p += AVX2_BYTES) {
+ uint32_t masks = _mm256_movemask_epi8(
+ _mm256_cmpeq_epi8(_mm256_loadu_si256((__m256i*)p),
spaces));
+ if ((~masks)) {
+ break;
}
}
#endif
- while (begin <= end && str.data[begin] == blank) {
- ++begin;
+ for (; p < end && *p == ch; ++p) {
}
- return StringRef(str.data + begin, end - begin + 1);
+ return p;
}
- auto begin = 0;
- auto end = str.size - 1;
- const auto rhs_size = rhs.size;
- while (end - begin + 1 >= rhs_size) {
- if (memcmp(str.data + begin, rhs.data, rhs_size) == 0) {
- begin += rhs.size;
+
+ const auto remove_size = remove_str.size;
+ const auto* const remove_data = remove_str.data;
+ while (end - p >= remove_size) {
+ if (memcmp(p, remove_data, remove_size) == 0) {
+ p += remove_str.size;
} else {
break;
}
}
- return StringRef(str.data + begin, end - begin + 1);
- }
-
- static StringRef trim(const StringRef& str, const StringRef& rhs) {
- if (str.size == 0 || rhs.size == 0) {
- return str;
- }
- return rtrim(ltrim(str, rhs), rhs);
+ return p;
}
// Gcc will do auto simd in this function
diff --git a/be/src/vec/functions/function_string.cpp
b/be/src/vec/functions/function_string.cpp
index 841b6561bf7..cf82970185e 100644
--- a/be/src/vec/functions/function_string.cpp
+++ b/be/src/vec/functions/function_string.cpp
@@ -485,25 +485,29 @@ struct NameLTrim {
struct NameRTrim {
static constexpr auto name = "rtrim";
};
-template <bool is_ltrim, bool is_rtrim>
+template <bool is_ltrim, bool is_rtrim, bool trim_single>
struct TrimUtil {
static Status vector(const ColumnString::Chars& str_data,
- const ColumnString::Offsets& str_offsets, const
StringRef& rhs,
+ const ColumnString::Offsets& str_offsets, const
StringRef& remove_str,
ColumnString::Chars& res_data, ColumnString::Offsets&
res_offsets) {
- size_t offset_size = str_offsets.size();
- res_offsets.resize(str_offsets.size());
+ const size_t offset_size = str_offsets.size();
+ res_offsets.resize(offset_size);
+ res_data.reserve(str_data.size());
for (size_t i = 0; i < offset_size; ++i) {
- const char* raw_str = reinterpret_cast<const
char*>(&str_data[str_offsets[i - 1]]);
- ColumnString::Offset size = str_offsets[i] - str_offsets[i - 1];
- StringRef str(raw_str, size);
+ const auto* str_begin = str_data.data() + str_offsets[i - 1];
+ const auto* str_end = str_data.data() + str_offsets[i];
+
if constexpr (is_ltrim) {
- str = simd::VStringFunctions::ltrim(str, rhs);
+ str_begin =
+ simd::VStringFunctions::ltrim<trim_single>(str_begin,
str_end, remove_str);
}
if constexpr (is_rtrim) {
- str = simd::VStringFunctions::rtrim(str, rhs);
+ str_end =
+ simd::VStringFunctions::rtrim<trim_single>(str_begin,
str_end, remove_str);
}
- StringOP::push_value_string(std::string_view((char*)str.data,
str.size), i, res_data,
- res_offsets);
+
+ res_data.insert_assume_reserved(str_begin, str_end);
+ res_offsets[i] = res_data.size();
}
return Status::OK();
}
@@ -521,9 +525,9 @@ struct Trim1Impl {
if (const auto* col = assert_cast<const ColumnString*>(column.get())) {
auto col_res = ColumnString::create();
char blank[] = " ";
- StringRef rhs(blank, 1);
- RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim>::vector(
- col->get_chars(), col->get_offsets(), rhs,
col_res->get_chars(),
+ const StringRef remove_str(blank, 1);
+ RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim, true>::vector(
+ col->get_chars(), col->get_offsets(), remove_str,
col_res->get_chars(),
col_res->get_offsets())));
block.replace_by_position(result, std::move(col_res));
} else {
@@ -550,15 +554,21 @@ struct Trim2Impl {
const auto& rcol =
assert_cast<const
ColumnConst*>(block.get_by_position(arguments[1]).column.get())
->get_data_column_ptr();
- if (auto col = assert_cast<const ColumnString*>(column.get())) {
- if (auto col_right = assert_cast<const ColumnString*>(rcol.get()))
{
+ if (const auto* col = assert_cast<const ColumnString*>(column.get())) {
+ if (const auto* col_right = assert_cast<const
ColumnString*>(rcol.get())) {
auto col_res = ColumnString::create();
- const char* raw_rhs = reinterpret_cast<const
char*>(&(col_right->get_chars()[0]));
- ColumnString::Offset rhs_size = col_right->get_offsets()[0];
- StringRef rhs(raw_rhs, rhs_size);
- RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim>::vector(
- col->get_chars(), col->get_offsets(), rhs,
col_res->get_chars(),
- col_res->get_offsets())));
+ const auto* remove_str_raw = col_right->get_chars().data();
+ const ColumnString::Offset remove_str_size =
col_right->get_offsets()[0];
+ const StringRef remove_str(remove_str_raw, remove_str_size);
+ if (remove_str.size == 1) {
+ RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim,
true>::vector(
+ col->get_chars(), col->get_offsets(), remove_str,
col_res->get_chars(),
+ col_res->get_offsets())));
+ } else {
+ RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim,
false>::vector(
+ col->get_chars(), col->get_offsets(), remove_str,
col_res->get_chars(),
+ col_res->get_offsets())));
+ }
block.replace_by_position(result, std::move(col_res));
} else {
return Status::RuntimeError("Illegal column {} of argument of
function {}",
diff --git a/regression-test/suites/correctness/test_trim_new_parameters.groovy
b/regression-test/suites/correctness/test_trim_new_parameters.groovy
index 3209eb7aae7..17ac4a0c65e 100644
--- a/regression-test/suites/correctness/test_trim_new_parameters.groovy
+++ b/regression-test/suites/correctness/test_trim_new_parameters.groovy
@@ -67,4 +67,7 @@ suite("test_trim_new_parameters") {
rtrim = sql "select rtrim('bcTTTabcabc','abc')"
assertEquals(rtrim[0][0], 'bcTTT')
+
+ trim_one = sql "select
trim('aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabaaaaaaaaaaabcTTTabcabcaaaaaaaaaaaaaaaaaaaaaaaaaabaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa','a')"
+ assertEquals(trim_one[0][0],
'baaaaaaaaaaabcTTTabcabcaaaaaaaaaaaaaaaaaaaaaaaaaab')
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]