Shuo-O commented on code in PR #51462:
URL: https://github.com/apache/doris/pull/51462#discussion_r2142096420


##########
be/src/vec/functions/function_regexp.cpp:
##########
@@ -65,6 +66,281 @@ struct FourParamTypes {
     }
 };
 
+struct TwoParamTypes {
+    static DataTypes get_variadic_argument_types() {
+        return {std::make_shared<DataTypeString>(), 
std::make_shared<DataTypeString>()};
+    }
+};
+
+// 
-----------------------------------------------------------------------------
+// RegexpPositionImpl:实现 regexp_position(string, pattern[, start])
+// 第三列在 ThreeParamTypes 中是 String,需要在运行时转为 Int64
+// 
-----------------------------------------------------------------------------
+struct RegexpPositionImpl {
+    static constexpr auto name = "regexp_position";
+
+    // 非 const-args 情况
+    static void execute_impl(FunctionContext* context, ColumnPtr 
argument_columns[],
+                             size_t argument_size, size_t input_rows_count,
+                             ColumnInt64::Container& result_data, NullMap& 
null_map) {
+        const auto* str_col = 
check_and_get_column<ColumnString>(argument_columns[0].get());
+        const auto* pattern_col = 
check_and_get_column<ColumnString>(argument_columns[1].get());
+        const ColumnString* start_str_col = nullptr;
+        if (argument_size == 3) {
+            start_str_col = 
check_and_get_column<ColumnString>(argument_columns[2].get());
+        }
+
+        for (size_t i = 0; i < input_rows_count; ++i) {
+            if (null_map[i]) {
+                result_data[i] = -1;
+                continue;
+            }
+
+            // 1) 把第三列 String 转为 Int64 start_pos
+            Int64 start_pos = 1;
+            if (start_str_col) {
+                auto start_data = start_str_col->get_data_at(i);
+                std::string start_str(start_data.data, start_data.size);
+                try {
+                    start_pos = std::stoll(start_str);
+                } catch (...) {
+                    result_data[i] = -1;
+                    continue;
+                }
+                if (start_pos < 1) {
+                    result_data[i] = -1;
+                    continue;
+                }
+            }
+
+            // 2) 拿 pattern 常量或非常量,以及 THREAD_LOCAL 中预编译的 re2
+            re2::RE2* re = reinterpret_cast<re2::RE2*>(
+                    
context->get_function_state(FunctionContext::THREAD_LOCAL));
+            std::unique_ptr<re2::RE2> scoped_re;
+            if (!re) {
+                auto pattern_data = pattern_col->get_data_at(i);
+                std::string pat(pattern_data.data, pattern_data.size);
+                std::string error_str;
+                bool st = StringFunctions::compile_regex(pat, &error_str, 
StringRef(), StringRef(),
+                                                         scoped_re);
+                if (!st) {
+                    context->add_warning(error_str.c_str());
+                    result_data[i] = -1;
+                    continue;
+                }
+                re = scoped_re.get();
+            }
+
+            // 3) 获取待匹配字符串
+            auto str_data = str_col->get_data_at(i);
+            re2::StringPiece str_sp(str_data.data, str_data.size);
+
+            // 4) 计算 0 基起始下标
+            size_t search_start = static_cast<size_t>(start_pos - 1);
+            if (search_start >= str_data.size) {
+                result_data[i] = -1;
+                continue;
+            }
+
+            // 5) 执行匹配,提取第一组子串
+            re2::StringPiece match;
+            bool success =
+                    re->Match(str_sp, search_start, str_data.size, 
re2::RE2::UNANCHORED, &match, 1);
+            if (!success) {
+                result_data[i] = -1;
+                continue;
+            }
+
+            // 6) 计算 1 基匹配位置
+            size_t match_pos = match.data() - str_data.data + 1;
+            result_data[i] = static_cast<Int64>(match_pos);
+        }
+    }
+
+    // 全常量(pattern & start_str 都为常量)的执行路径
+    static void execute_impl_const_args(FunctionContext* context, ColumnPtr 
argument_columns[],
+                                        size_t argument_size, size_t 
input_rows_count,
+                                        ColumnInt64::Container& result_data, 
NullMap& null_map) {
+        const auto* str_col = 
check_and_get_column<ColumnString>(argument_columns[0].get());
+        const auto* pattern_col = 
check_and_get_column<ColumnString>(argument_columns[1].get());
+        const ColumnString* start_str_col = nullptr;
+        Int64 start_pos_const = 1;
+
+        if (argument_size == 3) {
+            start_str_col = 
check_and_get_column<ColumnString>(argument_columns[2].get());
+            auto start_data = start_str_col->get_data_at(0);
+            std::string start_str(start_data.data, start_data.size);
+            try {
+                start_pos_const = std::stoll(start_str);
+            } catch (...) {
+                for (size_t i = 0; i < input_rows_count; ++i) {
+                    result_data[i] = -1;
+                }
+                return;
+            }
+            if (start_pos_const < 1) {
+                for (size_t i = 0; i < input_rows_count; ++i) {
+                    result_data[i] = -1;
+                }
+                return;
+            }
+        }
+
+        re2::RE2* re = reinterpret_cast<re2::RE2*>(
+                context->get_function_state(FunctionContext::THREAD_LOCAL));
+        if (!re) {
+            for (size_t i = 0; i < input_rows_count; ++i) {
+                result_data[i] = -1;
+            }
+            return;
+        }
+
+        for (size_t i = 0; i < input_rows_count; ++i) {
+            if (null_map[i]) {
+                result_data[i] = -1;
+                continue;
+            }
+            auto str_data = str_col->get_data_at(i);
+            re2::StringPiece str_sp(str_data.data, str_data.size);
+
+            size_t search_start = static_cast<size_t>(start_pos_const - 1);
+            if (search_start >= str_data.size) {
+                result_data[i] = -1;
+                continue;
+            }
+
+            re2::StringPiece match;
+            bool success =
+                    re->Match(str_sp, search_start, str_data.size, 
re2::RE2::UNANCHORED, &match, 1);
+            if (!success) {
+                result_data[i] = -1;
+                continue;
+            }
+
+            size_t match_pos = match.data() - str_data.data + 1;
+            result_data[i] = static_cast<Int64>(match_pos);
+        }
+    }
+};
+
+// 
-----------------------------------------------------------------------------
+// FunctionRegexpPosition:包装 IFunction 接口
+// 
-----------------------------------------------------------------------------
+
+// >>> Added: 新增 FunctionRegexpPosition 类,用于在 ClickHouse 中注册 regexp_position
+class FunctionRegexpPosition : public IFunction {
+public:
+    static constexpr auto name = RegexpPositionImpl::name;
+
+    static FunctionPtr create() { return 
std::make_shared<FunctionRegexpPosition>(); }
+
+    String get_name() const override { return name; }
+
+    // 变参函数
+    size_t get_number_of_arguments() const override { return 0; }
+    bool is_variadic() const override { return true; }
+
+    // 返回类型:Nullable(Int64)
+    DataTypePtr get_return_type_impl(const DataTypes& /*arguments*/) const 
override {
+        return make_nullable(std::make_shared<DataTypeInt64>());
+    }
+
+    // >>> Modified: 根据实际参数个数选 TwoParamTypes or ThreeParamTypes
+    DataTypes get_variadic_argument_types_impl() const override {
+        size_t actual_args = getVariadicNumberOfArguments(); // 框架提供:实际列数

Review Comment:
   I've fixed it
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to