This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-1.2-lts
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-1.2-lts by this push:
     new b083729e0b [feature](function) Add new parameters to 'trim' in 1.2 
(#19730)
b083729e0b is described below

commit b083729e0bd383344afe5216f0f6c4df201bfdab
Author: Mryange <[email protected]>
AuthorDate: Wed May 17 19:28:43 2023 +0800

    [feature](function) Add new parameters to 'trim' in 1.2 (#19730)
    
    * update
    
    * format
---
 be/src/util/simd/vstring_function.h                |  95 ++++++++++++++
 .../aggregate_function_collect.h                   |   2 +-
 be/src/vec/functions/function_string.cpp           | 140 +++++++++++++++++----
 gensrc/script/doris_builtins_functions.py          |  12 ++
 .../correctness/test_trim_new_parameters.groovy    |  70 +++++++++++
 5 files changed, 295 insertions(+), 24 deletions(-)

diff --git a/be/src/util/simd/vstring_function.h 
b/be/src/util/simd/vstring_function.h
index 3c1a4e7f32..bbc0b25164 100644
--- a/be/src/util/simd/vstring_function.h
+++ b/be/src/util/simd/vstring_function.h
@@ -125,6 +125,101 @@ public:
         return rtrim(ltrim(str));
     }
 
+    static StringRef rtrim(const StringRef& str, const StringRef& rhs) {
+        if (str.size == 0 || rhs.size == 0) {
+            return str;
+        }
+        if (rhs.size == 1) {
+            auto begin = 0;
+            int64_t end = str.size - 1;
+            const char blank = rhs.data[0];
+#if defined(__SSE2__) || defined(__aarch64__)
+            const auto pattern = _mm_set1_epi8(blank);
+            while (end - begin + 1 >= REGISTER_SIZE) {
+                const auto v_haystack = _mm_loadu_si128(
+                        reinterpret_cast<const __m128i*>(str.data + end + 1 - 
REGISTER_SIZE));
+                const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, 
pattern);
+                const auto mask = _mm_movemask_epi8(v_against_pattern);
+                int offset = __builtin_clz(~(mask << REGISTER_SIZE));
+                /// means not found
+                if (offset == 0) {
+                    return StringRef(str.data + begin, end - begin + 1);
+                } else {
+                    end -= offset;
+                }
+            }
+#endif
+            while (end >= begin && str.data[end] == blank) {
+                --end;
+            }
+            if (end < 0) {
+                return StringRef("", 0);
+            }
+            return StringRef(str.data + begin, end - begin + 1);
+        }
+        auto begin = 0;
+        auto end = str.size - 1;
+        const auto rhs_size = rhs.size;
+        while (end - begin + 1 >= rhs_size) {
+            if (memcmp(str.data + end - rhs_size + 1, rhs.data, rhs_size) == 
0) {
+                end -= rhs.size;
+            } else {
+                break;
+            }
+        }
+        return StringRef(str.data + begin, end - begin + 1);
+    }
+
+    static StringRef ltrim(const StringRef& str, const StringRef& rhs) {
+        if (str.size == 0 || rhs.size == 0) {
+            return str;
+        }
+        if (str.size == 1) {
+            auto begin = 0;
+            auto end = str.size - 1;
+            const char blank = rhs.data[0];
+#if defined(__SSE2__) || defined(__aarch64__)
+            const auto pattern = _mm_set1_epi8(blank);
+            while (end - begin + 1 >= REGISTER_SIZE) {
+                const auto v_haystack =
+                        _mm_loadu_si128(reinterpret_cast<const 
__m128i*>(str.data + begin));
+                const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, 
pattern);
+                const auto mask = _mm_movemask_epi8(v_against_pattern) ^ 
0xffff;
+                /// zero means not found
+                if (mask == 0) {
+                    begin += REGISTER_SIZE;
+                } else {
+                    const auto offset = __builtin_ctz(mask);
+                    begin += offset;
+                    return StringRef(str.data + begin, end - begin + 1);
+                }
+            }
+#endif
+            while (begin <= end && str.data[begin] == blank) {
+                ++begin;
+            }
+            return StringRef(str.data + begin, end - begin + 1);
+        }
+        auto begin = 0;
+        auto end = str.size - 1;
+        const auto rhs_size = rhs.size;
+        while (end - begin + 1 >= rhs_size) {
+            if (memcmp(str.data + begin, rhs.data, rhs_size) == 0) {
+                begin += rhs.size;
+            } else {
+                break;
+            }
+        }
+        return StringRef(str.data + begin, end - begin + 1);
+    }
+
+    static StringRef trim(const StringRef& str, const StringRef& rhs) {
+        if (str.size == 0 || rhs.size == 0) {
+            return str;
+        }
+        return rtrim(ltrim(str, rhs), rhs);
+    }
+
     // Gcc will do auto simd in this function
     static bool is_ascii(const StringVal& str) {
         char or_code = 0;
diff --git a/be/src/vec/aggregate_functions/aggregate_function_collect.h 
b/be/src/vec/aggregate_functions/aggregate_function_collect.h
index 5bd60170bb..897eae5806 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_collect.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_collect.h
@@ -252,7 +252,7 @@ public:
     AggregateFunctionCollect(const DataTypePtr& argument_type,
                              UInt64 max_size_ = 
std::numeric_limits<UInt64>::max())
             : IAggregateFunctionDataHelper<Data, 
AggregateFunctionCollect<Data, HasLimit>>(
-                      {argument_type},{}),
+                      {argument_type}, {}),
               return_type(argument_type) {}
 
     std::string get_name() const override {
diff --git a/be/src/vec/functions/function_string.cpp 
b/be/src/vec/functions/function_string.cpp
index ad8499dcc7..52066e3813 100644
--- a/be/src/vec/functions/function_string.cpp
+++ b/be/src/vec/functions/function_string.cpp
@@ -309,38 +309,134 @@ struct InitcapImpl {
 struct NameTrim {
     static constexpr auto name = "trim";
 };
-
 struct NameLTrim {
     static constexpr auto name = "ltrim";
 };
-
 struct NameRTrim {
     static constexpr auto name = "rtrim";
 };
-
 template <bool is_ltrim, bool is_rtrim>
-struct TrimImpl {
-    static Status vector(const ColumnString::Chars& data, const 
ColumnString::Offsets& offsets,
+struct TrimUtil {
+    static Status vector(const ColumnString::Chars& str_data,
+                         const ColumnString::Offsets& str_offsets, const 
StringRef& rhs,
                          ColumnString::Chars& res_data, ColumnString::Offsets& 
res_offsets) {
-        size_t offset_size = offsets.size();
-        res_offsets.resize(offsets.size());
-
+        size_t offset_size = str_offsets.size();
+        res_offsets.resize(str_offsets.size());
         for (size_t i = 0; i < offset_size; ++i) {
-            const char* raw_str = reinterpret_cast<const 
char*>(&data[offsets[i - 1]]);
-            ColumnString::Offset size = offsets[i] - offsets[i - 1];
-            StringVal str(raw_str, size);
+            const char* raw_str = reinterpret_cast<const 
char*>(&str_data[str_offsets[i - 1]]);
+            ColumnString::Offset size = str_offsets[i] - str_offsets[i - 1];
+            StringRef str(raw_str, size);
             if constexpr (is_ltrim) {
-                str = simd::VStringFunctions::ltrim(str);
+                str = simd::VStringFunctions::ltrim(str, rhs);
             }
             if constexpr (is_rtrim) {
-                str = simd::VStringFunctions::rtrim(str);
+                str = simd::VStringFunctions::rtrim(str, rhs);
             }
-            StringOP::push_value_string(std::string_view((char*)str.ptr, 
str.len), i, res_data,
+            StringOP::push_value_string(std::string_view((char*)str.data, 
str.size), i, res_data,
                                         res_offsets);
         }
         return Status::OK();
     }
 };
+// This is an implementation of a parameter for the Trim function.
+template <bool is_ltrim, bool is_rtrim, typename Name>
+struct Trim1Impl {
+    static constexpr auto name = Name::name;
+
+    static DataTypes get_variadic_argument_types() { return 
{std::make_shared<DataTypeString>()}; }
+
+    static Status execute(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
+                          size_t result, size_t input_rows_count) {
+        const ColumnPtr column = block.get_by_position(arguments[0]).column;
+        if (auto col = assert_cast<const ColumnString*>(column.get())) {
+            auto col_res = ColumnString::create();
+            char blank[] = " ";
+            StringRef rhs(blank, 1);
+            TrimUtil<is_ltrim, is_rtrim>::vector(col->get_chars(), 
col->get_offsets(), rhs,
+                                                 col_res->get_chars(), 
col_res->get_offsets());
+            block.replace_by_position(result, std::move(col_res));
+        } else {
+            return Status::RuntimeError("Illegal column {} of argument of 
function {}",
+                                        
block.get_by_position(arguments[0]).column->get_name(),
+                                        name);
+        }
+        return Status::OK();
+    }
+};
+
+// This is an implementation of two parameters for the Trim function.
+template <bool is_ltrim, bool is_rtrim, typename Name>
+struct Trim2Impl {
+    static constexpr auto name = Name::name;
+
+    static DataTypes get_variadic_argument_types() {
+        return {std::make_shared<DataTypeString>(), 
std::make_shared<DataTypeString>()};
+    }
+
+    static Status execute(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
+                          size_t result, size_t input_rows_count) {
+        const ColumnPtr column = block.get_by_position(arguments[0]).column;
+        const auto& rcol =
+                assert_cast<const 
ColumnConst*>(block.get_by_position(arguments[1]).column.get())
+                        ->get_data_column_ptr();
+        if (auto col = assert_cast<const ColumnString*>(column.get())) {
+            if (auto col_right = assert_cast<const ColumnString*>(rcol.get())) 
{
+                auto col_res = ColumnString::create();
+                const char* raw_rhs = reinterpret_cast<const 
char*>(&(col_right->get_chars()[0]));
+                ColumnString::Offset rhs_size = col_right->get_offsets()[0];
+                StringRef rhs(raw_rhs, rhs_size);
+                TrimUtil<is_ltrim, is_rtrim>::vector(col->get_chars(), 
col->get_offsets(), rhs,
+                                                     col_res->get_chars(), 
col_res->get_offsets());
+                block.replace_by_position(result, std::move(col_res));
+            } else {
+                return Status::RuntimeError("Illegal column {} of argument of 
function {}",
+                                            
block.get_by_position(arguments[1]).column->get_name(),
+                                            name);
+            }
+
+        } else {
+            return Status::RuntimeError("Illegal column {} of argument of 
function {}",
+                                        
block.get_by_position(arguments[0]).column->get_name(),
+                                        name);
+        }
+        return Status::OK();
+    }
+};
+
+template <typename impl>
+class FunctionTrim : public IFunction {
+public:
+    static constexpr auto name = impl::name;
+    static FunctionPtr create() { return 
std::make_shared<FunctionTrim<impl>>(); }
+    String get_name() const override { return impl::name; }
+
+    size_t get_number_of_arguments() const override {
+        return get_variadic_argument_types_impl().size();
+    }
+
+    bool get_is_injective(const Block&) override { return false; }
+
+    DataTypePtr get_return_type_impl(const DataTypes& arguments) const 
override {
+        if (!is_string_or_fixed_string(arguments[0])) {
+            LOG(FATAL) << fmt::format("Illegal type {} of argument of function 
{}",
+                                      arguments[0]->get_name(), get_name());
+        }
+        return arguments[0];
+    }
+    // The second parameter of "trim" is a constant.
+    ColumnNumbers get_arguments_that_are_always_constant() const override { 
return {1}; }
+
+    bool use_default_implementation_for_constants() const override { return 
true; }
+
+    DataTypes get_variadic_argument_types_impl() const override {
+        return impl::get_variadic_argument_types();
+    }
+
+    Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
+                        size_t result, size_t input_rows_count) override {
+        return impl::execute(context, block, arguments, result, 
input_rows_count);
+    }
+};
 
 struct UnHexImpl {
     static constexpr auto name = "unhex";
@@ -631,12 +727,7 @@ using FunctionToUpper = 
FunctionStringToString<TransferImpl<NameToUpper>, NameTo
 
 using FunctionToInitcap = FunctionStringToString<InitcapImpl, NameToInitcap>;
 
-using FunctionLTrim = FunctionStringToString<TrimImpl<true, false>, NameLTrim>;
-
-using FunctionRTrim = FunctionStringToString<TrimImpl<false, true>, NameRTrim>;
-
-using FunctionTrim = FunctionStringToString<TrimImpl<true, true>, NameTrim>;
-
+using FunctionUnHex = FunctionStringOperateToNullType<UnHexImpl>;
 using FunctionToBase64 = FunctionStringOperateToNullType<ToBase64Impl>;
 
 using FunctionFromBase64 = FunctionStringOperateToNullType<FromBase64Impl>;
@@ -663,9 +754,12 @@ void register_function_string(SimpleFunctionFactory& 
factory) {
     factory.register_function<FunctionToLower>();
     factory.register_function<FunctionToUpper>();
     factory.register_function<FunctionToInitcap>();
-    factory.register_function<FunctionLTrim>();
-    factory.register_function<FunctionRTrim>();
-    factory.register_function<FunctionTrim>();
+    factory.register_function<FunctionTrim<Trim1Impl<true, true, NameTrim>>>();
+    factory.register_function<FunctionTrim<Trim1Impl<true, false, 
NameLTrim>>>();
+    factory.register_function<FunctionTrim<Trim1Impl<false, true, 
NameRTrim>>>();
+    factory.register_function<FunctionTrim<Trim2Impl<true, true, NameTrim>>>();
+    factory.register_function<FunctionTrim<Trim2Impl<true, false, 
NameLTrim>>>();
+    factory.register_function<FunctionTrim<Trim2Impl<false, true, 
NameRTrim>>>();
     factory.register_function<FunctionConvertTo>();
     factory.register_function<FunctionSubstring<Substr3Impl>>();
     factory.register_function<FunctionSubstring<Substr2Impl>>();
diff --git a/gensrc/script/doris_builtins_functions.py 
b/gensrc/script/doris_builtins_functions.py
index 98656e03f9..3538ecfa93 100755
--- a/gensrc/script/doris_builtins_functions.py
+++ b/gensrc/script/doris_builtins_functions.py
@@ -2206,10 +2206,16 @@ visible_functions = [
             
'_ZN5doris15StringFunctions7initcapEPN9doris_udf15FunctionContextERKNS1_9StringValE',
 '', '', 'vec', ''],
     [['trim'], 'VARCHAR', ['VARCHAR'],
             
'_ZN5doris15StringFunctions4trimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
 '', '', 'vec', ''],
+    [['trim'], 'VARCHAR', ['VARCHAR','VARCHAR'],
+            
'_ZN5doris15StringFunctions4trimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
 '', '', 'vec', ''],
     [['ltrim'], 'VARCHAR', ['VARCHAR'],
             
'_ZN5doris15StringFunctions5ltrimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
 '', '', 'vec', ''],
+    [['ltrim'], 'VARCHAR', ['VARCHAR','VARCHAR'],
+            
'_ZN5doris15StringFunctions5ltrimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
 '', '', 'vec', ''],
     [['rtrim'], 'VARCHAR', ['VARCHAR'],
             
'_ZN5doris15StringFunctions5rtrimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
 '', '', 'vec', ''],
+    [['rtrim'], 'VARCHAR', ['VARCHAR','VARCHAR'],
+            
'_ZN5doris15StringFunctions5rtrimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
 '', '', 'vec', ''],
     [['ascii'], 'INT', ['VARCHAR'],
             
'_ZN5doris15StringFunctions5asciiEPN9doris_udf15FunctionContextERKNS1_9StringValE',
 '', '', 'vec', ''],
     [['instr'], 'INT', ['VARCHAR', 'VARCHAR'],
@@ -2390,10 +2396,16 @@ visible_functions = [
             
'_ZN5doris15StringFunctions5upperEPN9doris_udf15FunctionContextERKNS1_9StringValE',
 '', '', 'vec', ''],
     [['trim'], 'STRING', ['STRING'],
             
'_ZN5doris15StringFunctions4trimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
 '', '', 'vec', ''],
+    [['trim'], 'STRING', ['STRING','STRING'],
+        
'_ZN5doris15StringFunctions4trimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
 '', '', 'vec', ''],
     [['ltrim'], 'STRING', ['STRING'],
             
'_ZN5doris15StringFunctions5ltrimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
 '', '', 'vec', ''],
+    [['ltrim'], 'STRING', ['STRING','STRING'],
+            
'_ZN5doris15StringFunctions5ltrimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
 '', '', 'vec', ''],
     [['rtrim'], 'STRING', ['STRING'],
             
'_ZN5doris15StringFunctions5rtrimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
 '', '', 'vec', ''],
+    [['rtrim'], 'STRING', ['STRING','STRING'],
+            
'_ZN5doris15StringFunctions5rtrimEPN9doris_udf15FunctionContextERKNS1_9StringValE',
 '', '', 'vec', ''],
     [['ascii'], 'INT', ['STRING'],
             
'_ZN5doris15StringFunctions5asciiEPN9doris_udf15FunctionContextERKNS1_9StringValE',
 '', '', 'vec', ''],
     [['instr'], 'INT', ['STRING', 'STRING'],
diff --git a/regression-test/suites/correctness/test_trim_new_parameters.groovy 
b/regression-test/suites/correctness/test_trim_new_parameters.groovy
new file mode 100644
index 0000000000..3209eb7aae
--- /dev/null
+++ b/regression-test/suites/correctness/test_trim_new_parameters.groovy
@@ -0,0 +1,70 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_trim_new_parameters") {
+     sql """ DROP TABLE IF EXISTS tbl_trim_new_parameters """
+      sql """
+        CREATE TABLE tbl_trim_new_parameters (
+            id INT DEFAULT '10',
+            username VARCHAR(32) DEFAULT ''
+        ) ENGINE=OLAP
+        AGGREGATE KEY(id,username)
+        DISTRIBUTED BY HASH(id) BUCKETS 10
+        PROPERTIES (
+         "replication_allocation" = "tag.location.default: 1",
+         "in_memory" = "false",
+         "storage_format" = "V2"
+        );
+    """
+     sql """
+        insert into tbl_trim_new_parameters values(1,'abcabccccabc')
+    """
+    sql """
+        insert into tbl_trim_new_parameters values(2,'abcabcabc')
+    """
+    sql """
+        insert into tbl_trim_new_parameters values(3,'')
+    """
+
+    List<List<Object>> results = sql "select id,trim(username,'abc') from 
tbl_trim_new_parameters order by id"
+
+    assertEquals(results.size(), 3)
+    assertEquals(results[0][0], 1)
+    assertEquals(results[1][0], 2)
+    assertEquals(results[2][0], 3)
+    assertEquals(results[0][1], 'ccc')
+    assertEquals(results[1][1], '')
+    assertEquals(results[2][1], '')
+
+    List<List<Object>> trim = sql "select trim('   abc   ')"
+    assertEquals(trim[0][0], 'abc')
+
+    List<List<Object>> ltrim = sql "select ltrim('   abc   ')"
+    assertEquals(ltrim[0][0], 'abc   ')    
+
+    List<List<Object>> rtrim = sql "select rtrim('   abc   ')"
+    assertEquals(rtrim[0][0], '   abc')   
+
+    trim = sql "select trim('abcabcTTTabcabc','abc')"
+    assertEquals(trim[0][0], 'TTT')
+
+    ltrim = sql "select ltrim('abcabcTTTbc','abc')"
+    assertEquals(ltrim[0][0], 'TTTbc')    
+
+    rtrim = sql "select rtrim('bcTTTabcabc','abc')"
+    assertEquals(rtrim[0][0], 'bcTTT')   
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to