This is an automated email from the ASF dual-hosted git repository.

lihaopeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 996a54482d1 [opt](if) vec exec for more type in function if (#56905)
996a54482d1 is described below

commit 996a54482d16dcbbd8b700881478bcb88dda7fae
Author: Mryange <[email protected]>
AuthorDate: Wed Oct 15 11:56:36 2025 +0800

    [opt](if) vec exec for more type in function if (#56905)
---
 be/src/vec/functions/if.cpp                   | 75 ++++++++++++++++-----------
 be/src/vec/functions/if.h                     | 39 +++++++++-----
 be/test/vec/function/function_num_if_test.cpp |  6 +--
 3 files changed, 74 insertions(+), 46 deletions(-)

diff --git a/be/src/vec/functions/if.cpp b/be/src/vec/functions/if.cpp
index b7b7d8ab8da..aac252f71af 100644
--- a/be/src/vec/functions/if.cpp
+++ b/be/src/vec/functions/if.cpp
@@ -30,6 +30,8 @@
 #include <utility>
 
 #include "common/status.h"
+#include "runtime/define_primitive_type.h"
+#include "runtime/primitive_type.h"
 #include "util/simd/bits.h"
 #include "vec/aggregate_functions/aggregate_function.h"
 #include "vec/columns/column.h"
@@ -43,8 +45,13 @@
 #include "vec/core/column_with_type_and_name.h"
 #include "vec/core/types.h"
 #include "vec/data_types/data_type.h"
+#include "vec/data_types/data_type_date_or_datetime_v2.h"
+#include "vec/data_types/data_type_decimal.h"
+#include "vec/data_types/data_type_ipv4.h"
+#include "vec/data_types/data_type_ipv6.h"
 #include "vec/data_types/data_type_nullable.h"
 #include "vec/data_types/data_type_number.h"
+#include "vec/data_types/data_type_time.h"
 #include "vec/functions/cast_type_to_either.h"
 #include "vec/functions/function.h"
 #include "vec/functions/function_helpers.h"
@@ -190,34 +197,27 @@ public:
         return Status::OK();
     }
 
-    void execute_basic_type(Block& block, const ColumnUInt8* cond_col,
-                            const ColumnWithTypeAndName& then_col,
-                            const ColumnWithTypeAndName& else_col, uint32_t 
result,
-                            Status& status) const {
+    template <PrimitiveType PType>
+    Status execute_basic_type(Block& block, const ColumnUInt8* cond_col,
+                              const ColumnWithTypeAndName& then_col,
+                              const ColumnWithTypeAndName& else_col, uint32_t 
result,
+                              Status& status) const {
         if (then_col.type->get_primitive_type() != 
else_col.type->get_primitive_type()) {
-            status = Status::InternalError("then and else column type must be 
same");
-            return;
+            return Status::InternalError(
+                    "then and else column type must be same for function {} , 
but got {} , {}",
+                    get_name(), then_col.type->get_name(), 
else_col.type->get_name());
         }
-        DCHECK(is_int(then_col.type->get_primitive_type()) ||
-               is_float_or_double(then_col.type->get_primitive_type()))
-                << then_col.type->get_name();
-        auto valid = cast_type_to_either<DataTypeInt8, DataTypeInt16, 
DataTypeInt32, DataTypeInt64,
-                                         DataTypeInt128, DataTypeFloat32, 
DataTypeFloat64>(
-                then_col.type.get(), [&](const auto& type) -> bool {
-                    using DataType = std::decay_t<decltype(type)>;
-                    auto res_column = NumIfImpl<DataType::PType>::execute_if(
-                            cond_col->get_data(), then_col.column, 
else_col.column);
-                    if (!res_column) {
-                        return false;
-                    }
-                    block.replace_by_position(result, std::move(res_column));
-                    return true;
-                });
-        if (!valid) {
-            status = Status::InternalError("unexpected args column type {} , 
{} , of function {}",
-                                           then_col.type->get_name(), 
else_col.type->get_name(),
-                                           get_name());
+
+        auto res_column =
+                NumIfImpl<PType>::execute_if(cond_col->get_data(), 
then_col.column, else_col.column,
+                                             
block.get_by_position(result).type->get_scale());
+        if (!res_column) {
+            return Status::InternalError("unexpected args column {} , {} , of 
function {}",
+                                         then_col.column->get_name(), 
else_col.column->get_name(),
+                                         get_name());
         }
+        block.replace_by_position(result, std::move(res_column));
+        return Status::OK();
     }
 
     Status execute_for_null_then_else(FunctionContext* context, Block& block,
@@ -525,11 +525,26 @@ public:
             return Status::OK();
         }
 
-        if (is_int(arg_then.type->get_primitive_type()) ||
-            is_float_or_double(arg_then.type->get_primitive_type())) {
-            Status status;
-            execute_basic_type(block, cond_col, arg_then, arg_else, result, 
status);
-            return status;
+        Status vec_exec;
+        auto can_use_vec_exec = cast_type_to_either<
+                // int
+                DataTypeInt8, DataTypeInt16, DataTypeInt32, DataTypeInt64, 
DataTypeInt128,
+                DataTypeBool,
+                // flaot
+                DataTypeFloat32, DataTypeFloat64,
+                // date time
+                DataTypeDateTimeV2, DataTypeDateV2, DataTypeTimeV2,
+                // decimal
+                DataTypeDecimal32, DataTypeDecimal64, DataTypeDecimal128, 
DataTypeDecimal256,
+                // ip
+                DataTypeIPv4, DataTypeIPv6>(arg_then.type.get(), [&](const 
auto& type) -> bool {
+            using DataType = std::decay_t<decltype(type)>;
+            vec_exec = execute_basic_type<DataType::PType>(block, cond_col, 
arg_then, arg_else,
+                                                           result, vec_exec);
+            return true;
+        });
+        if (can_use_vec_exec) {
+            return vec_exec;
         } else {
             return execute_generic(block, cond_col, arg_then, arg_else, 
result, input_rows_count);
         }
diff --git a/be/src/vec/functions/if.h b/be/src/vec/functions/if.h
index cfb60ff5c49..98bd62382d1 100644
--- a/be/src/vec/functions/if.h
+++ b/be/src/vec/functions/if.h
@@ -23,6 +23,7 @@
 
 #include <boost/iterator/iterator_facade.hpp>
 
+#include "runtime/primitive_type.h"
 #include "vec/columns/column.h"
 #include "vec/columns/column_vector.h"
 #include "vec/common/pod_array_fwd.h"
@@ -43,33 +44,43 @@ private:
     using Type = typename PrimitiveTypeTraits<PType>::ColumnItemType;
     using ArrayCond = PaddedPODArray<UInt8>;
     using Array = PaddedPODArray<Type>;
-    using ColVecResult = ColumnVector<PType>;
-    using ColVecT = ColumnVector<PType>;
+    using ColVecT = typename PrimitiveTypeTraits<PType>::ColumnType;
 
 public:
     static const Array& get_data_from_column_const(const ColumnConst* column) {
         return assert_cast<const 
ColVecT&>(column->get_data_column()).get_data();
     }
 
+    static auto create_column(const ArrayCond& cond, int scale) {
+        if constexpr (is_decimal(PType)) {
+            return ColVecT::create(cond.size(), scale);
+        } else {
+            return ColVecT::create(cond.size());
+        }
+    }
+
     static ColumnPtr execute_if(const ArrayCond& cond, const ColumnPtr& 
then_col,
-                                const ColumnPtr& else_col) {
+                                const ColumnPtr& else_col, int result_scale) {
         if (const auto* col_then = 
check_and_get_column<ColVecT>(then_col.get())) {
             if (const auto* col_else = 
check_and_get_column<ColVecT>(else_col.get())) {
-                return execute_impl<false, false>(cond, col_then->get_data(), 
col_else->get_data());
+                return execute_impl<false, false>(cond, col_then->get_data(), 
col_else->get_data(),
+                                                  result_scale);
             } else if (const auto* col_const_else =
                                
check_and_get_column_const<ColVecT>(else_col.get())) {
                 return execute_impl<false, true>(cond, col_then->get_data(),
-                                                 
get_data_from_column_const(col_const_else));
+                                                 
get_data_from_column_const(col_const_else),
+                                                 result_scale);
             }
         } else if (const auto* col_const_then =
                            
check_and_get_column_const<ColVecT>(then_col.get())) {
             if (const auto* col_else = 
check_and_get_column<ColVecT>(else_col.get())) {
                 return execute_impl<true, false>(cond, 
get_data_from_column_const(col_const_then),
-                                                 col_else->get_data());
+                                                 col_else->get_data(), 
result_scale);
             } else if (const auto* col_const_else =
                                
check_and_get_column_const<ColVecT>(else_col.get())) {
                 return execute_impl<true, true>(cond, 
get_data_from_column_const(col_const_then),
-                                                
get_data_from_column_const(col_const_else));
+                                                
get_data_from_column_const(col_const_else),
+                                                result_scale);
             }
         }
         return nullptr;
@@ -77,24 +88,26 @@ public:
 
 private:
     template <bool is_const_a, bool is_const_b>
-    static ColumnPtr execute_impl(const ArrayCond& cond, const Array& a, const 
Array& b) {
+    static ColumnPtr execute_impl(const ArrayCond& cond, const Array& a, const 
Array& b,
+                                  int result_scale) {
 #ifdef __ARM_NEON
         if constexpr (can_use_neon_opt()) {
-            auto col_res = ColVecResult::create(cond.size());
+            auto col_res = create_column(cond, result_scale);
             auto res = col_res->get_data().data();
             neon_execute<is_const_a, is_const_b>(cond.data(), res, a.data(), 
b.data(), cond.size());
             return col_res;
         }
 #endif
-        return native_execute<is_const_a, is_const_b>(cond, a, b);
+        return native_execute<is_const_a, is_const_b>(cond, a, b, 
result_scale);
     }
 
     // res[i] = cond[i] ? a[i] : b[i];
     template <bool is_const_a, bool is_const_b>
-    static ColumnPtr native_execute(const ArrayCond& cond, const Array& a, 
const Array& b) {
+    static ColumnPtr native_execute(const ArrayCond& cond, const Array& a, 
const Array& b,
+                                    int result_scale) {
         size_t size = cond.size();
-        auto col_res = ColVecResult::create(size);
-        typename ColVecResult::Container& res = col_res->get_data();
+        auto col_res = create_column(cond, result_scale);
+        auto& res = col_res->get_data();
         for (size_t i = 0; i < size; ++i) {
             res[i] = cond[i] ? a[index_check_const<is_const_a>(i)]
                              : b[index_check_const<is_const_b>(i)];
diff --git a/be/test/vec/function/function_num_if_test.cpp 
b/be/test/vec/function/function_num_if_test.cpp
index 110f51ee2dc..8c4fb7d88c0 100644
--- a/be/test/vec/function/function_num_if_test.cpp
+++ b/be/test/vec/function/function_num_if_test.cpp
@@ -31,7 +31,7 @@ TEST(NumIfImplTest, smallTest) {
     auto cond = ColumnHelper::create_column<DataTypeUInt8>({1, 0, 1, 0, 1, 0, 
1, 0});
     auto a = ColumnHelper::create_column<DataTypeInt32>({1, 2, 3, 4, 5, 6, 7, 
8});
     auto b = ColumnHelper::create_column<DataTypeInt32>({10, 20, 30, 40, 50, 
60, 70, 80});
-    auto res = NumIfImpl<TYPE_INT>::execute_if(get_cond_data(cond), a, b);
+    auto res = NumIfImpl<TYPE_INT>::execute_if(get_cond_data(cond), a, b, 0);
     ColumnHelper::column_equal(
             res, ColumnHelper::create_column<DataTypeInt32>({1, 20, 3, 40, 5, 
60, 7, 80}));
 }
@@ -49,7 +49,7 @@ TEST(NumIfImplTest, largeTest) {
     auto cond = ColumnHelper::create_column<DataTypeUInt8>(cond_data);
     auto a = ColumnHelper::create_column<DataTypeInt32>(a_data);
     auto b = ColumnHelper::create_column<DataTypeInt32>(b_data);
-    auto res = NumIfImpl<TYPE_INT>::execute_if(get_cond_data(cond), a, b);
+    auto res = NumIfImpl<TYPE_INT>::execute_if(get_cond_data(cond), a, b, 0);
     std::vector<int32_t> expected_data(1024, 0);
     for (size_t i = 0; i < 1024; ++i) {
         expected_data[i] = cond_data[i] ? a_data[i] : b_data[i];
@@ -79,7 +79,7 @@ void test_for_all_const_no_const() {
     if (is_b_const) {
         b = ColumnConst::create(b, 1);
     }
-    auto res = NumIfImpl<DataType::PType>::execute_if(get_cond_data(cond), a, 
b);
+    auto res = NumIfImpl<DataType::PType>::execute_if(get_cond_data(cond), a, 
b, 0);
     std::vector<FieldType> expected_data(size, 0);
     for (size_t i = 0; i < size; ++i) {
         expected_data[i] = cond_data[i] ? a_data[is_a_const ? 0 : i] : 
b_data[is_b_const ? 0 : i];


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to