HappenLee commented on code in PR #51652:
URL: https://github.com/apache/doris/pull/51652#discussion_r2146942126


##########
be/src/vec/functions/function_regexp.cpp:
##########
@@ -43,13 +43,158 @@
 #include "vec/core/types.h"
 #include "vec/data_types/data_type.h"
 #include "vec/data_types/data_type_nullable.h"
+#include "vec/data_types/data_type_number.h"
 #include "vec/data_types/data_type_string.h"
 #include "vec/functions/function.h"
 #include "vec/functions/simple_function_factory.h"
 #include "vec/utils/stringop_substring.h"
 
 namespace doris::vectorized {
 #include "common/compile_check_begin.h"
+struct RegexpCountImpl {
+    static void execute_impl(FunctionContext* context, ColumnPtr 
argument_columns[],
+                             size_t input_rows_count, ColumnInt32::Container& 
result_data,
+                             NullMap& null_map) {
+        const auto* str_col = 
check_and_get_column<ColumnString>(argument_columns[0].get());
+        const auto* pattern_col = 
check_and_get_column<ColumnString>(argument_columns[1].get());
+        for (int i = 0; i < input_rows_count; ++i) {
+            if (null_map[i]) {
+                result_data[i] = 0;
+                continue;
+            }
+            result_data[i] = _execute_inner_loop(context, str_col, 
pattern_col, null_map, i);
+        }
+    }
+
+    static int _execute_inner_loop(FunctionContext* context, const 
ColumnString* str_col,
+                                   const ColumnString* pattern_col, NullMap& 
null_map,
+                                   const size_t index_now) {
+        re2::RE2* re = reinterpret_cast<re2::RE2*>(
+                context->get_function_state(FunctionContext::THREAD_LOCAL));
+        std::unique_ptr<re2::RE2> scoped_re;
+        if (re == nullptr) {
+            std::string error_str;
+            const auto& pattern = 
pattern_col->get_data_at(index_check_const(index_now, false));
+            bool st = StringFunctions::compile_regex(pattern, &error_str, 
StringRef(), StringRef(),
+                                                     scoped_re);
+            if (!st) {
+                context->add_warning(error_str.c_str());
+                null_map[index_now] = 1;
+                return 0;
+            }
+            re = scoped_re.get();
+        }
+
+        const auto& str = str_col->get_data_at(index_now);
+
+        int count = 0;
+        size_t pos = 0;
+        while (pos < str.size) {
+            auto str_pos = str.data + pos;
+            auto str_size = str.size - pos;
+            re2::StringPiece str_sp_current = re2::StringPiece(str_pos, 
str_size);
+            re2::StringPiece match;
+
+            bool success = re->Match(str_sp_current, 0, str_size, 
re2::RE2::UNANCHORED, &match, 1);
+            if (!success) {
+                break;
+            }
+            if (match.empty()) {
+                pos += 1;
+                continue;
+            }
+            count++;
+            size_t match_start = match.data() - str_sp_current.data();
+            pos += match_start + match.size();
+        }
+
+        return count;
+    }
+};
+
+class FunctionRegexpCount : public IFunction {
+public:
+    static constexpr auto name = "regexp_count";
+
+    static FunctionPtr create() { return 
std::make_shared<FunctionRegexpCount>(); }
+
+    String get_name() const override { return name; }
+
+    size_t get_number_of_arguments() const override { return 2; }
+
+    DataTypePtr get_return_type_impl(const DataTypes& arguments) const 
override {
+        DataTypePtr int32_type = std::make_shared<DataTypeInt32>();
+        bool is_nullable = false;
+        for (const auto& arg : arguments) {
+            if (arg->is_nullable()) {
+                is_nullable = true;
+                break;
+            }
+        }
+        return is_nullable ? make_nullable(int32_type) : int32_type;
+    }
+
+    Status open(FunctionContext* context, FunctionContext::FunctionStateScope 
scope) override {
+        if (scope == FunctionContext::THREAD_LOCAL) {
+            if (context->is_col_constant(1)) {
+                DCHECK(!context->get_function_state(scope));
+                const auto pattern_col = 
context->get_constant_col(1)->column_ptr;
+                const auto& pattern = pattern_col->get_data_at(0);
+                if (pattern.size == 0) {
+                    return Status::OK();
+                }
+
+                std::string error_str;
+                std::unique_ptr<re2::RE2> scoped_re;
+                bool st = StringFunctions::compile_regex(pattern, &error_str, 
StringRef(),
+                                                         StringRef(), 
scoped_re);
+                if (!st) {
+                    context->set_error(error_str.c_str());
+                    return Status::InvalidArgument(error_str);
+                }
+                std::shared_ptr<re2::RE2> re(scoped_re.release());
+                context->set_function_state(scope, re);
+            }
+        }
+        return Status::OK();
+    }
+
+    Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
+                        uint32_t result, size_t input_rows_count) const 
override {
+        auto result_null_map = ColumnUInt8::create(input_rows_count, 0);
+        auto result_data_column = ColumnInt32::create(input_rows_count);
+        auto& result_data = result_data_column->get_data();
+
+        bool col_const[2];
+        ColumnPtr argument_columns[2];

Review Comment:
   check const in open state,not need check const here.  if `re != nullptr` 
means arg 1 is const



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to