HappenLee commented on code in PR #40929:
URL: https://github.com/apache/doris/pull/40929#discussion_r1764940166


##########
be/src/vec/functions/function_string.h:
##########
@@ -3387,59 +3387,115 @@ struct SubReplaceImpl {
     static Status replace_execute(Block& block, const ColumnNumbers& 
arguments, size_t result,
                                   size_t input_rows_count) {
         auto res_column = ColumnString::create();
-        auto result_column = assert_cast<ColumnString*>(res_column.get());
+        auto* result_column = assert_cast<ColumnString*>(res_column.get());
         auto args_null_map = ColumnUInt8::create(input_rows_count, 0);
         ColumnPtr argument_columns[4];
+        bool col_const[4];
         for (int i = 0; i < 4; ++i) {
-            argument_columns[i] =
-                    
block.get_by_position(arguments[i]).column->convert_to_full_column_if_const();
-            if (auto* nullable = 
check_and_get_column<ColumnNullable>(*argument_columns[i])) {
-                // Danger: Here must dispose the null map data first! Because
-                // argument_columns[i]=nullable->get_nested_column_ptr(); will 
release the mem
-                // of column nullable mem of null map
-                VectorizedUtils::update_null_map(args_null_map->get_data(),
-                                                 
nullable->get_null_map_data());
-                argument_columns[i] = nullable->get_nested_column_ptr();
-            }
+            std::tie(argument_columns[i], col_const[i]) =
+                    
unpack_if_const(block.get_by_position(arguments[i]).column);
         }
+        const auto* data_column = assert_cast<const 
ColumnString*>(argument_columns[0].get());
+        const auto* mask_column = assert_cast<const 
ColumnString*>(argument_columns[1].get());
+        const auto* start_column =
+                assert_cast<const 
ColumnVector<Int32>*>(argument_columns[2].get());
+        const auto* length_column =
+                assert_cast<const 
ColumnVector<Int32>*>(argument_columns[3].get());
 
-        auto data_column = assert_cast<const 
ColumnString*>(argument_columns[0].get());
-        auto mask_column = assert_cast<const 
ColumnString*>(argument_columns[1].get());
-        auto start_column = assert_cast<const 
ColumnVector<Int32>*>(argument_columns[2].get());
-        auto length_column = assert_cast<const 
ColumnVector<Int32>*>(argument_columns[3].get());
-
-        vector(data_column, mask_column, start_column->get_data(), 
length_column->get_data(),
-               args_null_map->get_data(), result_column, input_rows_count);
-
+        std::visit(
+                [&](auto origin_str_const, auto new_str_const, auto 
start_const, auto len_const) {
+                    if (simd::VStringFunctions::is_ascii(
+                                StringRef {data_column->get_chars().data(), 
data_column->size()})) {
+                        vector_ascii<origin_str_const, new_str_const, 
start_const, len_const>(
+                                data_column, mask_column, 
start_column->get_data(),
+                                length_column->get_data(), 
args_null_map->get_data(), result_column,
+                                input_rows_count);
+                    } else {
+                        vector_utf8<origin_str_const, new_str_const, 
start_const, len_const>(
+                                data_column, mask_column, 
start_column->get_data(),
+                                length_column->get_data(), 
args_null_map->get_data(), result_column,
+                                input_rows_count);
+                    }
+                },
+                vectorized::make_bool_variant(col_const[0]),
+                vectorized::make_bool_variant(col_const[1]),
+                vectorized::make_bool_variant(col_const[2]),
+                vectorized::make_bool_variant(col_const[3]));
         block.get_by_position(result).column =
                 ColumnNullable::create(std::move(res_column), 
std::move(args_null_map));
         return Status::OK();
     }
 
 private:
-    static void vector(const ColumnString* data_column, const ColumnString* 
mask_column,
-                       const PaddedPODArray<Int32>& start, const 
PaddedPODArray<Int32>& length,
-                       NullMap& args_null_map, ColumnString* result_column,
-                       size_t input_rows_count) {
+    template <bool origin_str_const, bool new_str_const, bool start_const, 
bool len_const>
+    static void vector_ascii(const ColumnString* data_column, const 
ColumnString* mask_column,
+                             const PaddedPODArray<Int32>& args_start,
+                             const PaddedPODArray<Int32>& args_length, 
NullMap& args_null_map,
+                             ColumnString* result_column, size_t 
input_rows_count) {
         ColumnString::Chars& res_chars = result_column->get_chars();
         ColumnString::Offsets& res_offsets = result_column->get_offsets();
         for (size_t row = 0; row < input_rows_count; ++row) {
-            StringRef origin_str = data_column->get_data_at(row);
-            StringRef new_str = mask_column->get_data_at(row);
-            size_t origin_str_len = origin_str.size;
+            StringRef origin_str =
+                    
data_column->get_data_at(index_check_const<origin_str_const>(row));
+            StringRef new_str = 
mask_column->get_data_at(index_check_const<new_str_const>(row));
+            const auto start = args_start[index_check_const<start_const>(row)];
+            const auto length = args_length[index_check_const<len_const>(row)];
+            const size_t origin_str_len = origin_str.size;
             //input is null, start < 0, len < 0, str_size <= start. return NULL
-            if (args_null_map[row] || start[row] < 0 || length[row] < 0 ||
-                origin_str_len <= start[row]) {
+            if (args_null_map[row] || start < 0 || length < 0 || 
origin_str_len <= start) {
                 res_offsets.push_back(res_chars.size());
                 args_null_map[row] = 1;
             } else {
                 std::string_view replace_str = new_str.to_string_view();
                 std::string result = origin_str.to_string();
-                result.replace(start[row], length[row], replace_str);
+                result.replace(start, length, replace_str);
                 result_column->insert_data(result.data(), result.length());
             }
         }
     }
+
+    template <bool origin_str_const, bool new_str_const, bool start_const, 
bool len_const>
+    static void vector_utf8(const ColumnString* data_column, const 
ColumnString* mask_column,
+                            const PaddedPODArray<Int32>& args_start,
+                            const PaddedPODArray<Int32>& args_length, NullMap& 
args_null_map,
+                            ColumnString* result_column, size_t 
input_rows_count) {
+        ColumnString::Chars& res_chars = result_column->get_chars();
+        ColumnString::Offsets& res_offsets = result_column->get_offsets();
+        PaddedPODArray<size_t> index;
+
+        for (size_t row = 0; row < input_rows_count; ++row) {

Review Comment:
   recheck the logic again?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to