This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch dev-1.0.0
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git

commit c460367d95cab73fc8b818cd27d23fb186245527
Author: Pxl <[email protected]>
AuthorDate: Tue Mar 8 18:57:12 2022 +0800

    [feature][vectorized] support replace() (#8384)
---
 be/src/vec/common/string_ref.h                | 27 ++++++----
 be/src/vec/functions/function_string.cpp      |  1 +
 be/src/vec/functions/function_string.h        | 58 +++++++++++++++++++++
 be/test/exprs/runtime_filter_test.cpp         | 72 +++++++++++++--------------
 be/test/vec/function/function_string_test.cpp | 17 +++++++
 5 files changed, 130 insertions(+), 45 deletions(-)

diff --git a/be/src/vec/common/string_ref.h b/be/src/vec/common/string_ref.h
index 8ecbe07..ebc4dcc 100644
--- a/be/src/vec/common/string_ref.h
+++ b/be/src/vec/common/string_ref.h
@@ -23,6 +23,7 @@
 #include <functional>
 #include <ostream>
 #include <string>
+#include <string_view>
 #include <vector>
 
 #include "gutil/hash/city.h"
@@ -52,6 +53,7 @@ struct StringRef {
     StringRef() = default;
 
     std::string to_string() const { return std::string(data, size); }
+    std::string_view to_string_view() const { return std::string_view(data, 
size); }
 
     explicit operator std::string() const { return to_string(); }
 
@@ -108,22 +110,25 @@ inline bool memequalSSE2Wide(const char* p1, const char* 
p2, size_t size) {
         if (size >= 8) {
             /// Chunks of [8,16] bytes.
             return unaligned_load<uint64_t>(p1) == 
unaligned_load<uint64_t>(p2) &&
-                   unaligned_load<uint64_t>(p1 + size - 8) == 
unaligned_load<uint64_t>(p2 + size - 8);
+                   unaligned_load<uint64_t>(p1 + size - 8) ==
+                           unaligned_load<uint64_t>(p2 + size - 8);
         } else if (size >= 4) {
             /// Chunks of [4,7] bytes.
             return unaligned_load<uint32_t>(p1) == 
unaligned_load<uint32_t>(p2) &&
-                   unaligned_load<uint32_t>(p1 + size - 4) == 
unaligned_load<uint32_t>(p2 + size - 4);
+                   unaligned_load<uint32_t>(p1 + size - 4) ==
+                           unaligned_load<uint32_t>(p2 + size - 4);
         } else if (size >= 2) {
             /// Chunks of [2,3] bytes.
             return unaligned_load<uint16_t>(p1) == 
unaligned_load<uint16_t>(p2) &&
-                   unaligned_load<uint16_t>(p1 + size - 2) == 
unaligned_load<uint16_t>(p2 + size - 2);
+                   unaligned_load<uint16_t>(p1 + size - 2) ==
+                           unaligned_load<uint16_t>(p2 + size - 2);
         } else if (size >= 1) {
             /// A single byte.
             return *p1 == *p2;
         }
         return true;
     }
-    
+
     while (size >= 64) {
         if (compareSSE2x4(p1, p2)) {
             p1 += 64;
@@ -133,11 +138,15 @@ inline bool memequalSSE2Wide(const char* p1, const char* 
p2, size_t size) {
             return false;
     }
 
-    switch (size / 16)
-    {
-        case 3: if (!compareSSE2(p1 + 32, p2 + 32)) return false; 
[[fallthrough]];
-        case 2: if (!compareSSE2(p1 + 16, p2 + 16)) return false; 
[[fallthrough]];
-        case 1: if (!compareSSE2(p1, p2)) return false;
+    switch (size / 16) {
+    case 3:
+        if (!compareSSE2(p1 + 32, p2 + 32)) return false;
+        [[fallthrough]];
+    case 2:
+        if (!compareSSE2(p1 + 16, p2 + 16)) return false;
+        [[fallthrough]];
+    case 1:
+        if (!compareSSE2(p1, p2)) return false;
     }
 
     return compareSSE2(p1 + size - 16, p2 + size - 16);
diff --git a/be/src/vec/functions/function_string.cpp 
b/be/src/vec/functions/function_string.cpp
index 5d84f01..0375994 100644
--- a/be/src/vec/functions/function_string.cpp
+++ b/be/src/vec/functions/function_string.cpp
@@ -668,6 +668,7 @@ void register_function_string(SimpleFunctionFactory& 
factory) {
     factory.register_function<FunctionMoneyFormat<MoneyFormatInt128Impl>>();
     factory.register_function<FunctionMoneyFormat<MoneyFormatDecimalImpl>>();
     factory.register_function<FunctionStringMd5AndSM3<SM3Sum>>();
+    factory.register_function<FunctionReplace>();
 
     factory.register_alias(FunctionLeft::name, "strleft");
     factory.register_alias(FunctionRight::name, "strright");
diff --git a/be/src/vec/functions/function_string.h 
b/be/src/vec/functions/function_string.h
index 613d4ca..cc54008 100644
--- a/be/src/vec/functions/function_string.h
+++ b/be/src/vec/functions/function_string.h
@@ -36,6 +36,7 @@
 #include "vec/columns/column_nullable.h"
 #include "vec/columns/column_string.h"
 #include "vec/columns/columns_number.h"
+#include "vec/common/assert_cast.h"
 #include "vec/common/string_ref.h"
 #include "vec/data_types/data_type_nullable.h"
 #include "vec/data_types/data_type_number.h"
@@ -1271,4 +1272,61 @@ private:
     }
 };
 
+class FunctionReplace : public IFunction {
+public:
+    static constexpr auto name = "replace";
+    static FunctionPtr create() { return std::make_shared<FunctionReplace>(); }
+    String get_name() const override { return name; }
+    size_t get_number_of_arguments() const override { return 3; }
+
+    DataTypePtr get_return_type_impl(const DataTypes& arguments) const 
override {
+        return std::make_shared<DataTypeString>();
+    }
+
+    DataTypes get_variadic_argument_types_impl() const override {
+        return {std::make_shared<DataTypeString>(), 
std::make_shared<DataTypeString>(),
+                std::make_shared<DataTypeString>()};
+    }
+
+    bool use_default_implementation_for_constants() const override { return 
true; }
+
+    Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
+                        size_t result, size_t input_rows_count) override {
+        auto col_origin =
+                
block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
+        auto col_old =
+                
block.get_by_position(arguments[1]).column->convert_to_full_column_if_const();
+        auto col_new =
+                
block.get_by_position(arguments[2]).column->convert_to_full_column_if_const();
+
+        ColumnString::MutablePtr col_res = ColumnString::create();
+
+        for (int i = 0; i < input_rows_count; ++i) {
+            StringRef origin_str =
+                    assert_cast<const 
ColumnString*>(col_origin.get())->get_data_at(i);
+            StringRef old_str = assert_cast<const 
ColumnString*>(col_old.get())->get_data_at(i);
+            StringRef new_str = assert_cast<const 
ColumnString*>(col_new.get())->get_data_at(i);
+
+            std::string result = replace(origin_str.to_string(), 
old_str.to_string_view(),
+                                         new_str.to_string_view());
+            col_res->insert_data(result.data(), result.length());
+        }
+
+        block.replace_by_position(result, std::move(col_res));
+        return Status::OK();
+    }
+
+private:
+    std::string replace(std::string str, std::string_view old_str, 
std::string_view new_str) {
+        std::string::size_type pos = 0;
+        std::string::size_type oldLen = old_str.size();
+        std::string::size_type newLen = new_str.size();
+        while ((pos = str.find(old_str, pos)) != std::string::npos) {
+            str.replace(pos, oldLen, new_str);
+            pos += newLen;
+        }
+        return str;
+    }
+};
+
 } // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/test/exprs/runtime_filter_test.cpp 
b/be/test/exprs/runtime_filter_test.cpp
index 6ca422f..4c4a206 100644
--- a/be/test/exprs/runtime_filter_test.cpp
+++ b/be/test/exprs/runtime_filter_test.cpp
@@ -104,11 +104,13 @@ IRuntimeFilter* 
create_runtime_filter(TRuntimeFilterType::type type, TQueryOptio
     }
 
     IRuntimeFilter* runtime_filter = nullptr;
-    Status status = IRuntimeFilter::create(_runtime_stat,
-                                           
_runtime_stat->instance_mem_tracker().get(), _obj_pool,
-                                           &desc, options, 
RuntimeFilterRole::PRODUCER, -1, &runtime_filter);
+    Status status = IRuntimeFilter::create(
+            _runtime_stat, _runtime_stat->instance_mem_tracker().get(), 
_obj_pool, &desc, options,
+            RuntimeFilterRole::PRODUCER, -1, &runtime_filter);
+
     assert(status.ok());
-    return runtime_filter;
+
+    return status.ok() ? runtime_filter : nullptr;
 }
 
 std::vector<TupleRow>* create_rows(ObjectPool* _obj_pool, int from, int to) {
@@ -142,8 +144,8 @@ TEST_F(RuntimeFilterTest, runtime_filter_basic_test) {
     TQueryOptions options;
     options.runtime_filter_max_in_num = 1024;
 
-    IRuntimeFilter* runtime_filter =
-            create_runtime_filter(TRuntimeFilterType::BLOOM, &options, 
_runtime_stat.get(), &_obj_pool);
+    IRuntimeFilter* runtime_filter = 
create_runtime_filter(TRuntimeFilterType::BLOOM, &options,
+                                                           
_runtime_stat.get(), &_obj_pool);
     insert(runtime_filter, build_expr_ctx, tuple_rows);
 
     // get expr context from filter
@@ -184,12 +186,12 @@ TEST_F(RuntimeFilterTest, 
runtime_filter_merge_in_filter_test) {
     auto rows1 = create_rows(&_obj_pool, 1, 1024);
     auto rows2 = create_rows(&_obj_pool, 1025, 2048);
 
-    IRuntimeFilter* runtime_filter =
-            create_runtime_filter(TRuntimeFilterType::IN, &options, 
_runtime_stat.get(), &_obj_pool);
+    IRuntimeFilter* runtime_filter = 
create_runtime_filter(TRuntimeFilterType::IN, &options,
+                                                           
_runtime_stat.get(), &_obj_pool);
     insert(runtime_filter, build_expr_ctx, rows1);
 
-    IRuntimeFilter* runtime_filter2 =
-            create_runtime_filter(TRuntimeFilterType::IN, &options, 
_runtime_stat.get(), &_obj_pool);
+    IRuntimeFilter* runtime_filter2 = 
create_runtime_filter(TRuntimeFilterType::IN, &options,
+                                                            
_runtime_stat.get(), &_obj_pool);
     insert(runtime_filter2, build_expr_ctx, rows2);
 
     Status status = runtime_filter->merge_from(runtime_filter2->get_wrapper());
@@ -232,12 +234,12 @@ TEST_F(RuntimeFilterTest, 
runtime_filter_ignore_in_filter_test) {
     auto rows1 = create_rows(&_obj_pool, 1, 1);
     auto rows2 = create_rows(&_obj_pool, 2, 2);
 
-    IRuntimeFilter* runtime_filter =
-            create_runtime_filter(TRuntimeFilterType::IN, &options, 
_runtime_stat.get(), &_obj_pool);
+    IRuntimeFilter* runtime_filter = 
create_runtime_filter(TRuntimeFilterType::IN, &options,
+                                                           
_runtime_stat.get(), &_obj_pool);
     insert(runtime_filter, build_expr_ctx, rows1);
 
-    IRuntimeFilter* runtime_filter2 =
-            create_runtime_filter(TRuntimeFilterType::IN, &options, 
_runtime_stat.get(), &_obj_pool);
+    IRuntimeFilter* runtime_filter2 = 
create_runtime_filter(TRuntimeFilterType::IN, &options,
+                                                            
_runtime_stat.get(), &_obj_pool);
     insert(runtime_filter2, build_expr_ctx, rows2);
 
     Status status = runtime_filter->merge_from(runtime_filter2->get_wrapper());
@@ -280,13 +282,13 @@ TEST_F(RuntimeFilterTest, 
runtime_filter_in_or_bloom_filter_in_merge_in_test) {
     auto rows1 = create_rows(&_obj_pool, 1, 1);
     auto rows2 = create_rows(&_obj_pool, 2, 2);
 
-    IRuntimeFilter* runtime_filter =
-            create_runtime_filter(TRuntimeFilterType::IN_OR_BLOOM, &options, 
_runtime_stat.get(), &_obj_pool);
+    IRuntimeFilter* runtime_filter = create_runtime_filter(
+            TRuntimeFilterType::IN_OR_BLOOM, &options, _runtime_stat.get(), 
&_obj_pool);
     insert(runtime_filter, build_expr_ctx, rows1);
     ASSERT_FALSE(runtime_filter->is_bloomfilter());
 
-    IRuntimeFilter* runtime_filter2 =
-            create_runtime_filter(TRuntimeFilterType::IN, &options, 
_runtime_stat.get(), &_obj_pool);
+    IRuntimeFilter* runtime_filter2 = 
create_runtime_filter(TRuntimeFilterType::IN, &options,
+                                                            
_runtime_stat.get(), &_obj_pool);
     insert(runtime_filter2, build_expr_ctx, rows2);
     ASSERT_FALSE(runtime_filter2->is_bloomfilter());
 
@@ -331,13 +333,13 @@ TEST_F(RuntimeFilterTest, 
runtime_filter_in_or_bloom_filter_in_merge_in_upgrade_
     auto rows1 = create_rows(&_obj_pool, 1, 1);
     auto rows2 = create_rows(&_obj_pool, 2, 2);
 
-    IRuntimeFilter* runtime_filter =
-            create_runtime_filter(TRuntimeFilterType::IN_OR_BLOOM, &options, 
_runtime_stat.get(), &_obj_pool);
+    IRuntimeFilter* runtime_filter = create_runtime_filter(
+            TRuntimeFilterType::IN_OR_BLOOM, &options, _runtime_stat.get(), 
&_obj_pool);
     insert(runtime_filter, build_expr_ctx, rows1);
     ASSERT_FALSE(runtime_filter->is_bloomfilter());
 
-    IRuntimeFilter* runtime_filter2 =
-            create_runtime_filter(TRuntimeFilterType::IN, &options, 
_runtime_stat.get(), &_obj_pool);
+    IRuntimeFilter* runtime_filter2 = 
create_runtime_filter(TRuntimeFilterType::IN, &options,
+                                                            
_runtime_stat.get(), &_obj_pool);
     insert(runtime_filter2, build_expr_ctx, rows2);
     ASSERT_FALSE(runtime_filter2->is_bloomfilter());
 
@@ -382,13 +384,13 @@ TEST_F(RuntimeFilterTest, 
runtime_filter_in_or_bloom_filter_in_merge_bloom_filte
     auto rows1 = create_rows(&_obj_pool, 1, 1);
     auto rows2 = create_rows(&_obj_pool, 2, 2);
 
-    IRuntimeFilter* runtime_filter =
-            create_runtime_filter(TRuntimeFilterType::IN_OR_BLOOM, &options, 
_runtime_stat.get(), &_obj_pool);
+    IRuntimeFilter* runtime_filter = create_runtime_filter(
+            TRuntimeFilterType::IN_OR_BLOOM, &options, _runtime_stat.get(), 
&_obj_pool);
     insert(runtime_filter, build_expr_ctx, rows1);
     ASSERT_FALSE(runtime_filter->is_bloomfilter());
 
-    IRuntimeFilter* runtime_filter2 =
-            create_runtime_filter(TRuntimeFilterType::BLOOM, &options, 
_runtime_stat.get(), &_obj_pool);
+    IRuntimeFilter* runtime_filter2 = 
create_runtime_filter(TRuntimeFilterType::BLOOM, &options,
+                                                            
_runtime_stat.get(), &_obj_pool);
     insert(runtime_filter2, build_expr_ctx, rows2);
     ASSERT_TRUE(runtime_filter2->is_bloomfilter());
 
@@ -433,15 +435,15 @@ TEST_F(RuntimeFilterTest, 
runtime_filter_in_or_bloom_filter_bloom_filter_merge_i
     auto rows1 = create_rows(&_obj_pool, 1, 3);
     auto rows2 = create_rows(&_obj_pool, 4, 4);
 
-    IRuntimeFilter* runtime_filter =
-            create_runtime_filter(TRuntimeFilterType::IN_OR_BLOOM, &options, 
_runtime_stat.get(), &_obj_pool);
+    IRuntimeFilter* runtime_filter = create_runtime_filter(
+            TRuntimeFilterType::IN_OR_BLOOM, &options, _runtime_stat.get(), 
&_obj_pool);
     insert(runtime_filter, build_expr_ctx, rows1);
     ASSERT_FALSE(runtime_filter->is_bloomfilter());
     runtime_filter->change_to_bloom_filter();
     ASSERT_TRUE(runtime_filter->is_bloomfilter());
 
-    IRuntimeFilter* runtime_filter2 =
-            create_runtime_filter(TRuntimeFilterType::IN, &options, 
_runtime_stat.get(), &_obj_pool);
+    IRuntimeFilter* runtime_filter2 = 
create_runtime_filter(TRuntimeFilterType::IN, &options,
+                                                            
_runtime_stat.get(), &_obj_pool);
     insert(runtime_filter2, build_expr_ctx, rows2);
     ASSERT_FALSE(runtime_filter2->is_bloomfilter());
 
@@ -486,15 +488,15 @@ TEST_F(RuntimeFilterTest, 
runtime_filter_in_or_bloom_filter_bloom_filter_merge_b
     auto rows1 = create_rows(&_obj_pool, 1, 3);
     auto rows2 = create_rows(&_obj_pool, 4, 6);
 
-    IRuntimeFilter* runtime_filter =
-            create_runtime_filter(TRuntimeFilterType::IN_OR_BLOOM, &options, 
_runtime_stat.get(), &_obj_pool);
+    IRuntimeFilter* runtime_filter = create_runtime_filter(
+            TRuntimeFilterType::IN_OR_BLOOM, &options, _runtime_stat.get(), 
&_obj_pool);
     insert(runtime_filter, build_expr_ctx, rows1);
     ASSERT_FALSE(runtime_filter->is_bloomfilter());
     runtime_filter->change_to_bloom_filter();
     ASSERT_TRUE(runtime_filter->is_bloomfilter());
 
-    IRuntimeFilter* runtime_filter2 =
-            create_runtime_filter(TRuntimeFilterType::BLOOM, &options, 
_runtime_stat.get(), &_obj_pool);
+    IRuntimeFilter* runtime_filter2 = 
create_runtime_filter(TRuntimeFilterType::BLOOM, &options,
+                                                            
_runtime_stat.get(), &_obj_pool);
     insert(runtime_filter2, build_expr_ctx, rows2);
     ASSERT_TRUE(runtime_filter2->is_bloomfilter());
 
@@ -502,8 +504,6 @@ TEST_F(RuntimeFilterTest, 
runtime_filter_in_or_bloom_filter_bloom_filter_merge_b
     ASSERT_TRUE(status.ok());
     ASSERT_FALSE(runtime_filter->is_ignored());
     ASSERT_TRUE(runtime_filter->is_bloomfilter());
-//    
ASSERT_TRUE(runtime_filter->get_profile()->get_info_string("RealRuntimeFilterType")
 ==
-//                        
::doris::to_string(doris::RuntimeFilterType::BLOOM_FILTER);
 
     // get expr context from filter
 
diff --git a/be/test/vec/function/function_string_test.cpp 
b/be/test/vec/function/function_string_test.cpp
index 0a36af2..47bfea4 100644
--- a/be/test/vec/function/function_string_test.cpp
+++ b/be/test/vec/function/function_string_test.cpp
@@ -26,6 +26,7 @@
 #include "util/url_coding.h"
 #include "vec/core/field.h"
 #include "vec/core/types.h"
+#include "vec/data_types/data_type_string.h"
 
 namespace doris::vectorized {
 using namespace ut_type;
@@ -1000,6 +1001,22 @@ TEST(function_string_test, function_str_to_date_test) {
     check_function<DataTypeDateTime, true>(func_name, input_types, data_set);
 }
 
+TEST(function_string_test, function_replace) {
+    std::string func_name = "replace";
+    InputTypeSet input_types = {
+            TypeIndex::String,
+            TypeIndex::String,
+            TypeIndex::String,
+    };
+    DataSet data_set = {{{Null(), VARCHAR("9090"), VARCHAR("")}, {Null()}},
+                        {{VARCHAR("http://www.baidu.com:9090";), 
VARCHAR("9090"), VARCHAR("")},
+                         {VARCHAR("http://www.baidu.com:";)}},
+                        {{VARCHAR("aaaaa"), VARCHAR("a"), VARCHAR("")}, 
{VARCHAR("")}},
+                        {{VARCHAR("aaaaa"), VARCHAR("aa"), VARCHAR("")}, 
{VARCHAR("a")}},
+                        {{VARCHAR("aaaaa"), VARCHAR("aa"), VARCHAR("a")}, 
{VARCHAR("aaa")}}};
+    check_function<DataTypeString, true>(func_name, input_types, data_set);
+}
+
 } // namespace doris::vectorized
 
 int main(int argc, char** argv) {

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to