This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch branch-1.2-lts
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-1.2-lts by this push:
new b083729e0b [feature](function) Add new parameters to 'trim' in 1.2
(#19730)
b083729e0b is described below
commit b083729e0bd383344afe5216f0f6c4df201bfdab
Author: Mryange <[email protected]>
AuthorDate: Wed May 17 19:28:43 2023 +0800
[feature](function) Add new parameters to 'trim' in 1.2 (#19730)
* update
* format
---
be/src/util/simd/vstring_function.h | 95 ++++++++++++++
.../aggregate_function_collect.h | 2 +-
be/src/vec/functions/function_string.cpp | 140 +++++++++++++++++----
gensrc/script/doris_builtins_functions.py | 12 ++
.../correctness/test_trim_new_parameters.groovy | 70 +++++++++++
5 files changed, 295 insertions(+), 24 deletions(-)
diff --git a/be/src/util/simd/vstring_function.h
b/be/src/util/simd/vstring_function.h
index 3c1a4e7f32..bbc0b25164 100644
--- a/be/src/util/simd/vstring_function.h
+++ b/be/src/util/simd/vstring_function.h
@@ -125,6 +125,101 @@ public:
return rtrim(ltrim(str));
}
+ static StringRef rtrim(const StringRef& str, const StringRef& rhs) {
+ if (str.size == 0 || rhs.size == 0) {
+ return str;
+ }
+ if (rhs.size == 1) {
+ auto begin = 0;
+ int64_t end = str.size - 1;
+ const char blank = rhs.data[0];
+#if defined(__SSE2__) || defined(__aarch64__)
+ const auto pattern = _mm_set1_epi8(blank);
+ while (end - begin + 1 >= REGISTER_SIZE) {
+ const auto v_haystack = _mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(str.data + end + 1 -
REGISTER_SIZE));
+ const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack,
pattern);
+ const auto mask = _mm_movemask_epi8(v_against_pattern);
+ int offset = __builtin_clz(~(mask << REGISTER_SIZE));
+ /// means not found
+ if (offset == 0) {
+ return StringRef(str.data + begin, end - begin + 1);
+ } else {
+ end -= offset;
+ }
+ }
+#endif
+ while (end >= begin && str.data[end] == blank) {
+ --end;
+ }
+ if (end < 0) {
+ return StringRef("", 0);
+ }
+ return StringRef(str.data + begin, end - begin + 1);
+ }
+ auto begin = 0;
+ auto end = str.size - 1;
+ const auto rhs_size = rhs.size;
+ while (end - begin + 1 >= rhs_size) {
+ if (memcmp(str.data + end - rhs_size + 1, rhs.data, rhs_size) ==
0) {
+ end -= rhs.size;
+ } else {
+ break;
+ }
+ }
+ return StringRef(str.data + begin, end - begin + 1);
+ }
+
+ static StringRef ltrim(const StringRef& str, const StringRef& rhs) {
+ if (str.size == 0 || rhs.size == 0) {
+ return str;
+ }
+ if (str.size == 1) {
+ auto begin = 0;
+ auto end = str.size - 1;
+ const char blank = rhs.data[0];
+#if defined(__SSE2__) || defined(__aarch64__)
+ const auto pattern = _mm_set1_epi8(blank);
+ while (end - begin + 1 >= REGISTER_SIZE) {
+ const auto v_haystack =
+ _mm_loadu_si128(reinterpret_cast<const
__m128i*>(str.data + begin));
+ const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack,
pattern);
+ const auto mask = _mm_movemask_epi8(v_against_pattern) ^
0xffff;
+ /// zero means not found
+ if (mask == 0) {
+ begin += REGISTER_SIZE;
+ } else {
+ const auto offset = __builtin_ctz(mask);
+ begin += offset;
+ return StringRef(str.data + begin, end - begin + 1);
+ }
+ }
+#endif
+ while (begin <= end && str.data[begin] == blank) {
+ ++begin;
+ }
+ return StringRef(str.data + begin, end - begin + 1);
+ }
+ auto begin = 0;
+ auto end = str.size - 1;
+ const auto rhs_size = rhs.size;
+ while (end - begin + 1 >= rhs_size) {
+ if (memcmp(str.data + begin, rhs.data, rhs_size) == 0) {
+ begin += rhs.size;
+ } else {
+ break;
+ }
+ }
+ return StringRef(str.data + begin, end - begin + 1);
+ }
+
+ static StringRef trim(const StringRef& str, const StringRef& rhs) {
+ if (str.size == 0 || rhs.size == 0) {
+ return str;
+ }
+ return rtrim(ltrim(str, rhs), rhs);
+ }
+
// Gcc will do auto simd in this function
static bool is_ascii(const StringVal& str) {
char or_code = 0;
diff --git a/be/src/vec/aggregate_functions/aggregate_function_collect.h
b/be/src/vec/aggregate_functions/aggregate_function_collect.h
index 5bd60170bb..897eae5806 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_collect.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_collect.h
@@ -252,7 +252,7 @@ public:
AggregateFunctionCollect(const DataTypePtr& argument_type,
UInt64 max_size_ =
std::numeric_limits<UInt64>::max())
: IAggregateFunctionDataHelper<Data,
AggregateFunctionCollect<Data, HasLimit>>(
- {argument_type},{}),
+ {argument_type}, {}),
return_type(argument_type) {}
std::string get_name() const override {
diff --git a/be/src/vec/functions/function_string.cpp
b/be/src/vec/functions/function_string.cpp
index ad8499dcc7..52066e3813 100644
--- a/be/src/vec/functions/function_string.cpp
+++ b/be/src/vec/functions/function_string.cpp
@@ -309,38 +309,134 @@ struct InitcapImpl {
struct NameTrim {
static constexpr auto name = "trim";
};
-
struct NameLTrim {
static constexpr auto name = "ltrim";
};
-
struct NameRTrim {
static constexpr auto name = "rtrim";
};
-
template <bool is_ltrim, bool is_rtrim>
-struct TrimImpl {
- static Status vector(const ColumnString::Chars& data, const
ColumnString::Offsets& offsets,
+struct TrimUtil {
+ static Status vector(const ColumnString::Chars& str_data,
+ const ColumnString::Offsets& str_offsets, const
StringRef& rhs,
ColumnString::Chars& res_data, ColumnString::Offsets&
res_offsets) {
- size_t offset_size = offsets.size();
- res_offsets.resize(offsets.size());
-
+ size_t offset_size = str_offsets.size();
+ res_offsets.resize(str_offsets.size());
for (size_t i = 0; i < offset_size; ++i) {
- const char* raw_str = reinterpret_cast<const
char*>(&data[offsets[i - 1]]);
- ColumnString::Offset size = offsets[i] - offsets[i - 1];
- StringVal str(raw_str, size);
+ const char* raw_str = reinterpret_cast<const
char*>(&str_data[str_offsets[i - 1]]);
+ ColumnString::Offset size = str_offsets[i] - str_offsets[i - 1];
+ StringRef str(raw_str, size);
if constexpr (is_ltrim) {
- str = simd::VStringFunctions::ltrim(str);
+ str = simd::VStringFunctions::ltrim(str, rhs);
}
if constexpr (is_rtrim) {
- str = simd::VStringFunctions::rtrim(str);
+ str = simd::VStringFunctions::rtrim(str, rhs);
}
- StringOP::push_value_string(std::string_view((char*)str.ptr,
str.len), i, res_data,
+ StringOP::push_value_string(std::string_view((char*)str.data,
str.size), i, res_data,
res_offsets);
}
return Status::OK();
}
};
+// This is an implementation of a parameter for the Trim function.
+template <bool is_ltrim, bool is_rtrim, typename Name>
+struct Trim1Impl {
+ static constexpr auto name = Name::name;
+
+ static DataTypes get_variadic_argument_types() { return
{std::make_shared<DataTypeString>()}; }
+
+ static Status execute(FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
+ size_t result, size_t input_rows_count) {
+ const ColumnPtr column = block.get_by_position(arguments[0]).column;
+ if (auto col = assert_cast<const ColumnString*>(column.get())) {
+ auto col_res = ColumnString::create();
+ char blank[] = " ";
+ StringRef rhs(blank, 1);
+ TrimUtil<is_ltrim, is_rtrim>::vector(col->get_chars(),
col->get_offsets(), rhs,
+ col_res->get_chars(),
col_res->get_offsets());
+ block.replace_by_position(result, std::move(col_res));
+ } else {
+ return Status::RuntimeError("Illegal column {} of argument of
function {}",
+
block.get_by_position(arguments[0]).column->get_name(),
+ name);
+ }
+ return Status::OK();
+ }
+};
+
+// This is an implementation of two parameters for the Trim function.
+template <bool is_ltrim, bool is_rtrim, typename Name>
+struct Trim2Impl {
+ static constexpr auto name = Name::name;
+
+ static DataTypes get_variadic_argument_types() {
+ return {std::make_shared<DataTypeString>(),
std::make_shared<DataTypeString>()};
+ }
+
+ static Status execute(FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
+ size_t result, size_t input_rows_count) {
+ const ColumnPtr column = block.get_by_position(arguments[0]).column;
+ const auto& rcol =
+ assert_cast<const
ColumnConst*>(block.get_by_position(arguments[1]).column.get())
+ ->get_data_column_ptr();
+ if (auto col = assert_cast<const ColumnString*>(column.get())) {
+ if (auto col_right = assert_cast<const ColumnString*>(rcol.get()))
{
+ auto col_res = ColumnString::create();
+ const char* raw_rhs = reinterpret_cast<const
char*>(&(col_right->get_chars()[0]));
+ ColumnString::Offset rhs_size = col_right->get_offsets()[0];
+ StringRef rhs(raw_rhs, rhs_size);
+ TrimUtil<is_ltrim, is_rtrim>::vector(col->get_chars(),
col->get_offsets(), rhs,
+ col_res->get_chars(),
col_res->get_offsets());
+ block.replace_by_position(result, std::move(col_res));
+ } else {
+ return Status::RuntimeError("Illegal column {} of argument of
function {}",
+
block.get_by_position(arguments[1]).column->get_name(),
+ name);
+ }
+
+ } else {
+ return Status::RuntimeError("Illegal column {} of argument of
function {}",
+
block.get_by_position(arguments[0]).column->get_name(),
+ name);
+ }
+ return Status::OK();
+ }
+};
+
+template <typename impl>
+class FunctionTrim : public IFunction {
+public:
+ static constexpr auto name = impl::name;
+ static FunctionPtr create() { return
std::make_shared<FunctionTrim<impl>>(); }
+ String get_name() const override { return impl::name; }
+
+ size_t get_number_of_arguments() const override {
+ return get_variadic_argument_types_impl().size();
+ }
+
+ bool get_is_injective(const Block&) override { return false; }
+
+ DataTypePtr get_return_type_impl(const DataTypes& arguments) const
override {
+ if (!is_string_or_fixed_string(arguments[0])) {
+ LOG(FATAL) << fmt::format("Illegal type {} of argument of function
{}",
+ arguments[0]->get_name(), get_name());
+ }
+ return arguments[0];
+ }
+ // The second parameter of "trim" is a constant.
+ ColumnNumbers get_arguments_that_are_always_constant() const override {
return {1}; }
+
+ bool use_default_implementation_for_constants() const override { return
true; }
+
+ DataTypes get_variadic_argument_types_impl() const override {
+ return impl::get_variadic_argument_types();
+ }
+
+ Status execute_impl(FunctionContext* context, Block& block, const
ColumnNumbers& arguments,
+ size_t result, size_t input_rows_count) override {
+ return impl::execute(context, block, arguments, result,
input_rows_count);
+ }
+};
struct UnHexImpl {
static constexpr auto name = "unhex";
@@ -631,12 +727,7 @@ using FunctionToUpper =
FunctionStringToString<TransferImpl<NameToUpper>, NameTo
using FunctionToInitcap = FunctionStringToString<InitcapImpl, NameToInitcap>;
-using FunctionLTrim = FunctionStringToString<TrimImpl<true, false>, NameLTrim>;
-
-using FunctionRTrim = FunctionStringToString<TrimImpl<false, true>, NameRTrim>;
-
-using FunctionTrim = FunctionStringToString<TrimImpl<true, true>, NameTrim>;
-
+using FunctionUnHex = FunctionStringOperateToNullType<UnHexImpl>;
using FunctionToBase64 = FunctionStringOperateToNullType<ToBase64Impl>;
using FunctionFromBase64 = FunctionStringOperateToNullType<FromBase64Impl>;
@@ -663,9 +754,12 @@ void register_function_string(SimpleFunctionFactory&
factory) {
factory.register_function<FunctionToLower>();
factory.register_function<FunctionToUpper>();
factory.register_function<FunctionToInitcap>();
- factory.register_function<FunctionLTrim>();
- factory.register_function<FunctionRTrim>();
- factory.register_function<FunctionTrim>();
+ factory.register_function<FunctionTrim<Trim1Impl<true, true, NameTrim>>>();
+ factory.register_function<FunctionTrim<Trim1Impl<true, false,
NameLTrim>>>();
+ factory.register_function<FunctionTrim<Trim1Impl<false, true,
NameRTrim>>>();
+ factory.register_function<FunctionTrim<Trim2Impl<true, true, NameTrim>>>();
+ factory.register_function<FunctionTrim<Trim2Impl<true, false,
NameLTrim>>>();
+ factory.register_function<FunctionTrim<Trim2Impl<false, true,
NameRTrim>>>();
factory.register_function<FunctionConvertTo>();
factory.register_function<FunctionSubstring<Substr3Impl>>();
factory.register_function<FunctionSubstring<Substr2Impl>>();
diff --git a/gensrc/script/doris_builtins_functions.py
b/gensrc/script/doris_builtins_functions.py
index 98656e03f9..3538ecfa93 100755
--- a/gensrc/script/doris_builtins_functions.py
+++ b/gensrc/script/doris_builtins_functions.py
@@ -2206,10 +2206,16 @@ visible_functions = [
'_ZN5doris15StringFunctions7initcapEPN9doris_udf15FunctionContextERKNS1_9StringValE',
'', '', 'vec', ''],
[['trim'], 'VARCHAR', ['VARCHAR'],
'_ZN5doris15StringFunctions4trimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
'', '', 'vec', ''],
+ [['trim'], 'VARCHAR', ['VARCHAR','VARCHAR'],
+
'_ZN5doris15StringFunctions4trimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
'', '', 'vec', ''],
[['ltrim'], 'VARCHAR', ['VARCHAR'],
'_ZN5doris15StringFunctions5ltrimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
'', '', 'vec', ''],
+ [['ltrim'], 'VARCHAR', ['VARCHAR','VARCHAR'],
+
'_ZN5doris15StringFunctions5ltrimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
'', '', 'vec', ''],
[['rtrim'], 'VARCHAR', ['VARCHAR'],
'_ZN5doris15StringFunctions5rtrimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
'', '', 'vec', ''],
+ [['rtrim'], 'VARCHAR', ['VARCHAR','VARCHAR'],
+
'_ZN5doris15StringFunctions5rtrimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
'', '', 'vec', ''],
[['ascii'], 'INT', ['VARCHAR'],
'_ZN5doris15StringFunctions5asciiEPN9doris_udf15FunctionContextERKNS1_9StringValE',
'', '', 'vec', ''],
[['instr'], 'INT', ['VARCHAR', 'VARCHAR'],
@@ -2390,10 +2396,16 @@ visible_functions = [
'_ZN5doris15StringFunctions5upperEPN9doris_udf15FunctionContextERKNS1_9StringValE',
'', '', 'vec', ''],
[['trim'], 'STRING', ['STRING'],
'_ZN5doris15StringFunctions4trimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
'', '', 'vec', ''],
+ [['trim'], 'STRING', ['STRING','STRING'],
+
'_ZN5doris15StringFunctions4trimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
'', '', 'vec', ''],
[['ltrim'], 'STRING', ['STRING'],
'_ZN5doris15StringFunctions5ltrimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
'', '', 'vec', ''],
+ [['ltrim'], 'STRING', ['STRING','STRING'],
+
'_ZN5doris15StringFunctions5ltrimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
'', '', 'vec', ''],
[['rtrim'], 'STRING', ['STRING'],
'_ZN5doris15StringFunctions5rtrimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
'', '', 'vec', ''],
+ [['rtrim'], 'STRING', ['STRING','STRING'],
+
'_ZN5doris15StringFunctions5rtrimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
'', '', 'vec', ''],
[['ascii'], 'INT', ['STRING'],
'_ZN5doris15StringFunctions5asciiEPN9doris_udf15FunctionContextERKNS1_9StringValE',
'', '', 'vec', ''],
[['instr'], 'INT', ['STRING', 'STRING'],
diff --git a/regression-test/suites/correctness/test_trim_new_parameters.groovy
b/regression-test/suites/correctness/test_trim_new_parameters.groovy
new file mode 100644
index 0000000000..3209eb7aae
--- /dev/null
+++ b/regression-test/suites/correctness/test_trim_new_parameters.groovy
@@ -0,0 +1,70 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_trim_new_parameters") {
+ sql """ DROP TABLE IF EXISTS tbl_trim_new_parameters """
+ sql """
+ CREATE TABLE tbl_trim_new_parameters (
+ id INT DEFAULT '10',
+ username VARCHAR(32) DEFAULT ''
+ ) ENGINE=OLAP
+ AGGREGATE KEY(id,username)
+ DISTRIBUTED BY HASH(id) BUCKETS 10
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1",
+ "in_memory" = "false",
+ "storage_format" = "V2"
+ );
+ """
+ sql """
+ insert into tbl_trim_new_parameters values(1,'abcabccccabc')
+ """
+ sql """
+ insert into tbl_trim_new_parameters values(2,'abcabcabc')
+ """
+ sql """
+ insert into tbl_trim_new_parameters values(3,'')
+ """
+
+ List<List<Object>> results = sql "select id,trim(username,'abc') from
tbl_trim_new_parameters order by id"
+
+ assertEquals(results.size(), 3)
+ assertEquals(results[0][0], 1)
+ assertEquals(results[1][0], 2)
+ assertEquals(results[2][0], 3)
+ assertEquals(results[0][1], 'ccc')
+ assertEquals(results[1][1], '')
+ assertEquals(results[2][1], '')
+
+ List<List<Object>> trim = sql "select trim(' abc ')"
+ assertEquals(trim[0][0], 'abc')
+
+ List<List<Object>> ltrim = sql "select ltrim(' abc ')"
+ assertEquals(ltrim[0][0], 'abc ')
+
+ List<List<Object>> rtrim = sql "select rtrim(' abc ')"
+ assertEquals(rtrim[0][0], ' abc')
+
+ trim = sql "select trim('abcabcTTTabcabc','abc')"
+ assertEquals(trim[0][0], 'TTT')
+
+ ltrim = sql "select ltrim('abcabcTTTbc','abc')"
+ assertEquals(ltrim[0][0], 'TTTbc')
+
+ rtrim = sql "select rtrim('bcTTTabcabc','abc')"
+ assertEquals(rtrim[0][0], 'bcTTT')
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]